hts/packages/isdk/core/generate-text/run-tools-transformation.ts

192 lines
5.4 KiB
TypeScript

import { LanguageModelV1StreamPart, NoSuchToolError } from '@ai-sdk/provider';
import { generateId } from '../../shared/generate-id';
import { ExperimentalTool } from '../tool';
import { TextStreamPart } from './stream-text';
import { parseToolCall } from './tool-call';
export function runToolsTransformation<
TOOLS extends Record<string, ExperimentalTool>,
>({
tools,
generatorStream,
}: {
tools?: TOOLS;
generatorStream: ReadableStream<LanguageModelV1StreamPart>;
}): ReadableStream<TextStreamPart<TOOLS>> {
let canClose = false;
const outstandingToolCalls = new Set<string>();
// tool results stream
let toolResultsStreamController: ReadableStreamDefaultController<
TextStreamPart<TOOLS>
> | null = null;
const toolResultsStream = new ReadableStream<TextStreamPart<TOOLS>>({
start(controller) {
toolResultsStreamController = controller;
},
});
// forward stream
const forwardStream = new TransformStream<
LanguageModelV1StreamPart,
TextStreamPart<TOOLS>
>({
transform(
chunk: LanguageModelV1StreamPart,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>,
) {
const chunkType = chunk.type;
switch (chunkType) {
// forward:
case 'text-delta':
case 'error': {
controller.enqueue(chunk);
break;
}
// process tool call:
case 'tool-call': {
const toolName = chunk.toolName as keyof TOOLS & string;
if (tools == null) {
toolResultsStreamController!.enqueue({
type: 'error',
error: new NoSuchToolError({ toolName: chunk.toolName }),
});
break;
}
const tool = tools[toolName];
if (tool == null) {
toolResultsStreamController!.enqueue({
type: 'error',
error: new NoSuchToolError({
toolName: chunk.toolName,
availableTools: Object.keys(tools),
}),
});
break;
}
try {
const toolCall = parseToolCall({
toolCall: chunk,
tools,
});
controller.enqueue(toolCall);
if (tool.execute != null) {
const toolExecutionId = generateId(); // use our own id to guarantee uniqueness
outstandingToolCalls.add(toolExecutionId);
// Note: we don't await the tool execution here, because we want to process
// the next chunk as soon as possible. This is important for the case where
// the tool execution takes a long time.
tool.execute(toolCall.args).then(
(result: any) => {
toolResultsStreamController!.enqueue({
...toolCall,
type: 'tool-result',
result,
} as any);
outstandingToolCalls.delete(toolExecutionId);
// close the tool results controller if no more outstanding tool calls
if (canClose && outstandingToolCalls.size === 0) {
toolResultsStreamController!.close();
}
},
(error: any) => {
toolResultsStreamController!.enqueue({
type: 'error',
error,
});
outstandingToolCalls.delete(toolExecutionId);
// close the tool results controller if no more outstanding tool calls
if (canClose && outstandingToolCalls.size === 0) {
toolResultsStreamController!.close();
}
},
);
}
} catch (error) {
toolResultsStreamController!.enqueue({
type: 'error',
error,
});
}
break;
}
// process finish:
case 'finish': {
controller.enqueue({
type: 'finish',
finishReason: chunk.finishReason,
usage: {
promptTokens: chunk.usage.promptTokens,
completionTokens: chunk.usage.completionTokens,
totalTokens:
chunk.usage.promptTokens + chunk.usage.completionTokens,
},
});
break;
}
// ignore
case 'tool-call-delta': {
break;
}
default: {
const _exhaustiveCheck: never = chunkType;
throw new Error(`Unhandled chunk type: ${_exhaustiveCheck}`);
}
}
},
flush() {
canClose = true;
if (outstandingToolCalls.size === 0) {
toolResultsStreamController!.close();
}
},
});
// combine the generator stream and the tool results stream
return new ReadableStream<TextStreamPart<TOOLS>>({
async start(controller) {
generatorStream.pipeThrough(forwardStream).pipeTo(
new WritableStream({
write(chunk) {
controller.enqueue(chunk);
},
close() {
// the generator stream controller is automatically closed when it's consumed
},
}),
);
toolResultsStream.pipeTo(
new WritableStream({
write(chunk) {
controller.enqueue(chunk);
},
close() {
controller.close();
},
}),
);
},
});
}