hts/packages/isdk/anthropic/anthropic-messages-language...

373 lines
9.9 KiB
TypeScript

import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1FunctionToolCall,
LanguageModelV1StreamPart,
UnsupportedFunctionalityError,
} from '@ai-sdk/provider';
import { z } from 'zod';
import {
ParseResult,
createEventSourceResponseHandler,
createJsonResponseHandler,
postJsonToApi,
} from '../spec';
import { anthropicFailedResponseHandler } from './anthropic-error';
import {
AnthropicMessagesModelId,
AnthropicMessagesSettings,
} from './anthropic-messages-settings';
import { convertToAnthropicMessagesPrompt } from './convert-to-anthropic-messages-prompt';
import { mapAnthropicStopReason } from './map-anthropic-stop-reason';
type AnthropicMessagesConfig = {
provider: string;
baseUrl: string;
headers: () => Record<string, string | undefined>;
};
export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
readonly specificationVersion = 'v1';
readonly defaultObjectGenerationMode = 'tool';
readonly modelId: AnthropicMessagesModelId;
readonly settings: AnthropicMessagesSettings;
private readonly config: AnthropicMessagesConfig;
constructor(
modelId: AnthropicMessagesModelId,
settings: AnthropicMessagesSettings,
config: AnthropicMessagesConfig,
) {
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider(): string {
return this.config.provider;
}
private getArgs({
mode,
prompt,
maxTokens,
temperature,
topP,
frequencyPenalty,
presencePenalty,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
const warnings: LanguageModelV1CallWarning[] = [];
if (frequencyPenalty != null) {
warnings.push({
type: 'unsupported-setting',
setting: 'frequencyPenalty',
});
}
if (presencePenalty != null) {
warnings.push({
type: 'unsupported-setting',
setting: 'presencePenalty',
});
}
if (seed != null) {
warnings.push({
type: 'unsupported-setting',
setting: 'seed',
});
}
const messagesPrompt = convertToAnthropicMessagesPrompt(prompt);
const baseArgs = {
// model id:
model: this.modelId,
// model specific settings:
top_k: this.settings.topK,
// standardized settings:
max_tokens: maxTokens ?? 4096, // 4096: max model output tokens
temperature, // uses 0..1 scale
top_p: topP,
// prompt:
system: messagesPrompt.system,
messages: messagesPrompt.messages,
};
switch (type) {
case 'regular': {
// when the tools array is empty, change it to undefined to prevent OpenAI errors:
const tools = mode.tools?.length ? mode.tools : undefined;
return {
args: {
...baseArgs,
tools: tools?.map(tool => ({
name: tool.name,
description: tool.description,
input_schema: tool.parameters,
})),
},
warnings,
};
}
case 'object-json': {
throw new UnsupportedFunctionalityError({
functionality: 'json-mode object generation',
});
}
case 'object-tool': {
const { name, description, parameters } = mode.tool;
// add instruction to use tool:
baseArgs.messages[baseArgs.messages.length - 1].content.push({
type: 'text',
text: `\n\nUse the '${name}' tool.`,
});
return {
args: {
...baseArgs,
tools: [{ name, description, input_schema: parameters }],
},
warnings,
};
}
case 'object-grammar': {
throw new UnsupportedFunctionalityError({
functionality: 'grammar-mode object generation',
});
}
default: {
const _exhaustiveCheck: never = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
async doGenerate(
options: Parameters<LanguageModelV1['doGenerate']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
const { args, warnings } = this.getArgs(options);
const response = await postJsonToApi({
url: `${this.config.baseUrl}/messages`,
headers: this.config.headers(),
body: args,
failedResponseHandler: anthropicFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
anthropicMessagesResponseSchema,
),
abortSignal: options.abortSignal,
});
const { messages: rawPrompt, ...rawSettings } = args;
// extract text
let text = '';
for (const content of response.content) {
if (content.type === 'text') {
text += content.text;
}
}
// extract tool calls
let toolCalls: LanguageModelV1FunctionToolCall[] | undefined = undefined;
if (response.content.some(content => content.type === 'tool_use')) {
toolCalls = [];
for (const content of response.content) {
if (content.type === 'tool_use') {
toolCalls.push({
toolCallType: 'function',
toolCallId: content.id,
toolName: content.name,
args: JSON.stringify(content.input),
});
}
}
}
return {
text,
toolCalls,
finishReason: mapAnthropicStopReason(response.stop_reason),
usage: {
promptTokens: response.usage.input_tokens,
completionTokens: response.usage.output_tokens,
},
rawCall: { rawPrompt, rawSettings },
warnings,
};
}
async doStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
const { args, warnings } = this.getArgs(options);
const response = await postJsonToApi({
url: `${this.config.baseUrl}/messages`,
headers: this.config.headers(),
body: {
...args,
stream: true,
},
failedResponseHandler: anthropicFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
anthropicMessagesChunkSchema,
),
abortSignal: options.abortSignal,
});
const { messages: rawPrompt, ...rawSettings } = args;
let finishReason: LanguageModelV1FinishReason = 'other';
const usage: { promptTokens: number; completionTokens: number } = {
promptTokens: Number.NaN,
completionTokens: Number.NaN,
};
return {
stream: response.pipeThrough(
new TransformStream<
ParseResult<z.infer<typeof anthropicMessagesChunkSchema>>,
LanguageModelV1StreamPart
>({
transform(chunk, controller) {
if (!chunk.success) {
controller.enqueue({ type: 'error', error: chunk.error });
return;
}
const value = chunk.value;
switch (value.type) {
case 'ping':
case 'content_block_start':
case 'content_block_stop': {
return; // ignored
}
case 'content_block_delta': {
controller.enqueue({
type: 'text-delta',
textDelta: value.delta.text,
});
return;
}
case 'message_start': {
usage.promptTokens = value.message.usage.input_tokens;
usage.completionTokens = value.message.usage.output_tokens;
return;
}
case 'message_delta': {
usage.completionTokens = value.usage.output_tokens;
finishReason = mapAnthropicStopReason(value.delta.stop_reason);
return;
}
case 'message_stop': {
controller.enqueue({ type: 'finish', finishReason, usage });
return;
}
default: {
const _exhaustiveCheck: never = value;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
},
}),
),
rawCall: { rawPrompt, rawSettings },
warnings,
};
}
}
// limited version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const anthropicMessagesResponseSchema = z.object({
type: z.literal('message'),
content: z.array(
z.discriminatedUnion('type', [
z.object({
type: z.literal('text'),
text: z.string(),
}),
z.object({
type: z.literal('tool_use'),
id: z.string(),
name: z.string(),
input: z.unknown(),
}),
]),
),
stop_reason: z.string().optional().nullable(),
usage: z.object({
input_tokens: z.number(),
output_tokens: z.number(),
}),
});
// limited version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const anthropicMessagesChunkSchema = z.discriminatedUnion('type', [
z.object({
type: z.literal('message_start'),
message: z.object({
usage: z.object({
input_tokens: z.number(),
output_tokens: z.number(),
}),
}),
}),
z.object({
type: z.literal('content_block_start'),
index: z.number(),
content_block: z.object({
type: z.literal('text'),
text: z.string(),
}),
}),
z.object({
type: z.literal('content_block_delta'),
index: z.number(),
delta: z.object({
type: z.literal('text_delta'),
text: z.string(),
}),
}),
z.object({
type: z.literal('content_block_stop'),
index: z.number(),
}),
z.object({
type: z.literal('message_delta'),
delta: z.object({ stop_reason: z.string().optional().nullable() }),
usage: z.object({ output_tokens: z.number() }),
}),
z.object({
type: z.literal('message_stop'),
}),
z.object({
type: z.literal('ping'),
}),
]);