422 lines
12 KiB
TypeScript
422 lines
12 KiB
TypeScript
import {
|
|
InvalidResponseDataError,
|
|
LanguageModelV1,
|
|
LanguageModelV1FinishReason,
|
|
LanguageModelV1StreamPart,
|
|
UnsupportedFunctionalityError,
|
|
} from '@ai-sdk/provider';
|
|
import { z } from 'zod';
|
|
import {
|
|
ParseResult,
|
|
createEventSourceResponseHandler,
|
|
createJsonResponseHandler,
|
|
generateId,
|
|
isParseableJson,
|
|
postJsonToApi,
|
|
scale,
|
|
} from '../spec';
|
|
import { convertToOpenAIChatMessages } from './convert-to-openai-chat-messages';
|
|
import { mapOpenAIFinishReason } from './map-openai-finish-reason';
|
|
import { OpenAIChatModelId, OpenAIChatSettings } from './openai-chat-settings';
|
|
import { openaiFailedResponseHandler } from './openai-error';
|
|
|
|
type OpenAIChatConfig = {
|
|
provider: string;
|
|
baseUrl: string;
|
|
headers: () => Record<string, string | undefined>;
|
|
};
|
|
|
|
export class OpenAIChatLanguageModel implements LanguageModelV1 {
|
|
readonly specificationVersion = 'v1';
|
|
readonly defaultObjectGenerationMode = 'tool';
|
|
|
|
readonly modelId: OpenAIChatModelId;
|
|
readonly settings: OpenAIChatSettings;
|
|
|
|
private readonly config: OpenAIChatConfig;
|
|
|
|
constructor(
|
|
modelId: OpenAIChatModelId,
|
|
settings: OpenAIChatSettings,
|
|
config: OpenAIChatConfig,
|
|
) {
|
|
this.modelId = modelId;
|
|
this.settings = settings;
|
|
this.config = config;
|
|
}
|
|
|
|
get provider(): string {
|
|
return this.config.provider;
|
|
}
|
|
|
|
private getArgs({
|
|
mode,
|
|
prompt,
|
|
maxTokens,
|
|
temperature,
|
|
topP,
|
|
frequencyPenalty,
|
|
presencePenalty,
|
|
seed,
|
|
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
|
|
const type = mode.type;
|
|
|
|
const baseArgs = {
|
|
// model id:
|
|
model: this.modelId,
|
|
|
|
// model specific settings:
|
|
logit_bias: this.settings.logitBias,
|
|
user: this.settings.user,
|
|
|
|
// standardized settings:
|
|
max_tokens: maxTokens,
|
|
temperature: scale({
|
|
value: temperature,
|
|
outputMin: 0,
|
|
outputMax: 2,
|
|
}),
|
|
top_p: topP,
|
|
frequency_penalty: scale({
|
|
value: frequencyPenalty,
|
|
inputMin: -1,
|
|
inputMax: 1,
|
|
outputMin: -2,
|
|
outputMax: 2,
|
|
}),
|
|
presence_penalty: scale({
|
|
value: presencePenalty,
|
|
inputMin: -1,
|
|
inputMax: 1,
|
|
outputMin: -2,
|
|
outputMax: 2,
|
|
}),
|
|
seed,
|
|
|
|
// messages:
|
|
messages: convertToOpenAIChatMessages(prompt),
|
|
};
|
|
|
|
switch (type) {
|
|
case 'regular': {
|
|
// when the tools array is empty, change it to undefined to prevent OpenAI errors:
|
|
const tools = mode.tools?.length ? mode.tools : undefined;
|
|
|
|
return {
|
|
...baseArgs,
|
|
tools: tools?.map(tool => ({
|
|
type: 'function',
|
|
function: {
|
|
name: tool.name,
|
|
description: tool.description,
|
|
parameters: tool.parameters,
|
|
},
|
|
})),
|
|
};
|
|
}
|
|
|
|
case 'object-json': {
|
|
return {
|
|
...baseArgs,
|
|
response_format: { type: 'json_object' },
|
|
};
|
|
}
|
|
|
|
case 'object-tool': {
|
|
return {
|
|
...baseArgs,
|
|
tool_choice: { type: 'function', function: { name: mode.tool.name } },
|
|
tools: [
|
|
{
|
|
type: 'function',
|
|
function: {
|
|
name: mode.tool.name,
|
|
description: mode.tool.description,
|
|
parameters: mode.tool.parameters,
|
|
},
|
|
},
|
|
],
|
|
};
|
|
}
|
|
|
|
case 'object-grammar': {
|
|
throw new UnsupportedFunctionalityError({
|
|
functionality: 'object-grammar mode',
|
|
});
|
|
}
|
|
|
|
default: {
|
|
const _exhaustiveCheck: never = type;
|
|
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
async doGenerate(
|
|
options: Parameters<LanguageModelV1['doGenerate']>[0],
|
|
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
|
const args = this.getArgs(options);
|
|
|
|
const response = await postJsonToApi({
|
|
url: `${this.config.baseUrl}/chat/completions`,
|
|
headers: this.config.headers(),
|
|
body: args,
|
|
failedResponseHandler: openaiFailedResponseHandler,
|
|
successfulResponseHandler: createJsonResponseHandler(
|
|
openAIChatResponseSchema,
|
|
),
|
|
abortSignal: options.abortSignal,
|
|
});
|
|
|
|
const { messages: rawPrompt, ...rawSettings } = args;
|
|
const choice = response.choices[0];
|
|
|
|
return {
|
|
text: choice.message.content ?? undefined,
|
|
toolCalls: choice.message.tool_calls?.map(toolCall => ({
|
|
toolCallType: 'function',
|
|
toolCallId: toolCall.id,
|
|
toolName: toolCall.function.name,
|
|
args: toolCall.function.arguments!,
|
|
})),
|
|
finishReason: mapOpenAIFinishReason(choice.finish_reason),
|
|
usage: {
|
|
promptTokens: response.usage.prompt_tokens,
|
|
completionTokens: response.usage.completion_tokens,
|
|
},
|
|
rawCall: { rawPrompt, rawSettings },
|
|
warnings: [],
|
|
};
|
|
}
|
|
|
|
async doStream(
|
|
options: Parameters<LanguageModelV1['doStream']>[0],
|
|
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
|
const args = this.getArgs(options);
|
|
|
|
const response = await postJsonToApi({
|
|
url: `${this.config.baseUrl}/chat/completions`,
|
|
headers: this.config.headers(),
|
|
body: {
|
|
...args,
|
|
stream: true,
|
|
},
|
|
failedResponseHandler: openaiFailedResponseHandler,
|
|
successfulResponseHandler: createEventSourceResponseHandler(
|
|
openaiChatChunkSchema,
|
|
),
|
|
abortSignal: options.abortSignal,
|
|
});
|
|
|
|
const { messages: rawPrompt, ...rawSettings } = args;
|
|
|
|
const toolCalls: Array<{
|
|
id: string;
|
|
type: 'function';
|
|
function: {
|
|
name: string;
|
|
arguments: string;
|
|
};
|
|
}> = [];
|
|
|
|
let finishReason: LanguageModelV1FinishReason = 'other';
|
|
let usage: { promptTokens: number; completionTokens: number } = {
|
|
promptTokens: Number.NaN,
|
|
completionTokens: Number.NaN,
|
|
};
|
|
|
|
return {
|
|
stream: response.pipeThrough(
|
|
new TransformStream<
|
|
ParseResult<z.infer<typeof openaiChatChunkSchema>>,
|
|
LanguageModelV1StreamPart
|
|
>({
|
|
transform(chunk, controller) {
|
|
if (!chunk.success) {
|
|
controller.enqueue({ type: 'error', error: chunk.error });
|
|
return;
|
|
}
|
|
|
|
const value = chunk.value;
|
|
|
|
if (value.usage != null) {
|
|
usage = {
|
|
promptTokens: value.usage.prompt_tokens,
|
|
completionTokens: value.usage.completion_tokens,
|
|
};
|
|
}
|
|
|
|
const choice = value.choices[0];
|
|
|
|
if (choice?.finish_reason != null) {
|
|
finishReason = mapOpenAIFinishReason(choice.finish_reason);
|
|
}
|
|
|
|
if (choice?.delta == null) {
|
|
return;
|
|
}
|
|
|
|
const delta = choice.delta;
|
|
|
|
if (delta.content != null) {
|
|
controller.enqueue({
|
|
type: 'text-delta',
|
|
textDelta: delta.content,
|
|
});
|
|
}
|
|
|
|
if (delta.tool_calls != null) {
|
|
for (const toolCallDelta of delta.tool_calls) {
|
|
const index = toolCallDelta.index;
|
|
|
|
// Tool call start. OpenAI returns all information except the arguments in the first chunk.
|
|
if (toolCalls[index] == null) {
|
|
if (toolCallDelta.type !== 'function') {
|
|
throw new InvalidResponseDataError({
|
|
data: toolCallDelta,
|
|
message: `Expected 'function' type.`,
|
|
});
|
|
}
|
|
|
|
if (toolCallDelta.id == null) {
|
|
throw new InvalidResponseDataError({
|
|
data: toolCallDelta,
|
|
message: `Expected 'id' to be a string.`,
|
|
});
|
|
}
|
|
|
|
if (toolCallDelta.function?.name == null) {
|
|
throw new InvalidResponseDataError({
|
|
data: toolCallDelta,
|
|
message: `Expected 'function.name' to be a string.`,
|
|
});
|
|
}
|
|
|
|
toolCalls[index] = {
|
|
id: toolCallDelta.id,
|
|
type: 'function',
|
|
function: {
|
|
name: toolCallDelta.function.name,
|
|
arguments: toolCallDelta.function.arguments ?? '',
|
|
},
|
|
};
|
|
|
|
continue;
|
|
}
|
|
|
|
// existing tool call, merge
|
|
const toolCall = toolCalls[index];
|
|
|
|
if (toolCallDelta.function?.arguments != null) {
|
|
toolCall.function!.arguments +=
|
|
toolCallDelta.function?.arguments ?? '';
|
|
}
|
|
|
|
// send delta
|
|
controller.enqueue({
|
|
type: 'tool-call-delta',
|
|
toolCallType: 'function',
|
|
toolCallId: toolCall.id,
|
|
toolName: toolCall.function.name,
|
|
argsTextDelta: toolCallDelta.function.arguments ?? '',
|
|
});
|
|
|
|
// check if tool call is complete
|
|
if (
|
|
toolCall.function?.name == null ||
|
|
toolCall.function?.arguments == null ||
|
|
!isParseableJson(toolCall.function.arguments)
|
|
) {
|
|
continue;
|
|
}
|
|
|
|
controller.enqueue({
|
|
type: 'tool-call',
|
|
toolCallType: 'function',
|
|
toolCallId: toolCall.id ?? generateId(),
|
|
toolName: toolCall.function.name,
|
|
args: toolCall.function.arguments,
|
|
});
|
|
}
|
|
}
|
|
},
|
|
|
|
flush(controller) {
|
|
controller.enqueue({ type: 'finish', finishReason, usage });
|
|
},
|
|
}),
|
|
),
|
|
rawCall: { rawPrompt, rawSettings },
|
|
warnings: [],
|
|
};
|
|
}
|
|
}
|
|
|
|
// limited version of the schema, focussed on what is needed for the implementation
|
|
// this approach limits breakages when the API changes and increases efficiency
|
|
const openAIChatResponseSchema = z.object({
|
|
choices: z.array(
|
|
z.object({
|
|
message: z.object({
|
|
role: z.literal('assistant'),
|
|
content: z.string().nullable(),
|
|
tool_calls: z
|
|
.array(
|
|
z.object({
|
|
id: z.string(),
|
|
type: z.literal('function'),
|
|
function: z.object({
|
|
name: z.string(),
|
|
arguments: z.string(),
|
|
}),
|
|
}),
|
|
)
|
|
.optional(),
|
|
}),
|
|
index: z.number(),
|
|
finish_reason: z.string().optional().nullable(),
|
|
}),
|
|
),
|
|
object: z.literal('chat.completion'),
|
|
usage: z.object({
|
|
prompt_tokens: z.number(),
|
|
completion_tokens: z.number(),
|
|
}),
|
|
});
|
|
|
|
// limited version of the schema, focussed on what is needed for the implementation
|
|
// this approach limits breakages when the API changes and increases efficiency
|
|
const openaiChatChunkSchema = z.object({
|
|
object: z.literal('chat.completion.chunk'),
|
|
choices: z.array(
|
|
z.object({
|
|
delta: z.object({
|
|
role: z.enum(['assistant']).optional(),
|
|
content: z.string().nullable().optional(),
|
|
tool_calls: z
|
|
.array(
|
|
z.object({
|
|
index: z.number(),
|
|
id: z.string().optional(),
|
|
type: z.literal('function').optional(),
|
|
function: z.object({
|
|
name: z.string().optional(),
|
|
arguments: z.string().optional(),
|
|
}),
|
|
}),
|
|
)
|
|
.optional(),
|
|
}),
|
|
finish_reason: z.string().nullable().optional(),
|
|
index: z.number(),
|
|
}),
|
|
),
|
|
usage: z
|
|
.object({
|
|
prompt_tokens: z.number(),
|
|
completion_tokens: z.number(),
|
|
})
|
|
.optional()
|
|
.nullable(),
|
|
});
|