hts/packages/isdk/google/google-generative-ai-langua...

242 lines
6.5 KiB
TypeScript

import { LanguageModelV1Prompt } from '@ai-sdk/provider';
import { convertStreamToArray } from '../spec/test/convert-stream-to-array';
import { JsonTestServer } from '../spec/test/json-test-server';
import { StreamingTestServer } from '../spec/test/streaming-test-server';
import { Google } from './google-facade';
const TEST_PROMPT: LanguageModelV1Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
];
const SAFETY_RATINGS = [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
probability: 'NEGLIGIBLE',
},
];
const google = new Google({
apiKey: 'test-api-key',
generateId: () => 'test-id',
});
const model = google.generativeAI('models/gemini-pro');
describe('doGenerate', () => {
const server = new JsonTestServer(
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent',
);
server.setupTestEnvironment();
function prepareJsonResponse({ content = '' }: { content?: string }) {
server.responseBodyJson = {
candidates: [
{
content: {
parts: [{ text: content }],
role: 'model',
},
finishReason: 'STOP',
index: 0,
safetyRatings: SAFETY_RATINGS,
},
],
promptFeedback: { safetyRatings: SAFETY_RATINGS },
};
}
it('should extract text response', async () => {
prepareJsonResponse({ content: 'Hello, World!' });
const { text } = await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});
expect(text).toStrictEqual('Hello, World!');
});
it('should extract tool calls', async () => {
server.responseBodyJson = {
candidates: [
{
content: {
parts: [
{
functionCall: {
name: 'test-tool',
args: { value: 'example value' },
},
},
],
role: 'model',
},
finishReason: 'STOP',
index: 0,
safetyRatings: SAFETY_RATINGS,
},
],
promptFeedback: { safetyRatings: SAFETY_RATINGS },
};
const { toolCalls, finishReason, text } = await model.doGenerate({
inputFormat: 'prompt',
mode: {
type: 'regular',
tools: [
{
type: 'function',
name: 'test-tool',
parameters: {
type: 'object',
properties: { value: { type: 'string' } },
required: ['value'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
],
},
prompt: TEST_PROMPT,
});
expect(toolCalls).toStrictEqual([
{
toolCallId: 'test-id',
toolCallType: 'function',
toolName: 'test-tool',
args: '{"value":"example value"}',
},
]);
expect(text).toStrictEqual(undefined);
expect(finishReason).toStrictEqual('tool-calls');
});
it('should pass the model and the messages', async () => {
prepareJsonResponse({ content: '' });
await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});
expect(await server.getRequestBodyJson()).toStrictEqual({
contents: [
{
role: 'user',
parts: [{ text: 'Hello' }],
},
],
generationConfig: {},
});
});
it('should pass the api key as Authorization header', async () => {
prepareJsonResponse({ content: '' });
const google = new Google({ apiKey: 'test-api-key' });
const model = google.generativeAI('models/gemini-pro');
await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});
expect(
(await server.getRequestHeaders()).get('x-goog-api-key'),
).toStrictEqual('test-api-key');
});
});
describe('doStream', () => {
const server = new StreamingTestServer(
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?alt=sse',
);
server.setupTestEnvironment();
function prepareStreamResponse({ content }: { content: string[] }) {
server.responseChunks = content.map(
text =>
`data: {"candidates": [{"content": {"parts": [{"text": "${text}"}],"role": "model"},` +
`"finishReason": "STOP","index": 0,"safetyRatings": [` +
`{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},` +
`{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},` +
`{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},` +
`{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}]}\n\n`,
);
}
it('should stream text deltas', async () => {
prepareStreamResponse({ content: ['Hello', ', ', 'world!'] });
const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});
expect(await convertStreamToArray(stream)).toStrictEqual([
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: 'world!' },
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: NaN, completionTokens: NaN },
},
]);
});
it('should pass the messages', async () => {
prepareStreamResponse({ content: [''] });
await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});
expect(await server.getRequestBodyJson()).toStrictEqual({
contents: [
{
role: 'user',
parts: [{ text: 'Hello' }],
},
],
generationConfig: {},
});
});
it('should pass the api key as Authorization header', async () => {
prepareStreamResponse({ content: [''] });
const google = new Google({ apiKey: 'test-api-key' });
await google.generativeAI('models/gemini-pro').doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});
expect(
(await server.getRequestHeaders()).get('x-goog-api-key'),
).toStrictEqual('test-api-key');
});
});