import { LanguageModelV1, LanguageModelV1CallWarning, LanguageModelV1FinishReason, LanguageModelV1FunctionToolCall, LanguageModelV1StreamPart, UnsupportedFunctionalityError, } from '@ai-sdk/provider'; import { z } from 'zod'; import { ParseResult, createEventSourceResponseHandler, createJsonResponseHandler, postJsonToApi, } from '../spec'; import { anthropicFailedResponseHandler } from './anthropic-error'; import { AnthropicMessagesModelId, AnthropicMessagesSettings, } from './anthropic-messages-settings'; import { convertToAnthropicMessagesPrompt } from './convert-to-anthropic-messages-prompt'; import { mapAnthropicStopReason } from './map-anthropic-stop-reason'; type AnthropicMessagesConfig = { provider: string; baseUrl: string; headers: () => Record; }; export class AnthropicMessagesLanguageModel implements LanguageModelV1 { readonly specificationVersion = 'v1'; readonly defaultObjectGenerationMode = 'tool'; readonly modelId: AnthropicMessagesModelId; readonly settings: AnthropicMessagesSettings; private readonly config: AnthropicMessagesConfig; constructor( modelId: AnthropicMessagesModelId, settings: AnthropicMessagesSettings, config: AnthropicMessagesConfig, ) { this.modelId = modelId; this.settings = settings; this.config = config; } get provider(): string { return this.config.provider; } private getArgs({ mode, prompt, maxTokens, temperature, topP, frequencyPenalty, presencePenalty, seed, }: Parameters[0]) { const type = mode.type; const warnings: LanguageModelV1CallWarning[] = []; if (frequencyPenalty != null) { warnings.push({ type: 'unsupported-setting', setting: 'frequencyPenalty', }); } if (presencePenalty != null) { warnings.push({ type: 'unsupported-setting', setting: 'presencePenalty', }); } if (seed != null) { warnings.push({ type: 'unsupported-setting', setting: 'seed', }); } const messagesPrompt = convertToAnthropicMessagesPrompt(prompt); const baseArgs = { // model id: model: this.modelId, // model specific settings: top_k: this.settings.topK, // standardized settings: max_tokens: maxTokens ?? 4096, // 4096: max model output tokens temperature, // uses 0..1 scale top_p: topP, // prompt: system: messagesPrompt.system, messages: messagesPrompt.messages, }; switch (type) { case 'regular': { // when the tools array is empty, change it to undefined to prevent OpenAI errors: const tools = mode.tools?.length ? mode.tools : undefined; return { args: { ...baseArgs, tools: tools?.map(tool => ({ name: tool.name, description: tool.description, input_schema: tool.parameters, })), }, warnings, }; } case 'object-json': { throw new UnsupportedFunctionalityError({ functionality: 'json-mode object generation', }); } case 'object-tool': { const { name, description, parameters } = mode.tool; // add instruction to use tool: baseArgs.messages[baseArgs.messages.length - 1].content.push({ type: 'text', text: `\n\nUse the '${name}' tool.`, }); return { args: { ...baseArgs, tools: [{ name, description, input_schema: parameters }], }, warnings, }; } case 'object-grammar': { throw new UnsupportedFunctionalityError({ functionality: 'grammar-mode object generation', }); } default: { const _exhaustiveCheck: never = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } async doGenerate( options: Parameters[0], ): Promise>> { const { args, warnings } = this.getArgs(options); const response = await postJsonToApi({ url: `${this.config.baseUrl}/messages`, headers: this.config.headers(), body: args, failedResponseHandler: anthropicFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( anthropicMessagesResponseSchema, ), abortSignal: options.abortSignal, }); const { messages: rawPrompt, ...rawSettings } = args; // extract text let text = ''; for (const content of response.content) { if (content.type === 'text') { text += content.text; } } // extract tool calls let toolCalls: LanguageModelV1FunctionToolCall[] | undefined = undefined; if (response.content.some(content => content.type === 'tool_use')) { toolCalls = []; for (const content of response.content) { if (content.type === 'tool_use') { toolCalls.push({ toolCallType: 'function', toolCallId: content.id, toolName: content.name, args: JSON.stringify(content.input), }); } } } return { text, toolCalls, finishReason: mapAnthropicStopReason(response.stop_reason), usage: { promptTokens: response.usage.input_tokens, completionTokens: response.usage.output_tokens, }, rawCall: { rawPrompt, rawSettings }, warnings, }; } async doStream( options: Parameters[0], ): Promise>> { const { args, warnings } = this.getArgs(options); const response = await postJsonToApi({ url: `${this.config.baseUrl}/messages`, headers: this.config.headers(), body: { ...args, stream: true, }, failedResponseHandler: anthropicFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( anthropicMessagesChunkSchema, ), abortSignal: options.abortSignal, }); const { messages: rawPrompt, ...rawSettings } = args; let finishReason: LanguageModelV1FinishReason = 'other'; const usage: { promptTokens: number; completionTokens: number } = { promptTokens: Number.NaN, completionTokens: Number.NaN, }; return { stream: response.pipeThrough( new TransformStream< ParseResult>, LanguageModelV1StreamPart >({ transform(chunk, controller) { if (!chunk.success) { controller.enqueue({ type: 'error', error: chunk.error }); return; } const value = chunk.value; switch (value.type) { case 'ping': case 'content_block_start': case 'content_block_stop': { return; // ignored } case 'content_block_delta': { controller.enqueue({ type: 'text-delta', textDelta: value.delta.text, }); return; } case 'message_start': { usage.promptTokens = value.message.usage.input_tokens; usage.completionTokens = value.message.usage.output_tokens; return; } case 'message_delta': { usage.completionTokens = value.usage.output_tokens; finishReason = mapAnthropicStopReason(value.delta.stop_reason); return; } case 'message_stop': { controller.enqueue({ type: 'finish', finishReason, usage }); return; } default: { const _exhaustiveCheck: never = value; throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`); } } }, }), ), rawCall: { rawPrompt, rawSettings }, warnings, }; } } // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency const anthropicMessagesResponseSchema = z.object({ type: z.literal('message'), content: z.array( z.discriminatedUnion('type', [ z.object({ type: z.literal('text'), text: z.string(), }), z.object({ type: z.literal('tool_use'), id: z.string(), name: z.string(), input: z.unknown(), }), ]), ), stop_reason: z.string().optional().nullable(), usage: z.object({ input_tokens: z.number(), output_tokens: z.number(), }), }); // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency const anthropicMessagesChunkSchema = z.discriminatedUnion('type', [ z.object({ type: z.literal('message_start'), message: z.object({ usage: z.object({ input_tokens: z.number(), output_tokens: z.number(), }), }), }), z.object({ type: z.literal('content_block_start'), index: z.number(), content_block: z.object({ type: z.literal('text'), text: z.string(), }), }), z.object({ type: z.literal('content_block_delta'), index: z.number(), delta: z.object({ type: z.literal('text_delta'), text: z.string(), }), }), z.object({ type: z.literal('content_block_stop'), index: z.number(), }), z.object({ type: z.literal('message_delta'), delta: z.object({ stop_reason: z.string().optional().nullable() }), usage: z.object({ output_tokens: z.number() }), }), z.object({ type: z.literal('message_stop'), }), z.object({ type: z.literal('ping'), }), ]);