242 lines
6.5 KiB
TypeScript
242 lines
6.5 KiB
TypeScript
import { LanguageModelV1Prompt } from '@ai-sdk/provider';
|
|
import { convertStreamToArray } from '../spec/test/convert-stream-to-array';
|
|
import { JsonTestServer } from '../spec/test/json-test-server';
|
|
import { StreamingTestServer } from '../spec/test/streaming-test-server';
|
|
import { Google } from './google-facade';
|
|
|
|
const TEST_PROMPT: LanguageModelV1Prompt = [
|
|
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
|
|
];
|
|
|
|
const SAFETY_RATINGS = [
|
|
{
|
|
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
|
probability: 'NEGLIGIBLE',
|
|
},
|
|
{
|
|
category: 'HARM_CATEGORY_HATE_SPEECH',
|
|
probability: 'NEGLIGIBLE',
|
|
},
|
|
{
|
|
category: 'HARM_CATEGORY_HARASSMENT',
|
|
probability: 'NEGLIGIBLE',
|
|
},
|
|
{
|
|
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
|
probability: 'NEGLIGIBLE',
|
|
},
|
|
];
|
|
|
|
const google = new Google({
|
|
apiKey: 'test-api-key',
|
|
generateId: () => 'test-id',
|
|
});
|
|
const model = google.generativeAI('models/gemini-pro');
|
|
|
|
describe('doGenerate', () => {
|
|
const server = new JsonTestServer(
|
|
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent',
|
|
);
|
|
|
|
server.setupTestEnvironment();
|
|
|
|
function prepareJsonResponse({ content = '' }: { content?: string }) {
|
|
server.responseBodyJson = {
|
|
candidates: [
|
|
{
|
|
content: {
|
|
parts: [{ text: content }],
|
|
role: 'model',
|
|
},
|
|
finishReason: 'STOP',
|
|
index: 0,
|
|
safetyRatings: SAFETY_RATINGS,
|
|
},
|
|
],
|
|
promptFeedback: { safetyRatings: SAFETY_RATINGS },
|
|
};
|
|
}
|
|
|
|
it('should extract text response', async () => {
|
|
prepareJsonResponse({ content: 'Hello, World!' });
|
|
|
|
const { text } = await model.doGenerate({
|
|
inputFormat: 'prompt',
|
|
mode: { type: 'regular' },
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(text).toStrictEqual('Hello, World!');
|
|
});
|
|
|
|
it('should extract tool calls', async () => {
|
|
server.responseBodyJson = {
|
|
candidates: [
|
|
{
|
|
content: {
|
|
parts: [
|
|
{
|
|
functionCall: {
|
|
name: 'test-tool',
|
|
args: { value: 'example value' },
|
|
},
|
|
},
|
|
],
|
|
role: 'model',
|
|
},
|
|
finishReason: 'STOP',
|
|
index: 0,
|
|
safetyRatings: SAFETY_RATINGS,
|
|
},
|
|
],
|
|
promptFeedback: { safetyRatings: SAFETY_RATINGS },
|
|
};
|
|
|
|
const { toolCalls, finishReason, text } = await model.doGenerate({
|
|
inputFormat: 'prompt',
|
|
mode: {
|
|
type: 'regular',
|
|
tools: [
|
|
{
|
|
type: 'function',
|
|
name: 'test-tool',
|
|
parameters: {
|
|
type: 'object',
|
|
properties: { value: { type: 'string' } },
|
|
required: ['value'],
|
|
additionalProperties: false,
|
|
$schema: 'http://json-schema.org/draft-07/schema#',
|
|
},
|
|
},
|
|
],
|
|
},
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(toolCalls).toStrictEqual([
|
|
{
|
|
toolCallId: 'test-id',
|
|
toolCallType: 'function',
|
|
toolName: 'test-tool',
|
|
args: '{"value":"example value"}',
|
|
},
|
|
]);
|
|
expect(text).toStrictEqual(undefined);
|
|
expect(finishReason).toStrictEqual('tool-calls');
|
|
});
|
|
|
|
it('should pass the model and the messages', async () => {
|
|
prepareJsonResponse({ content: '' });
|
|
|
|
await model.doGenerate({
|
|
inputFormat: 'prompt',
|
|
mode: { type: 'regular' },
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(await server.getRequestBodyJson()).toStrictEqual({
|
|
contents: [
|
|
{
|
|
role: 'user',
|
|
parts: [{ text: 'Hello' }],
|
|
},
|
|
],
|
|
generationConfig: {},
|
|
});
|
|
});
|
|
|
|
it('should pass the api key as Authorization header', async () => {
|
|
prepareJsonResponse({ content: '' });
|
|
|
|
const google = new Google({ apiKey: 'test-api-key' });
|
|
const model = google.generativeAI('models/gemini-pro');
|
|
|
|
await model.doGenerate({
|
|
inputFormat: 'prompt',
|
|
mode: { type: 'regular' },
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(
|
|
(await server.getRequestHeaders()).get('x-goog-api-key'),
|
|
).toStrictEqual('test-api-key');
|
|
});
|
|
});
|
|
|
|
describe('doStream', () => {
|
|
const server = new StreamingTestServer(
|
|
'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?alt=sse',
|
|
);
|
|
|
|
server.setupTestEnvironment();
|
|
|
|
function prepareStreamResponse({ content }: { content: string[] }) {
|
|
server.responseChunks = content.map(
|
|
text =>
|
|
`data: {"candidates": [{"content": {"parts": [{"text": "${text}"}],"role": "model"},` +
|
|
`"finishReason": "STOP","index": 0,"safetyRatings": [` +
|
|
`{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},` +
|
|
`{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},` +
|
|
`{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},` +
|
|
`{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}]}\n\n`,
|
|
);
|
|
}
|
|
|
|
it('should stream text deltas', async () => {
|
|
prepareStreamResponse({ content: ['Hello', ', ', 'world!'] });
|
|
|
|
const { stream } = await model.doStream({
|
|
inputFormat: 'prompt',
|
|
mode: { type: 'regular' },
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(await convertStreamToArray(stream)).toStrictEqual([
|
|
{ type: 'text-delta', textDelta: 'Hello' },
|
|
{ type: 'text-delta', textDelta: ', ' },
|
|
{ type: 'text-delta', textDelta: 'world!' },
|
|
{
|
|
type: 'finish',
|
|
finishReason: 'stop',
|
|
usage: { promptTokens: NaN, completionTokens: NaN },
|
|
},
|
|
]);
|
|
});
|
|
|
|
it('should pass the messages', async () => {
|
|
prepareStreamResponse({ content: [''] });
|
|
|
|
await model.doStream({
|
|
inputFormat: 'prompt',
|
|
mode: { type: 'regular' },
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(await server.getRequestBodyJson()).toStrictEqual({
|
|
contents: [
|
|
{
|
|
role: 'user',
|
|
parts: [{ text: 'Hello' }],
|
|
},
|
|
],
|
|
generationConfig: {},
|
|
});
|
|
});
|
|
|
|
it('should pass the api key as Authorization header', async () => {
|
|
prepareStreamResponse({ content: [''] });
|
|
|
|
const google = new Google({ apiKey: 'test-api-key' });
|
|
|
|
await google.generativeAI('models/gemini-pro').doStream({
|
|
inputFormat: 'prompt',
|
|
mode: { type: 'regular' },
|
|
prompt: TEST_PROMPT,
|
|
});
|
|
|
|
expect(
|
|
(await server.getRequestHeaders()).get('x-goog-api-key'),
|
|
).toStrictEqual('test-api-key');
|
|
});
|
|
});
|