373 lines
9.9 KiB
TypeScript
373 lines
9.9 KiB
TypeScript
import {
|
|
LanguageModelV1,
|
|
LanguageModelV1CallWarning,
|
|
LanguageModelV1FinishReason,
|
|
LanguageModelV1FunctionToolCall,
|
|
LanguageModelV1StreamPart,
|
|
UnsupportedFunctionalityError,
|
|
} from '@ai-sdk/provider';
|
|
import { z } from 'zod';
|
|
import {
|
|
ParseResult,
|
|
createEventSourceResponseHandler,
|
|
createJsonResponseHandler,
|
|
postJsonToApi,
|
|
} from '../spec';
|
|
import { anthropicFailedResponseHandler } from './anthropic-error';
|
|
import {
|
|
AnthropicMessagesModelId,
|
|
AnthropicMessagesSettings,
|
|
} from './anthropic-messages-settings';
|
|
import { convertToAnthropicMessagesPrompt } from './convert-to-anthropic-messages-prompt';
|
|
import { mapAnthropicStopReason } from './map-anthropic-stop-reason';
|
|
|
|
type AnthropicMessagesConfig = {
|
|
provider: string;
|
|
baseUrl: string;
|
|
headers: () => Record<string, string | undefined>;
|
|
};
|
|
|
|
export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
|
|
readonly specificationVersion = 'v1';
|
|
readonly defaultObjectGenerationMode = 'tool';
|
|
|
|
readonly modelId: AnthropicMessagesModelId;
|
|
readonly settings: AnthropicMessagesSettings;
|
|
|
|
private readonly config: AnthropicMessagesConfig;
|
|
|
|
constructor(
|
|
modelId: AnthropicMessagesModelId,
|
|
settings: AnthropicMessagesSettings,
|
|
config: AnthropicMessagesConfig,
|
|
) {
|
|
this.modelId = modelId;
|
|
this.settings = settings;
|
|
this.config = config;
|
|
}
|
|
|
|
get provider(): string {
|
|
return this.config.provider;
|
|
}
|
|
|
|
private getArgs({
|
|
mode,
|
|
prompt,
|
|
maxTokens,
|
|
temperature,
|
|
topP,
|
|
frequencyPenalty,
|
|
presencePenalty,
|
|
seed,
|
|
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
|
|
const type = mode.type;
|
|
|
|
const warnings: LanguageModelV1CallWarning[] = [];
|
|
|
|
if (frequencyPenalty != null) {
|
|
warnings.push({
|
|
type: 'unsupported-setting',
|
|
setting: 'frequencyPenalty',
|
|
});
|
|
}
|
|
|
|
if (presencePenalty != null) {
|
|
warnings.push({
|
|
type: 'unsupported-setting',
|
|
setting: 'presencePenalty',
|
|
});
|
|
}
|
|
|
|
if (seed != null) {
|
|
warnings.push({
|
|
type: 'unsupported-setting',
|
|
setting: 'seed',
|
|
});
|
|
}
|
|
|
|
const messagesPrompt = convertToAnthropicMessagesPrompt(prompt);
|
|
|
|
const baseArgs = {
|
|
// model id:
|
|
model: this.modelId,
|
|
|
|
// model specific settings:
|
|
top_k: this.settings.topK,
|
|
|
|
// standardized settings:
|
|
max_tokens: maxTokens ?? 4096, // 4096: max model output tokens
|
|
temperature, // uses 0..1 scale
|
|
top_p: topP,
|
|
|
|
// prompt:
|
|
system: messagesPrompt.system,
|
|
messages: messagesPrompt.messages,
|
|
};
|
|
|
|
switch (type) {
|
|
case 'regular': {
|
|
// when the tools array is empty, change it to undefined to prevent OpenAI errors:
|
|
const tools = mode.tools?.length ? mode.tools : undefined;
|
|
|
|
return {
|
|
args: {
|
|
...baseArgs,
|
|
tools: tools?.map(tool => ({
|
|
name: tool.name,
|
|
description: tool.description,
|
|
input_schema: tool.parameters,
|
|
})),
|
|
},
|
|
warnings,
|
|
};
|
|
}
|
|
|
|
case 'object-json': {
|
|
throw new UnsupportedFunctionalityError({
|
|
functionality: 'json-mode object generation',
|
|
});
|
|
}
|
|
|
|
case 'object-tool': {
|
|
const { name, description, parameters } = mode.tool;
|
|
|
|
// add instruction to use tool:
|
|
baseArgs.messages[baseArgs.messages.length - 1].content.push({
|
|
type: 'text',
|
|
text: `\n\nUse the '${name}' tool.`,
|
|
});
|
|
|
|
return {
|
|
args: {
|
|
...baseArgs,
|
|
tools: [{ name, description, input_schema: parameters }],
|
|
},
|
|
warnings,
|
|
};
|
|
}
|
|
|
|
case 'object-grammar': {
|
|
throw new UnsupportedFunctionalityError({
|
|
functionality: 'grammar-mode object generation',
|
|
});
|
|
}
|
|
|
|
default: {
|
|
const _exhaustiveCheck: never = type;
|
|
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
async doGenerate(
|
|
options: Parameters<LanguageModelV1['doGenerate']>[0],
|
|
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
|
const { args, warnings } = this.getArgs(options);
|
|
|
|
const response = await postJsonToApi({
|
|
url: `${this.config.baseUrl}/messages`,
|
|
headers: this.config.headers(),
|
|
body: args,
|
|
failedResponseHandler: anthropicFailedResponseHandler,
|
|
successfulResponseHandler: createJsonResponseHandler(
|
|
anthropicMessagesResponseSchema,
|
|
),
|
|
abortSignal: options.abortSignal,
|
|
});
|
|
|
|
const { messages: rawPrompt, ...rawSettings } = args;
|
|
|
|
// extract text
|
|
let text = '';
|
|
for (const content of response.content) {
|
|
if (content.type === 'text') {
|
|
text += content.text;
|
|
}
|
|
}
|
|
|
|
// extract tool calls
|
|
let toolCalls: LanguageModelV1FunctionToolCall[] | undefined = undefined;
|
|
if (response.content.some(content => content.type === 'tool_use')) {
|
|
toolCalls = [];
|
|
for (const content of response.content) {
|
|
if (content.type === 'tool_use') {
|
|
toolCalls.push({
|
|
toolCallType: 'function',
|
|
toolCallId: content.id,
|
|
toolName: content.name,
|
|
args: JSON.stringify(content.input),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
return {
|
|
text,
|
|
toolCalls,
|
|
finishReason: mapAnthropicStopReason(response.stop_reason),
|
|
usage: {
|
|
promptTokens: response.usage.input_tokens,
|
|
completionTokens: response.usage.output_tokens,
|
|
},
|
|
rawCall: { rawPrompt, rawSettings },
|
|
warnings,
|
|
};
|
|
}
|
|
|
|
async doStream(
|
|
options: Parameters<LanguageModelV1['doStream']>[0],
|
|
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
|
const { args, warnings } = this.getArgs(options);
|
|
|
|
const response = await postJsonToApi({
|
|
url: `${this.config.baseUrl}/messages`,
|
|
headers: this.config.headers(),
|
|
body: {
|
|
...args,
|
|
stream: true,
|
|
},
|
|
failedResponseHandler: anthropicFailedResponseHandler,
|
|
successfulResponseHandler: createEventSourceResponseHandler(
|
|
anthropicMessagesChunkSchema,
|
|
),
|
|
abortSignal: options.abortSignal,
|
|
});
|
|
|
|
const { messages: rawPrompt, ...rawSettings } = args;
|
|
|
|
let finishReason: LanguageModelV1FinishReason = 'other';
|
|
const usage: { promptTokens: number; completionTokens: number } = {
|
|
promptTokens: Number.NaN,
|
|
completionTokens: Number.NaN,
|
|
};
|
|
|
|
return {
|
|
stream: response.pipeThrough(
|
|
new TransformStream<
|
|
ParseResult<z.infer<typeof anthropicMessagesChunkSchema>>,
|
|
LanguageModelV1StreamPart
|
|
>({
|
|
transform(chunk, controller) {
|
|
if (!chunk.success) {
|
|
controller.enqueue({ type: 'error', error: chunk.error });
|
|
return;
|
|
}
|
|
|
|
const value = chunk.value;
|
|
|
|
switch (value.type) {
|
|
case 'ping':
|
|
case 'content_block_start':
|
|
case 'content_block_stop': {
|
|
return; // ignored
|
|
}
|
|
|
|
case 'content_block_delta': {
|
|
controller.enqueue({
|
|
type: 'text-delta',
|
|
textDelta: value.delta.text,
|
|
});
|
|
return;
|
|
}
|
|
|
|
case 'message_start': {
|
|
usage.promptTokens = value.message.usage.input_tokens;
|
|
usage.completionTokens = value.message.usage.output_tokens;
|
|
return;
|
|
}
|
|
|
|
case 'message_delta': {
|
|
usage.completionTokens = value.usage.output_tokens;
|
|
finishReason = mapAnthropicStopReason(value.delta.stop_reason);
|
|
return;
|
|
}
|
|
|
|
case 'message_stop': {
|
|
controller.enqueue({ type: 'finish', finishReason, usage });
|
|
return;
|
|
}
|
|
|
|
default: {
|
|
const _exhaustiveCheck: never = value;
|
|
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
|
|
}
|
|
}
|
|
},
|
|
}),
|
|
),
|
|
rawCall: { rawPrompt, rawSettings },
|
|
warnings,
|
|
};
|
|
}
|
|
}
|
|
|
|
// limited version of the schema, focussed on what is needed for the implementation
|
|
// this approach limits breakages when the API changes and increases efficiency
|
|
const anthropicMessagesResponseSchema = z.object({
|
|
type: z.literal('message'),
|
|
content: z.array(
|
|
z.discriminatedUnion('type', [
|
|
z.object({
|
|
type: z.literal('text'),
|
|
text: z.string(),
|
|
}),
|
|
z.object({
|
|
type: z.literal('tool_use'),
|
|
id: z.string(),
|
|
name: z.string(),
|
|
input: z.unknown(),
|
|
}),
|
|
]),
|
|
),
|
|
stop_reason: z.string().optional().nullable(),
|
|
usage: z.object({
|
|
input_tokens: z.number(),
|
|
output_tokens: z.number(),
|
|
}),
|
|
});
|
|
|
|
// limited version of the schema, focussed on what is needed for the implementation
|
|
// this approach limits breakages when the API changes and increases efficiency
|
|
const anthropicMessagesChunkSchema = z.discriminatedUnion('type', [
|
|
z.object({
|
|
type: z.literal('message_start'),
|
|
message: z.object({
|
|
usage: z.object({
|
|
input_tokens: z.number(),
|
|
output_tokens: z.number(),
|
|
}),
|
|
}),
|
|
}),
|
|
z.object({
|
|
type: z.literal('content_block_start'),
|
|
index: z.number(),
|
|
content_block: z.object({
|
|
type: z.literal('text'),
|
|
text: z.string(),
|
|
}),
|
|
}),
|
|
z.object({
|
|
type: z.literal('content_block_delta'),
|
|
index: z.number(),
|
|
delta: z.object({
|
|
type: z.literal('text_delta'),
|
|
text: z.string(),
|
|
}),
|
|
}),
|
|
z.object({
|
|
type: z.literal('content_block_stop'),
|
|
index: z.number(),
|
|
}),
|
|
z.object({
|
|
type: z.literal('message_delta'),
|
|
delta: z.object({ stop_reason: z.string().optional().nullable() }),
|
|
usage: z.object({ output_tokens: z.number() }),
|
|
}),
|
|
z.object({
|
|
type: z.literal('message_stop'),
|
|
}),
|
|
z.object({
|
|
type: z.literal('ping'),
|
|
}),
|
|
]);
|