Skip to content

Commit 49c2c01

Browse files
committed
move streamText into own file
1 parent 5378dfd commit 49c2c01

File tree

4 files changed

+214
-130
lines changed

4 files changed

+214
-130
lines changed

src/client/index.ts

Lines changed: 21 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,8 @@ import type {
1515
StreamTextResult,
1616
ToolChoice,
1717
ToolSet,
18-
UIMessage as AIUIMessage,
19-
} from "ai";
20-
import {
21-
generateObject,
22-
generateText,
23-
stepCountIs,
24-
streamObject,
25-
streamText,
2618
} from "ai";
19+
import { generateObject, generateText, stepCountIs, streamObject } from "ai";
2720
import { assert, omit, pick } from "convex-helpers";
2821
import {
2922
internalActionGeneric,
@@ -70,13 +63,7 @@ import {
7063
generateAndSaveEmbeddings,
7164
} from "./search.js";
7265
import { startGeneration } from "./start.js";
73-
import {
74-
compressUIMessageChunks,
75-
DeltaStreamer,
76-
mergeTransforms,
77-
syncStreams,
78-
type StreamingOptions,
79-
} from "./streaming.js";
66+
import { syncStreams, type StreamingOptions } from "./streaming.js";
8067
import { createThread, getThreadMetadata } from "./threads.js";
8168
import type {
8269
ActionCtx,
@@ -100,6 +87,8 @@ import type {
10087
QueryCtx,
10188
AgentPrompt,
10289
} from "./types.js";
90+
import { streamText } from "./streamText.js";
91+
import { errorToString, willContinue } from "./utils.js";
10392

10493
export { stepCountIs } from "ai";
10594
export {
@@ -550,97 +539,27 @@ export class Agent<
550539
> &
551540
GenerationOutputMetadata
552541
> {
553-
const { threadId } = threadOpts;
554-
const { args, userId, order, stepOrder, promptMessageId, ...call } =
555-
await this.start(ctx, streamTextArgs, { ...threadOpts, ...options });
556-
557542
type Tools = TOOLS extends undefined ? AgentTools : TOOLS;
558-
const steps: StepResult<Tools>[] = [];
559-
560-
const opts = { ...this.options, ...options };
561-
const streamer =
562-
threadId && opts.saveStreamDeltas
563-
? new DeltaStreamer(
564-
this.component,
565-
ctx,
566-
{
567-
throttleMs:
568-
typeof opts.saveStreamDeltas === "object"
569-
? opts.saveStreamDeltas.throttleMs
570-
: undefined,
571-
onAsyncAbort: call.fail,
572-
compress: compressUIMessageChunks,
573-
abortSignal: args.abortSignal,
574-
},
575-
{
576-
threadId,
577-
userId,
578-
agentName: this.options.name,
579-
model: getModelName(args.model),
580-
provider: getProviderName(args.model),
581-
providerOptions: args.providerOptions,
582-
format: "UIMessageChunk",
583-
order,
584-
stepOrder,
585-
},
586-
)
587-
: undefined;
588-
589-
const result = streamText({
590-
...args,
591-
abortSignal: streamer?.abortController.signal ?? args.abortSignal,
592-
experimental_transform: mergeTransforms(
593-
options?.saveStreamDeltas,
594-
streamTextArgs.experimental_transform,
595-
),
596-
onError: async (error) => {
597-
console.error("onError", error);
598-
await call.fail(errorToString(error.error));
599-
await streamer?.fail(errorToString(error.error));
600-
return streamTextArgs.onError?.(error);
601-
},
602-
prepareStep: async (options) => {
603-
const result = await streamTextArgs.prepareStep?.(options);
604-
if (result) {
605-
const model = result.model ?? options.model;
606-
call.updateModel(model);
607-
// streamer?.updateMetadata({
608-
// model: getModelName(model),
609-
// provider: getProviderName(model),
610-
// providerOptions: options.messages.at(-1)?.providerOptions,
611-
// });
612-
return result;
613-
}
614-
return undefined;
543+
return streamText<Tools, OUTPUT, PARTIAL_OUTPUT>(
544+
ctx,
545+
this.component,
546+
{
547+
...streamTextArgs,
548+
model: streamTextArgs.model ?? this.options.languageModel,
549+
tools: (streamTextArgs.tools ?? this.options.tools) as Tools,
550+
system: streamTextArgs.system ?? this.options.instructions,
551+
stopWhen: (streamTextArgs.stopWhen ?? this.options.stopWhen) as
552+
| StopCondition<Tools>
553+
| Array<StopCondition<Tools>>,
615554
},
616-
onStepFinish: async (step) => {
617-
steps.push(step);
618-
const createPendingMessage = await willContinue(steps, args.stopWhen);
619-
await call.save({ step }, createPendingMessage);
620-
return args.onStepFinish?.(step);
555+
{
556+
...threadOpts,
557+
...this.options,
558+
agentName: this.options.name,
559+
agentForToolCtx: this,
560+
...options,
621561
},
622-
}) as StreamTextResult<
623-
TOOLS extends undefined ? AgentTools : TOOLS,
624-
PARTIAL_OUTPUT
625-
>;
626-
const stream = streamer?.consumeStream(
627-
result.toUIMessageStream<AIUIMessage<Tools>>(),
628562
);
629-
if (
630-
(typeof options?.saveStreamDeltas === "object" &&
631-
!options.saveStreamDeltas.returnImmediately) ||
632-
options?.saveStreamDeltas === true
633-
) {
634-
await stream;
635-
await result.consumeStream();
636-
}
637-
const metadata: GenerationOutputMetadata = {
638-
promptMessageId,
639-
order,
640-
savedMessages: call.getSavedMessages(),
641-
messageId: promptMessageId,
642-
};
643-
return Object.assign(result, metadata);
644563
}
645564

646565
/**
@@ -1569,29 +1488,3 @@ export class Agent<
15691488
});
15701489
}
15711490
}
1572-
1573-
async function willContinue(
1574-
steps: StepResult<any>[],
1575-
1576-
stopWhen: StopCondition<any> | Array<StopCondition<any>> | undefined,
1577-
): Promise<boolean> {
1578-
const step = steps.at(-1)!;
1579-
// we aren't doing another round after a tool result
1580-
// TODO: whether to handle continuing after too much context used..
1581-
if (step.finishReason !== "tool-calls") return false;
1582-
// we don't have a tool result, so we'll wait for more
1583-
if (step.toolCalls.length > step.toolResults.length) return false;
1584-
if (Array.isArray(stopWhen)) {
1585-
return (await Promise.all(stopWhen.map(async (s) => s({ steps })))).every(
1586-
(stop) => !stop,
1587-
);
1588-
}
1589-
return !!stopWhen && !(await stopWhen({ steps }));
1590-
}
1591-
1592-
function errorToString(error: unknown): string {
1593-
if (error instanceof Error) {
1594-
return error.message;
1595-
}
1596-
return String(error);
1597-
}

src/client/start.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import {
2323
} from "../shared.js";
2424
import { wrapTools, type ToolCtx } from "./createTool.js";
2525
import type { Agent } from "./index.js";
26-
import { omit } from "convex-helpers";
26+
import { assert, omit } from "convex-helpers";
2727
import { saveInputMessages } from "./saveInputMessages.js";
2828
import type { GenericActionCtx, GenericDataModel } from "convex/server";
2929

@@ -90,7 +90,7 @@ export async function startGeneration<
9090
Config & {
9191
userId?: string | null;
9292
threadId?: string;
93-
languageModel: LanguageModel;
93+
languageModel?: LanguageModel;
9494
agentName: string;
9595
agentForToolCtx?: Agent;
9696
},
@@ -155,6 +155,7 @@ export async function startGeneration<
155155
let pendingMessageId = pendingMessage?._id;
156156

157157
const model = args.model ?? opts.languageModel;
158+
assert(model, "model is required");
158159
let activeModel: ModelOrMetadata = model;
159160

160161
const fail = async (reason: string) => {

src/client/streamText.ts

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import type {
2+
StepResult,
3+
StreamTextResult,
4+
ToolSet,
5+
UIMessage as AIUIMessage,
6+
} from "ai";
7+
import { streamText as streamTextAi } from "ai";
8+
import {
9+
compressUIMessageChunks,
10+
DeltaStreamer,
11+
mergeTransforms,
12+
type StreamingOptions,
13+
} from "./streaming.js";
14+
import type {
15+
ActionCtx,
16+
AgentComponent,
17+
AgentPrompt,
18+
GenerationOutputMetadata,
19+
Options,
20+
} from "./types.js";
21+
import { startGeneration } from "./start.js";
22+
import type { Agent } from "./index.js";
23+
import { getModelName, getProviderName } from "../shared.js";
24+
import { errorToString, willContinue } from "./utils.js";
25+
26+
/**
27+
* This behaves like {@link streamText} from the "ai" package except that
28+
* it add context based on the userId and threadId and saves the input and
29+
* resulting messages to the thread, if specified.
30+
* Use {@link continueThread} to get a version of this function already scoped
31+
* to a thread (and optionally userId).
32+
*/
33+
export async function streamText<
34+
TOOLS extends ToolSet,
35+
OUTPUT = never,
36+
PARTIAL_OUTPUT = never,
37+
>(
38+
ctx: ActionCtx,
39+
component: AgentComponent,
40+
/**
41+
* The arguments to the streamText function, similar to the ai sdk's
42+
* {@link streamText} function, along with Agent prompt options.
43+
*/
44+
streamTextArgs: AgentPrompt &
45+
Omit<
46+
Parameters<typeof streamTextAi<TOOLS, OUTPUT, PARTIAL_OUTPUT>>[0],
47+
"model" | "prompt" | "messages"
48+
> & {
49+
/**
50+
* The tools to use for the tool calls. This will override tools specified
51+
* in the Agent constructor or createThread / continueThread.
52+
*/
53+
tools?: TOOLS;
54+
},
55+
/**
56+
* The {@link ContextOptions} and {@link StorageOptions}
57+
* options to use for fetching contextual messages and saving input/output messages.
58+
*/
59+
options: Options & {
60+
agentName: string;
61+
userId?: string | null;
62+
threadId?: string;
63+
/**
64+
* Whether to save incremental data (deltas) from streaming responses.
65+
* Defaults to false.
66+
* If false, it will not save any deltas to the database.
67+
* If true, it will save deltas with {@link DEFAULT_STREAMING_OPTIONS}.
68+
*
69+
* Regardless of this option, when streaming you are able to use this
70+
* `streamText` function as you would with the "ai" package's version:
71+
* iterating over the text, streaming it over HTTP, etc.
72+
*/
73+
saveStreamDeltas?: boolean | StreamingOptions;
74+
agentForToolCtx?: Agent;
75+
},
76+
): Promise<StreamTextResult<TOOLS, PARTIAL_OUTPUT> & GenerationOutputMetadata> {
77+
const { threadId } = options ?? {};
78+
const { args, userId, order, stepOrder, promptMessageId, ...call } =
79+
await startGeneration(ctx, component, streamTextArgs, options);
80+
81+
const steps: StepResult<TOOLS>[] = [];
82+
83+
const streamer =
84+
threadId && options.saveStreamDeltas
85+
? new DeltaStreamer(
86+
component,
87+
ctx,
88+
{
89+
throttleMs:
90+
typeof options.saveStreamDeltas === "object"
91+
? options.saveStreamDeltas.throttleMs
92+
: undefined,
93+
onAsyncAbort: call.fail,
94+
compress: compressUIMessageChunks,
95+
abortSignal: args.abortSignal,
96+
},
97+
{
98+
threadId,
99+
userId,
100+
agentName: options?.agentName,
101+
model: getModelName(args.model),
102+
provider: getProviderName(args.model),
103+
providerOptions: args.providerOptions,
104+
format: "UIMessageChunk",
105+
order,
106+
stepOrder,
107+
},
108+
)
109+
: undefined;
110+
111+
const result = streamTextAi({
112+
...args,
113+
abortSignal: streamer?.abortController.signal ?? args.abortSignal,
114+
experimental_transform: mergeTransforms(
115+
options?.saveStreamDeltas,
116+
streamTextArgs.experimental_transform,
117+
),
118+
onError: async (error) => {
119+
console.error("onError", error);
120+
await call.fail(errorToString(error.error));
121+
await streamer?.fail(errorToString(error.error));
122+
return streamTextArgs.onError?.(error);
123+
},
124+
prepareStep: async (options) => {
125+
const result = await streamTextArgs.prepareStep?.(options);
126+
if (result) {
127+
const model = result.model ?? options.model;
128+
call.updateModel(model);
129+
// streamer?.updateMetadata({
130+
// model: getModelName(model),
131+
// provider: getProviderName(model),
132+
// providerOptions: options.messages.at(-1)?.providerOptions,
133+
// });
134+
return result;
135+
}
136+
return undefined;
137+
},
138+
onStepFinish: async (step) => {
139+
steps.push(step);
140+
const createPendingMessage = await willContinue(steps, args.stopWhen);
141+
await call.save({ step }, createPendingMessage);
142+
return args.onStepFinish?.(step);
143+
},
144+
}) as StreamTextResult<TOOLS, PARTIAL_OUTPUT>;
145+
const stream = streamer?.consumeStream(
146+
result.toUIMessageStream<AIUIMessage<TOOLS>>(),
147+
);
148+
if (
149+
(typeof options?.saveStreamDeltas === "object" &&
150+
!options.saveStreamDeltas.returnImmediately) ||
151+
options?.saveStreamDeltas === true
152+
) {
153+
await stream;
154+
await result.consumeStream();
155+
}
156+
const metadata: GenerationOutputMetadata = {
157+
promptMessageId,
158+
order,
159+
savedMessages: call.getSavedMessages(),
160+
messageId: promptMessageId,
161+
};
162+
return Object.assign(result, metadata);
163+
}

0 commit comments

Comments
 (0)