import { useSWR } from 'sswr'; import { Readable, Writable, derived, get, writable } from 'svelte/store'; import { callChatApi } from '../shared/call-chat-api'; import { processChatStream } from '../shared/process-chat-stream'; import type { ChatRequest, ChatRequestOptions, CreateMessage, IdGenerator, JSONValue, Message, UseChatOptions, } from '../shared/types'; import { generateId as generateIdFunc } from '../shared/generate-id'; export type { CreateMessage, Message, UseChatOptions }; export type UseChatHelpers = { /** Current messages in the chat */ messages: Readable; /** The error object of the API request */ error: Readable; /** * Append a user message to the chat list. This triggers the API call to fetch * the assistant's response. * @param message The message to append * @param chatRequestOptions Additional options to pass to the API call */ append: ( message: Message | CreateMessage, chatRequestOptions?: ChatRequestOptions, ) => Promise; /** * Reload the last AI chat response for the given chat history. If the last * message isn't from the assistant, it will request the API to generate a * new response. */ reload: ( chatRequestOptions?: ChatRequestOptions, ) => Promise; /** * Abort the current request immediately, keep the generated tokens if any. */ stop: () => void; /** * Update the `messages` state locally. This is useful when you want to * edit the messages on the client, and then trigger the `reload` method * manually to regenerate the AI response. */ setMessages: (messages: Message[]) => void; /** The current value of the input */ input: Writable; /** Form submission handler to automatically reset input and append a user message */ handleSubmit: (e: any, chatRequestOptions?: ChatRequestOptions) => void; metadata?: Object; /** Whether the API request is in progress */ isLoading: Readable; /** Additional data added on the server via StreamData */ data: Readable; }; const getStreamedResponse = async ( api: string, chatRequest: ChatRequest, mutate: (messages: Message[]) => void, mutateStreamData: (data: JSONValue[] | undefined) => void, existingData: JSONValue[] | undefined, extraMetadata: { credentials?: RequestCredentials; headers?: Record | Headers; body?: any; }, previousMessages: Message[], abortControllerRef: AbortController | null, generateId: IdGenerator, onFinish?: (message: Message) => void, onResponse?: (response: Response) => void | Promise, sendExtraMessageFields?: boolean, ) => { // Do an optimistic update to the chat state to show the updated messages // immediately. mutate(chatRequest.messages); const constructedMessagesPayload = sendExtraMessageFields ? chatRequest.messages : chatRequest.messages.map( ({ role, content, name, function_call, tool_calls, tool_call_id }) => ({ role, content, tool_call_id, ...(name !== undefined && { name }), ...(function_call !== undefined && { function_call: function_call, }), ...(tool_calls !== undefined && { tool_calls: tool_calls, }), }), ); return await callChatApi({ api, messages: constructedMessagesPayload, body: { ...extraMetadata.body, ...chatRequest.options?.body, ...(chatRequest.functions !== undefined && { functions: chatRequest.functions, }), ...(chatRequest.function_call !== undefined && { function_call: chatRequest.function_call, }), ...(chatRequest.tools !== undefined && { tools: chatRequest.tools, }), ...(chatRequest.tool_choice !== undefined && { tool_choice: chatRequest.tool_choice, }), }, credentials: extraMetadata.credentials, headers: { ...extraMetadata.headers, ...chatRequest.options?.headers, }, abortController: () => abortControllerRef, restoreMessagesOnFailure() { mutate(previousMessages); }, onResponse, onUpdate(merged, data) { mutate([...chatRequest.messages, ...merged]); mutateStreamData([...(existingData || []), ...(data || [])]); }, onFinish, generateId, }); }; let uniqueId = 0; const store: Record = {}; export function useChat({ api = '/api/chat', id, initialMessages = [], initialInput = '', sendExtraMessageFields, experimental_onFunctionCall, experimental_onToolCall, onResponse, onFinish, onError, credentials, headers, body, generateId = generateIdFunc, }: UseChatOptions = {}): UseChatHelpers { // Generate a unique id for the chat if not provided. const chatId = id || `chat-${uniqueId++}`; const key = `${api}|${chatId}`; const { data, mutate: originalMutate, isLoading: isSWRLoading, } = useSWR(key, { fetcher: () => store[key] || initialMessages, fallbackData: initialMessages, }); const streamData = writable(undefined); const loading = writable(false); // Force the `data` to be `initialMessages` if it's `undefined`. data.set(initialMessages); const mutate = (data: Message[]) => { store[key] = data; return originalMutate(data); }; // Because of the `fallbackData` option, the `data` will never be `undefined`. const messages = data as Writable; // Abort controller to cancel the current API call. let abortController: AbortController | null = null; const extraMetadata = { credentials, headers, body, }; const error = writable(undefined); // Actual mutation hook to send messages to the API endpoint and update the // chat state. async function triggerRequest(chatRequest: ChatRequest) { try { error.set(undefined); loading.set(true); abortController = new AbortController(); await processChatStream({ getStreamedResponse: () => getStreamedResponse( api, chatRequest, mutate, data => { streamData.set(data); }, get(streamData), extraMetadata, get(messages), abortController, generateId, onFinish, onResponse, sendExtraMessageFields, ), experimental_onFunctionCall, experimental_onToolCall, updateChatRequest: chatRequestParam => { chatRequest = chatRequestParam; }, getCurrentMessages: () => get(messages), }); abortController = null; return null; } catch (err) { // Ignore abort errors as they are expected. if ((err as any).name === 'AbortError') { abortController = null; return null; } if (onError && err instanceof Error) { onError(err); } error.set(err as Error); } finally { loading.set(false); } } const append: UseChatHelpers['append'] = async ( message: Message | CreateMessage, { options, functions, function_call, tools, tool_choice, }: ChatRequestOptions = {}, ) => { if (!message.id) { message.id = generateId(); } const chatRequest: ChatRequest = { messages: get(messages).concat(message as Message), options, ...(functions !== undefined && { functions }), ...(function_call !== undefined && { function_call }), ...(tools !== undefined && { tools }), ...(tool_choice !== undefined && { tool_choice }), }; return triggerRequest(chatRequest); }; const reload: UseChatHelpers['reload'] = async ({ options, functions, function_call, tools, tool_choice, }: ChatRequestOptions = {}) => { const messagesSnapshot = get(messages); if (messagesSnapshot.length === 0) return null; // Remove last assistant message and retry last user message. const lastMessage = messagesSnapshot.at(-1); if (lastMessage?.role === 'assistant') { const chatRequest: ChatRequest = { messages: messagesSnapshot.slice(0, -1), options, ...(functions !== undefined && { functions }), ...(function_call !== undefined && { function_call }), ...(tools !== undefined && { tools }), ...(tool_choice !== undefined && { tool_choice }), }; return triggerRequest(chatRequest); } const chatRequest: ChatRequest = { messages: messagesSnapshot, options, ...(functions !== undefined && { functions }), ...(function_call !== undefined && { function_call }), ...(tools !== undefined && { tools }), ...(tool_choice !== undefined && { tool_choice }), }; return triggerRequest(chatRequest); }; const stop = () => { if (abortController) { abortController.abort(); abortController = null; } }; const setMessages = (messages: Message[]) => { mutate(messages); }; const input = writable(initialInput); const handleSubmit = (e: any, options: ChatRequestOptions = {}) => { e.preventDefault(); const inputValue = get(input); if (!inputValue) return; append( { content: inputValue, role: 'user', createdAt: new Date(), }, options, ); input.set(''); }; const isLoading = derived( [isSWRLoading, loading], ([$isSWRLoading, $loading]) => { return $isSWRLoading || $loading; }, ); return { messages, error, append, reload, stop, setMessages, input, handleSubmit, isLoading, data: streamData, }; }