Testing
To test applications that interact with language models, we can use MockLanguageModel
to simulate model responses.
It can also be used to assert that the application is making the expected calls to the model.
/* eslint-disable @typescript-eslint/no-floating-promises */import test, { suite, type TestContext } from "node:test";import type { ModelResponse } from "../types.ts";import { MockLanguageModel } from "./test.ts";// import { MockLanguageModel, type ModelResponse } from "@hoangvvo/llm-sdk/test";
suite("MockLanguageModel", () => { test("tracks generate inputs and returns mocked responses", async (t: TestContext) => { const model = new MockLanguageModel();
model.enqueueGenerateResult( // First mocked response { response: { content: [{ type: "text", text: "Hello, world!" }], }, }, // Second mocked response is an error { error: new Error("Generate error"), }, // Third mocked response { response: { content: [{ type: "text", text: "Goodbye, world!" }], }, }, );
// First call should return the first mocked response const res1 = await model.generate({ messages: [{ role: "user", content: [{ type: "text", text: "Hi" }] }], }); const expected1: ModelResponse = { content: [{ type: "text", text: "Hello, world!" }], }; t.assert.deepStrictEqual(res1, expected1); t.assert.deepStrictEqual(model.trackedGenerateInputs[0], { messages: [{ role: "user", content: [{ type: "text", text: "Hi" }] }], });
// Second call should throw an error await t.assert.rejects( model.generate({ messages: [ { role: "user", content: [{ type: "text", text: "Error" }] }, ], }), { message: "Generate error" }, ); t.assert.deepStrictEqual(model.trackedGenerateInputs[1], { messages: [{ role: "user", content: [{ type: "text", text: "Error" }] }], });
// Third call should return the last mocked response const res3 = await model.generate({ messages: [ { role: "user", content: [{ type: "text", text: "Goodbye" }] }, ], }); const expected3: ModelResponse = { content: [{ type: "text", text: "Goodbye, world!" }], }; t.assert.deepStrictEqual(res3, expected3); t.assert.deepStrictEqual(model.trackedGenerateInputs[2], { messages: [ { role: "user", content: [{ type: "text", text: "Goodbye" }] }, ], });
// Reset tracked inputs model.reset(); t.assert.deepStrictEqual(model.trackedGenerateInputs, []);
model.enqueueGenerateResult({ response: { content: [{ type: "text", text: "After reset" }], }, });
// Restore the mock to its initial state model.restore(); t.assert.deepStrictEqual(model.trackedGenerateInputs, []); await t.assert.rejects(() => { // No mocked results should be available after restore return model.generate({ messages: [{ role: "user", content: [{ type: "text", text: "Hi" }] }], }); }, /No mocked generate results available/); });
test("tracks stream inputs and yields mocked partials", async (t: TestContext) => { const model = new MockLanguageModel();
model.enqueueStreamResult( // First mocked stream response { partials: [ { delta: { index: 0, part: { type: "text", text: "Hello" } } }, { delta: { index: 0, part: { type: "text", text: ", " } } }, { delta: { index: 0, part: { type: "text", text: "world!" } } }, ], }, // Second mocked stream response is an error { error: new Error("Stream error"), }, // Third mocked stream response { partials: [ { delta: { index: 0, part: { type: "text", text: "Goodbye" } } }, { delta: { index: 0, part: { type: "text", text: ", " } } }, { delta: { index: 0, part: { type: "text", text: "world!" } } }, ], }, );
// First stream call should yield the first set of partials const partials1 = []; for await (const partial of model.stream({ messages: [{ role: "user", content: [{ type: "text", text: "Hi" }] }], })) { partials1.push(partial); } const expectedPartials1 = [ { delta: { index: 0, part: { type: "text", text: "Hello" } } }, { delta: { index: 0, part: { type: "text", text: ", " } } }, { delta: { index: 0, part: { type: "text", text: "world!" } } }, ]; t.assert.deepStrictEqual(partials1, expectedPartials1); t.assert.deepStrictEqual(model.trackedStreamInputs[0], { messages: [{ role: "user", content: [{ type: "text", text: "Hi" }] }], });
// Second stream call should throw an error await t.assert.rejects( async () => { const partials2 = []; for await (const partial of model.stream({ messages: [ { role: "user", content: [{ type: "text", text: "Error" }] }, ], })) { partials2.push(partial); } return partials2; }, { message: "Stream error" }, ); t.assert.deepStrictEqual(model.trackedStreamInputs[1], { messages: [{ role: "user", content: [{ type: "text", text: "Error" }] }], });
// Third stream call should yield the last set of partials const partials3 = []; for await (const partial of model.stream({ messages: [ { role: "user", content: [{ type: "text", text: "Goodbye" }] }, ], })) { partials3.push(partial); } const expectedPartials3 = [ { delta: { index: 0, part: { type: "text", text: "Goodbye" } } }, { delta: { index: 0, part: { type: "text", text: ", " } } }, { delta: { index: 0, part: { type: "text", text: "world!" } } }, ]; t.assert.deepStrictEqual(partials3, expectedPartials3); t.assert.deepStrictEqual(model.trackedStreamInputs[2], { messages: [ { role: "user", content: [{ type: "text", text: "Goodbye" }] }, ], });
// Reset tracked inputs model.reset(); t.assert.deepStrictEqual(model.trackedStreamInputs, []);
model.enqueueStreamResult({ partials: [ { delta: { index: 0, part: { type: "text", text: "After reset" } } }, ], });
// Restore the mock to its initial state model.restore(); t.assert.deepStrictEqual(model.trackedStreamInputs, []); await t.assert.rejects(async () => { // No mocked results should be available after restore const stream = model.stream({ messages: [{ role: "user", content: [{ type: "text", text: "Hi" }] }], }); // eslint-disable-next-line @typescript-eslint/no-unused-vars for await (const _partial of stream) { // noop } }, /No mocked stream results available/); });});
use futures::StreamExt;use llm_sdk::{ llm_sdk_test::{MockGenerateResult, MockLanguageModel, MockStreamResult}, ContentDelta, LanguageModel, LanguageModelError, LanguageModelInput, LanguageModelResult, LanguageModelStream, Message, ModelResponse, Part, PartDelta, PartialModelResponse, TextPartDelta, UserMessage,};
fn user_input(text: &str) -> LanguageModelInput { LanguageModelInput { messages: vec![Message::User(UserMessage { content: vec![Part::text(text)], })], ..LanguageModelInput::default() }}
fn text_partial(text: &str) -> PartialModelResponse { PartialModelResponse { delta: Some(ContentDelta { index: 0, part: PartDelta::Text(TextPartDelta { text: text.to_string(), }), }), ..PartialModelResponse::default() }}
#[tokio::test]async fn mock_language_model_tracks_generate_inputs_and_returns_results() { let model = MockLanguageModel::new();
let response1 = ModelResponse { content: vec![Part::text("Hello, world!")], ..ModelResponse::default() }; let response3 = ModelResponse { content: vec![Part::text("Goodbye, world!")], ..ModelResponse::default() };
model .enqueue_generate(response1.clone()) .enqueue_generate(MockGenerateResult::error(LanguageModelError::InvalidInput( "generate error".to_string(), ))) .enqueue_generate(response3.clone());
let input1 = user_input("Hi"); let res1 = model .generate(input1.clone()) .await .expect("first generate should succeed"); assert_eq!(res1, response1); let tracked = model.tracked_generate_inputs(); assert_eq!(tracked.len(), 1); assert_eq!(tracked[0].messages, input1.messages.clone());
let input2 = user_input("Error"); let err = model .generate(input2.clone()) .await .expect_err("second generate should error"); match err { LanguageModelError::InvalidInput(msg) => { assert_eq!(msg, "generate error"); } other => panic!("unexpected error variant: {:?}", other), } let tracked = model.tracked_generate_inputs(); assert_eq!(tracked.len(), 2); assert_eq!(tracked[1].messages, input2.messages.clone());
let input3 = user_input("Goodbye"); let res3 = model .generate(input3.clone()) .await .expect("third generate should succeed"); assert_eq!(res3, response3); let tracked = model.tracked_generate_inputs(); assert_eq!(tracked.len(), 3); assert_eq!(tracked[2].messages, input3.messages.clone());
model.reset(); assert!(model.tracked_generate_inputs().is_empty());
model.enqueue_generate(ModelResponse { content: vec![Part::text("After reset")], ..ModelResponse::default() });
model.restore(); assert!(model.tracked_generate_inputs().is_empty());
let err = model .generate(input1.clone()) .await .expect_err("generate after restore should fail"); match err { LanguageModelError::Invariant(provider, message) => { assert_eq!(provider, "mock"); assert_eq!(message, "no mocked generate results available"); } other => panic!("unexpected error variant: {:?}", other), }}
#[tokio::test]async fn mock_language_model_tracks_stream_inputs_and_yields_partials() { let model = MockLanguageModel::new();
let partials1 = vec![ text_partial("Hello"), text_partial(", "), text_partial("world!"), ]; let partials3 = vec![ text_partial("Goodbye"), text_partial(", "), text_partial("world!"), ];
model .enqueue_stream(partials1.clone()) .enqueue_stream(MockStreamResult::error(LanguageModelError::InvalidInput( "stream error".to_string(), ))) .enqueue_stream(partials3.clone());
let stream_input1 = user_input("Hi"); let stream1 = model .stream(stream_input1.clone()) .await .expect("first stream should succeed"); let collected1 = collect_stream_partials(stream1) .await .expect("collecting partials should succeed"); assert_eq!(collected1, partials1); let tracked = model.tracked_stream_inputs(); assert_eq!(tracked.len(), 1); assert_eq!(tracked[0].messages, stream_input1.messages.clone());
let stream_input2 = user_input("Error"); let err = match model.stream(stream_input2.clone()).await { Ok(_) => panic!("expected stream error"), Err(err) => err, }; match err { LanguageModelError::InvalidInput(msg) => assert_eq!(msg, "stream error"), other => panic!("unexpected error variant: {:?}", other), } let tracked = model.tracked_stream_inputs(); assert_eq!(tracked.len(), 2); assert_eq!(tracked[1].messages, stream_input2.messages.clone());
let stream_input3 = user_input("Goodbye"); let stream3 = model .stream(stream_input3.clone()) .await .expect("third stream should succeed"); let collected3 = collect_stream_partials(stream3) .await .expect("collecting partials should succeed"); assert_eq!(collected3, partials3); let tracked = model.tracked_stream_inputs(); assert_eq!(tracked.len(), 3); assert_eq!(tracked[2].messages, stream_input3.messages.clone());
model.reset(); assert!(model.tracked_stream_inputs().is_empty());
model.enqueue_stream(vec![text_partial("After reset")]);
model.restore(); assert!(model.tracked_stream_inputs().is_empty());
let err = match model.stream(stream_input1.clone()).await { Ok(_) => panic!("expected stream failure"), Err(err) => err, }; match err { LanguageModelError::Invariant(provider, message) => { assert_eq!(provider, "mock"); assert_eq!(message, "no mocked stream results available"); } other => panic!("unexpected error variant: {:?}", other), }}
async fn collect_stream_partials( mut stream: LanguageModelStream,) -> LanguageModelResult<Vec<PartialModelResponse>> { let mut partials = Vec::new(); while let Some(item) = stream.next().await { partials.push(item?); } Ok(partials)}
package llmsdktest_test
import ( "context" "errors" "testing"
"github.com/google/go-cmp/cmp" llmsdk "github.com/hoangvvo/llm-sdk/sdk-go" "github.com/hoangvvo/llm-sdk/sdk-go/llmsdktest")
func TestMockLanguageModelGenerate(t *testing.T) { model := llmsdktest.NewMockLanguageModel()
response1 := llmsdk.ModelResponse{ Content: []llmsdk.Part{{TextPart: &llmsdk.TextPart{Text: "Hello, world!"}}}, } response3 := llmsdk.ModelResponse{ Content: []llmsdk.Part{{TextPart: &llmsdk.TextPart{Text: "Goodbye, world!"}}}, }
model.EnqueueGenerateResult( llmsdktest.NewMockGenerateResultResponse(response1), llmsdktest.NewMockGenerateResultError(errors.New("generate error")), llmsdktest.NewMockGenerateResultResponse(response3), )
ctx := context.Background()
input1 := &llmsdk.LanguageModelInput{ Messages: []llmsdk.Message{ llmsdk.NewUserMessage(llmsdk.Part{TextPart: &llmsdk.TextPart{Text: "Hi"}}), }, } res1, err := model.Generate(ctx, input1) if err != nil { t.Fatalf("Generate returned error: %v", err) } if diff := cmp.Diff(res1, &response1); diff != "" { t.Errorf("unexpected first response (-want +got):\n%s", diff) } trackedGenerateInputs := model.TrackedGenerateInputs() if len(trackedGenerateInputs) != 1 { t.Fatalf("expected 1 tracked generate input, got %d", len(trackedGenerateInputs)) } if diff := cmp.Diff(trackedGenerateInputs[0], *input1); diff != "" { t.Errorf("tracked generate input mismatch (-want +got):\n%s", diff) }
input2 := &llmsdk.LanguageModelInput{ Messages: []llmsdk.Message{ llmsdk.NewUserMessage(llmsdk.Part{TextPart: &llmsdk.TextPart{Text: "Error"}}), }, } if _, err := model.Generate(ctx, input2); err == nil || err.Error() != "generate error" { t.Errorf("expected generate error, got %v", err) } trackedGenerateInputs = model.TrackedGenerateInputs() if len(trackedGenerateInputs) != 2 { t.Fatalf("expected 2 tracked generate inputs, got %d", len(trackedGenerateInputs)) } if diff := cmp.Diff(trackedGenerateInputs[1], *input2); diff != "" { t.Errorf("tracked generate input mismatch (-want +got):\n%s", diff) }
input3 := &llmsdk.LanguageModelInput{ Messages: []llmsdk.Message{ llmsdk.NewUserMessage(llmsdk.Part{TextPart: &llmsdk.TextPart{Text: "Goodbye"}}), }, } res3, err := model.Generate(ctx, input3) if err != nil { t.Fatalf("Generate returned error: %v", err) } if diff := cmp.Diff(res3, &response3); diff != "" { t.Errorf("unexpected third response (-want +got):\n%s", diff) } trackedGenerateInputs = model.TrackedGenerateInputs() if len(trackedGenerateInputs) != 3 { t.Fatalf("expected 3 tracked generate inputs, got %d", len(trackedGenerateInputs)) } if diff := cmp.Diff(trackedGenerateInputs[2], *input3); diff != "" { t.Errorf("tracked generate input mismatch (-want +got):\n%s", diff) }
model.Reset() trackedGenerateInputs = model.TrackedGenerateInputs() if len(trackedGenerateInputs) != 0 { t.Errorf("expected tracked inputs to be reset, got %d", len(trackedGenerateInputs)) }
model.EnqueueGenerateResult(llmsdktest.NewMockGenerateResultResponse(llmsdk.ModelResponse{ Content: []llmsdk.Part{{TextPart: &llmsdk.TextPart{Text: "After reset"}}}, }))
model.Restore() if len(model.TrackedGenerateInputs()) != 0 { t.Errorf("expected tracked inputs to be empty after restore, got %d", len(model.TrackedGenerateInputs())) }
if _, err := model.Generate(ctx, input1); err == nil || err.Error() != "no mocked generate results available" { t.Errorf("expected no mocked generate results error after restore, got %v", err) }}
func TestMockLanguageModelStream(t *testing.T) { model := llmsdktest.NewMockLanguageModel()
partials1 := []llmsdk.PartialModelResponse{ {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: "Hello"}}}}, {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: ", "}}}}, {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: "world!"}}}}, } partials3 := []llmsdk.PartialModelResponse{ {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: "Goodbye"}}}}, {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: ", "}}}}, {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: "world!"}}}}, }
model.EnqueueStreamResult( llmsdktest.NewMockStreamResultPartials(partials1), llmsdktest.NewMockStreamResultError(errors.New("stream error")), llmsdktest.NewMockStreamResultPartials(partials3), )
ctx := context.Background()
streamInput1 := &llmsdk.LanguageModelInput{ Messages: []llmsdk.Message{ llmsdk.NewUserMessage(llmsdk.Part{TextPart: &llmsdk.TextPart{Text: "Hi"}}), }, } stream1, err := model.Stream(ctx, streamInput1) if err != nil { t.Fatalf("Stream returned error: %v", err) } gotPartials1 := collectStreamPartials(t, stream1) if diff := cmp.Diff(gotPartials1, partials1); diff != "" { t.Errorf("unexpected partials from first stream (-want +got):\n%s", diff) } trackedStreamInputs := model.TrackedStreamInputs() if len(trackedStreamInputs) != 1 { t.Fatalf("expected 1 tracked stream input, got %d", len(trackedStreamInputs)) } if diff := cmp.Diff(trackedStreamInputs[0], *streamInput1); diff != "" { t.Errorf("tracked stream input mismatch (-want +got):\n%s", diff) }
streamInput2 := &llmsdk.LanguageModelInput{ Messages: []llmsdk.Message{ llmsdk.NewUserMessage(llmsdk.Part{TextPart: &llmsdk.TextPart{Text: "Error"}}), }, } if _, err := model.Stream(ctx, streamInput2); err == nil || err.Error() != "stream error" { t.Errorf("expected stream error, got %v", err) } trackedStreamInputs = model.TrackedStreamInputs() if len(trackedStreamInputs) != 2 { t.Fatalf("expected 2 tracked stream inputs, got %d", len(trackedStreamInputs)) } if diff := cmp.Diff(trackedStreamInputs[1], *streamInput2); diff != "" { t.Errorf("tracked stream input mismatch (-want +got):\n%s", diff) }
streamInput3 := &llmsdk.LanguageModelInput{ Messages: []llmsdk.Message{ llmsdk.NewUserMessage(llmsdk.Part{TextPart: &llmsdk.TextPart{Text: "Goodbye"}}), }, } stream3, err := model.Stream(ctx, streamInput3) if err != nil { t.Fatalf("Stream returned error: %v", err) } gotPartials3 := collectStreamPartials(t, stream3) if diff := cmp.Diff(gotPartials3, partials3); diff != "" { t.Errorf("unexpected partials from third stream (-want +got):\n%s", diff) } trackedStreamInputs = model.TrackedStreamInputs() if len(trackedStreamInputs) != 3 { t.Fatalf("expected 3 tracked stream inputs, got %d", len(trackedStreamInputs)) } if diff := cmp.Diff(trackedStreamInputs[2], *streamInput3); diff != "" { t.Errorf("tracked stream input mismatch (-want +got):\n%s", diff) }
model.Reset() if len(model.TrackedStreamInputs()) != 0 { t.Errorf("expected tracked stream inputs to be reset, got %d", len(model.TrackedStreamInputs())) }
model.EnqueueStreamResult(llmsdktest.NewMockStreamResultPartials([]llmsdk.PartialModelResponse{ {Delta: &llmsdk.ContentDelta{Index: 0, Part: llmsdk.PartDelta{TextPartDelta: &llmsdk.TextPartDelta{Text: "After reset"}}}}, }))
model.Restore() if len(model.TrackedStreamInputs()) != 0 { t.Errorf("expected tracked stream inputs to be empty after restore, got %d", len(model.TrackedStreamInputs())) }
if _, err := model.Stream(ctx, streamInput1); err == nil || err.Error() != "no mocked stream results available" { t.Errorf("expected no mocked stream results error after restore, got %v", err) }}
func collectStreamPartials(t *testing.T, stream *llmsdk.LanguageModelStream) []llmsdk.PartialModelResponse { t.Helper() var partials []llmsdk.PartialModelResponse for stream.Next() { current := stream.Current() if current != nil { partials = append(partials, *current) } } if err := stream.Err(); err != nil { t.Fatalf("stream error: %v", err) } return partials}