import { LanguageModelV1, LanguageModelV1CallWarning, LanguageModelV1FinishReason, LanguageModelV1StreamPart, UnsupportedFunctionalityError, } from '@ai-sdk/provider'; import { z } from 'zod'; import { ParseResult, createEventSourceResponseHandler, createJsonResponseHandler, postJsonToApi, } from '../spec'; import { convertToGoogleGenerativeAIMessages } from './convert-to-google-generative-ai-messages'; import { googleFailedResponseHandler } from './google-error'; import { GoogleGenerativeAIContentPart } from './google-generative-ai-prompt'; import { GoogleGenerativeAIModelId, GoogleGenerativeAISettings, } from './google-generative-ai-settings'; import { mapGoogleGenerativeAIFinishReason } from './map-google-generative-ai-finish-reason'; type GoogleGenerativeAIConfig = { provider: string; baseUrl: string; headers: () => Record; generateId: () => string; }; export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { readonly specificationVersion = 'v1'; readonly defaultObjectGenerationMode = undefined; readonly modelId: GoogleGenerativeAIModelId; readonly settings: GoogleGenerativeAISettings; private readonly config: GoogleGenerativeAIConfig; constructor( modelId: GoogleGenerativeAIModelId, settings: GoogleGenerativeAISettings, config: GoogleGenerativeAIConfig, ) { this.modelId = modelId; this.settings = settings; this.config = config; } get provider(): string { return this.config.provider; } private getArgs({ mode, prompt, maxTokens, temperature, topP, frequencyPenalty, presencePenalty, seed, }: Parameters[0]) { const type = mode.type; const warnings: LanguageModelV1CallWarning[] = []; if (frequencyPenalty != null) { warnings.push({ type: 'unsupported-setting', setting: 'frequencyPenalty', }); } if (presencePenalty != null) { warnings.push({ type: 'unsupported-setting', setting: 'presencePenalty', }); } if (seed != null) { warnings.push({ type: 'unsupported-setting', setting: 'seed', }); } const baseArgs = { generationConfig: { // model specific settings: topK: this.settings.topK, // standardized settings: maxOutputTokens: maxTokens, temperature, topP, }, // prompt: contents: convertToGoogleGenerativeAIMessages(prompt), }; switch (type) { case 'regular': { const functionDeclarations = mode.tools?.map(tool => ({ name: tool.name, description: tool.description ?? '', parameters: prepareJsonSchema(tool.parameters), })); return { args: { ...baseArgs, tools: functionDeclarations == null ? undefined : { functionDeclarations }, }, warnings, }; } case 'object-json': { throw new UnsupportedFunctionalityError({ functionality: 'object-json mode', }); } case 'object-tool': { throw new UnsupportedFunctionalityError({ functionality: 'object-tool mode', }); } case 'object-grammar': { throw new UnsupportedFunctionalityError({ functionality: 'object-grammar mode', }); } default: { const _exhaustiveCheck: never = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } async doGenerate( options: Parameters[0], ): Promise>> { const { args, warnings } = this.getArgs(options); const response = await postJsonToApi({ url: `${this.config.baseUrl}/${this.modelId}:generateContent`, headers: this.config.headers(), body: args, failedResponseHandler: googleFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler(responseSchema), abortSignal: options.abortSignal, }); const { contents: rawPrompt, ...rawSettings } = args; const candidate = response.candidates[0]; const toolCalls = getToolCallsFromParts({ parts: candidate.content.parts, generateId: this.config.generateId, }); return { text: getTextFromParts(candidate.content.parts), toolCalls, finishReason: mapGoogleGenerativeAIFinishReason({ finishReason: candidate.finishReason, hasToolCalls: toolCalls != null && toolCalls.length > 0, }), usage: { promptTokens: NaN, completionTokens: candidate.tokenCount ?? NaN, }, rawCall: { rawPrompt, rawSettings }, warnings, }; } async doStream( options: Parameters[0], ): Promise>> { const { args, warnings } = this.getArgs(options); const response = await postJsonToApi({ url: `${this.config.baseUrl}/${this.modelId}:streamGenerateContent?alt=sse`, headers: this.config.headers(), body: args, failedResponseHandler: googleFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler(chunkSchema), abortSignal: options.abortSignal, }); const { contents: rawPrompt, ...rawSettings } = args; let finishReason: LanguageModelV1FinishReason = 'other'; let usage: { promptTokens: number; completionTokens: number } = { promptTokens: Number.NaN, completionTokens: Number.NaN, }; const generateId = this.config.generateId; let hasToolCalls = false; return { stream: response.pipeThrough( new TransformStream< ParseResult>, LanguageModelV1StreamPart >({ transform(chunk, controller) { if (!chunk.success) { controller.enqueue({ type: 'error', error: chunk.error }); return; } const value = chunk.value; const candidate = value.candidates[0]; if (candidate?.finishReason != null) { finishReason = mapGoogleGenerativeAIFinishReason({ finishReason: candidate.finishReason, hasToolCalls, }); } if (candidate.tokenCount != null) { usage = { promptTokens: NaN, completionTokens: candidate.tokenCount, }; } const content = candidate.content; if (content == null) { return; } const deltaText = getTextFromParts(content.parts); if (deltaText != null) { controller.enqueue({ type: 'text-delta', textDelta: deltaText, }); } const toolCallDeltas = getToolCallsFromParts({ parts: content.parts, generateId, }); if (toolCallDeltas != null) { for (const toolCall of toolCallDeltas) { controller.enqueue({ type: 'tool-call-delta', toolCallType: 'function', toolCallId: toolCall.toolCallId, toolName: toolCall.toolName, argsTextDelta: toolCall.args, }); controller.enqueue({ type: 'tool-call', toolCallType: 'function', toolCallId: toolCall.toolCallId, toolName: toolCall.toolName, args: toolCall.args, }); hasToolCalls = true; } } }, flush(controller) { controller.enqueue({ type: 'finish', finishReason, usage }); }, }), ), rawCall: { rawPrompt, rawSettings }, warnings, }; } } // Removes all "additionalProperty" and "$schema" properties from the object (recursively) // (not supported by Google Generative AI) function prepareJsonSchema(jsonSchema: any): unknown { if (typeof jsonSchema !== 'object') { return jsonSchema; } if (Array.isArray(jsonSchema)) { return jsonSchema.map(prepareJsonSchema); } const result: Record = {}; for (const [key, value] of Object.entries(jsonSchema)) { if (key === 'additionalProperties' || key === '$schema') { continue; } result[key] = prepareJsonSchema(value); } return result; } function getToolCallsFromParts({ parts, generateId, }: { parts: z.infer['parts']; generateId: () => string; }) { const functionCallParts = parts.filter( part => 'functionCall' in part, ) as Array< GoogleGenerativeAIContentPart & { functionCall: { name: string; args: unknown }; } >; return functionCallParts.length === 0 ? undefined : functionCallParts.map(part => ({ toolCallType: 'function' as const, toolCallId: generateId(), toolName: part.functionCall.name, args: JSON.stringify(part.functionCall.args), })); } function getTextFromParts(parts: z.infer['parts']) { const textParts = parts.filter(part => 'text' in part) as Array< GoogleGenerativeAIContentPart & { text: string } >; return textParts.length === 0 ? undefined : textParts.map(part => part.text).join(''); } const contentSchema = z.object({ role: z.string(), parts: z.array( z.union([ z.object({ text: z.string(), }), z.object({ functionCall: z.object({ name: z.string(), args: z.unknown(), }), }), ]), ), }); // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency const responseSchema = z.object({ candidates: z.array( z.object({ content: contentSchema, finishReason: z.string().optional(), tokenCount: z.number().optional(), }), ), }); // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency const chunkSchema = z.object({ candidates: z.array( z.object({ content: contentSchema.optional(), finishReason: z.string().optional(), tokenCount: z.number().optional(), }), ), });