import { useCallback, useEffect, useId, useRef, useState } from 'react'; import useSWR, { KeyedMutator } from 'swr'; import { callChatApi } from '../shared/call-chat-api'; import { generateId as generateIdFunc } from '../shared/generate-id'; import { processChatStream } from '../shared/process-chat-stream'; import type { ChatRequest, ChatRequestOptions, CreateMessage, IdGenerator, JSONValue, Message, UseChatOptions, } from '../shared/types'; import type { ReactResponseRow, experimental_StreamingReactResponse, } from '../streams/streaming-react-response'; export type { CreateMessage, Message, UseChatOptions }; export type UseChatHelpers = { /** Current messages in the chat */ messages: Message[]; /** The error object of the API request */ error: undefined | Error; /** * Append a user message to the chat list. This triggers the API call to fetch * the assistant's response. * @param message The message to append * @param options Additional options to pass to the API call */ append: ( message: Message | CreateMessage, chatRequestOptions?: ChatRequestOptions, ) => Promise; /** * Reload the last AI chat response for the given chat history. If the last * message isn't from the assistant, it will request the API to generate a * new response. */ reload: ( chatRequestOptions?: ChatRequestOptions, ) => Promise; /** * Abort the current request immediately, keep the generated tokens if any. */ stop: () => void; /** * Update the `messages` state locally. This is useful when you want to * edit the messages on the client, and then trigger the `reload` method * manually to regenerate the AI response. */ setMessages: (messages: Message[]) => void; /** The current value of the input */ input: string; /** setState-powered method to update the input value */ setInput: React.Dispatch>; /** An input/textarea-ready onChange handler to control the value of the input */ handleInputChange: ( e: | React.ChangeEvent | React.ChangeEvent, ) => void; /** Form submission handler to automatically reset input and append a user message */ handleSubmit: ( e: React.FormEvent, chatRequestOptions?: ChatRequestOptions, ) => void; metadata?: Object; /** Whether the API request is in progress */ isLoading: boolean; /** Additional data added on the server via StreamData */ data?: JSONValue[]; }; type StreamingReactResponseAction = (payload: { messages: Message[]; data?: Record; }) => Promise; const getStreamedResponse = async ( api: string | StreamingReactResponseAction, chatRequest: ChatRequest, mutate: KeyedMutator, mutateStreamData: KeyedMutator, existingData: JSONValue[] | undefined, extraMetadataRef: React.MutableRefObject, messagesRef: React.MutableRefObject, abortControllerRef: React.MutableRefObject, generateId: IdGenerator, onFinish?: (message: Message) => void, onResponse?: (response: Response) => void | Promise, sendExtraMessageFields?: boolean, ) => { // Do an optimistic update to the chat state to show the updated messages // immediately. const previousMessages = messagesRef.current; mutate(chatRequest.messages, false); const constructedMessagesPayload = sendExtraMessageFields ? chatRequest.messages : chatRequest.messages.map( ({ role, content, name, function_call, tool_calls, tool_call_id }) => ({ role, content, tool_call_id, ...(name !== undefined && { name }), ...(function_call !== undefined && { function_call: function_call, }), ...(tool_calls !== undefined && { tool_calls: tool_calls, }), }), ); if (typeof api !== 'string') { // In this case, we are handling a Server Action. No complex mode handling needed. const replyId = generateId(); const createdAt = new Date(); let responseMessage: Message = { id: replyId, createdAt, content: '', role: 'assistant', }; async function readRow(promise: Promise) { const { content, ui, next } = await promise; // TODO: Handle function calls. responseMessage['content'] = content; responseMessage['ui'] = await ui; mutate([...chatRequest.messages, { ...responseMessage }], false); if (next) { await readRow(next); } } try { const promise = api({ messages: constructedMessagesPayload as Message[], data: chatRequest.data, }) as Promise; await readRow(promise); } catch (e) { // Restore the previous messages if the request fails. mutate(previousMessages, false); throw e; } if (onFinish) { onFinish(responseMessage); } return responseMessage; } return await callChatApi({ api, messages: constructedMessagesPayload, body: { data: chatRequest.data, ...extraMetadataRef.current.body, ...chatRequest.options?.body, ...(chatRequest.functions !== undefined && { functions: chatRequest.functions, }), ...(chatRequest.function_call !== undefined && { function_call: chatRequest.function_call, }), ...(chatRequest.tools !== undefined && { tools: chatRequest.tools, }), ...(chatRequest.tool_choice !== undefined && { tool_choice: chatRequest.tool_choice, }), }, credentials: extraMetadataRef.current.credentials, headers: { ...extraMetadataRef.current.headers, ...chatRequest.options?.headers, }, abortController: () => abortControllerRef.current, restoreMessagesOnFailure() { mutate(previousMessages, false); }, onResponse, onUpdate(merged, data) { mutate([...chatRequest.messages, ...merged], false); mutateStreamData([...(existingData || []), ...(data || [])], false); }, onFinish, generateId, }); }; export function useChat({ api = '/api/chat', id, initialMessages, initialInput = '', sendExtraMessageFields, experimental_onFunctionCall, experimental_onToolCall, onResponse, onFinish, onError, credentials, headers, body, generateId = generateIdFunc, }: Omit & { api?: string | StreamingReactResponseAction; key?: string; } = {}): UseChatHelpers { // Generate a unique id for the chat if not provided. const hookId = useId(); const idKey = id ?? hookId; const chatKey = typeof api === 'string' ? [api, idKey] : idKey; // Store a empty array as the initial messages // (instead of using a default parameter value that gets re-created each time) // to avoid re-renders: const [initialMessagesFallback] = useState([]); // Store the chat state in SWR, using the chatId as the key to share states. const { data: messages, mutate } = useSWR( [chatKey, 'messages'], null, { fallbackData: initialMessages ?? initialMessagesFallback }, ); // We store loading state in another hook to sync loading states across hook invocations const { data: isLoading = false, mutate: mutateLoading } = useSWR( [chatKey, 'loading'], null, ); const { data: streamData, mutate: mutateStreamData } = useSWR< JSONValue[] | undefined >([chatKey, 'streamData'], null); const { data: error = undefined, mutate: setError } = useSWR< undefined | Error >([chatKey, 'error'], null); // Keep the latest messages in a ref. const messagesRef = useRef(messages || []); useEffect(() => { messagesRef.current = messages || []; }, [messages]); // Abort controller to cancel the current API call. const abortControllerRef = useRef(null); const extraMetadataRef = useRef({ credentials, headers, body, }); useEffect(() => { extraMetadataRef.current = { credentials, headers, body, }; }, [credentials, headers, body]); const triggerRequest = useCallback( async (chatRequest: ChatRequest) => { try { mutateLoading(true); setError(undefined); const abortController = new AbortController(); abortControllerRef.current = abortController; await processChatStream({ getStreamedResponse: () => getStreamedResponse( api, chatRequest, mutate, mutateStreamData, streamData!, extraMetadataRef, messagesRef, abortControllerRef, generateId, onFinish, onResponse, sendExtraMessageFields, ), experimental_onFunctionCall, experimental_onToolCall, updateChatRequest: chatRequestParam => { chatRequest = chatRequestParam; }, getCurrentMessages: () => messagesRef.current, }); abortControllerRef.current = null; } catch (err) { // Ignore abort errors as they are expected. if ((err as any).name === 'AbortError') { abortControllerRef.current = null; return null; } if (onError && err instanceof Error) { onError(err); } setError(err as Error); } finally { mutateLoading(false); } }, [ mutate, mutateLoading, api, extraMetadataRef, onResponse, onFinish, onError, setError, mutateStreamData, streamData, sendExtraMessageFields, experimental_onFunctionCall, experimental_onToolCall, messagesRef, abortControllerRef, generateId, ], ); const append = useCallback( async ( message: Message | CreateMessage, { options, functions, function_call, tools, tool_choice, data, }: ChatRequestOptions = {}, ) => { if (!message.id) { message.id = generateId(); } const chatRequest: ChatRequest = { messages: messagesRef.current.concat(message as Message), options, data, ...(functions !== undefined && { functions }), ...(function_call !== undefined && { function_call }), ...(tools !== undefined && { tools }), ...(tool_choice !== undefined && { tool_choice }), }; return triggerRequest(chatRequest); }, [triggerRequest, generateId], ); const reload = useCallback( async ({ options, functions, function_call, tools, tool_choice, }: ChatRequestOptions = {}) => { if (messagesRef.current.length === 0) return null; // Remove last assistant message and retry last user message. const lastMessage = messagesRef.current[messagesRef.current.length - 1]; if (lastMessage.role === 'assistant') { const chatRequest: ChatRequest = { messages: messagesRef.current.slice(0, -1), options, ...(functions !== undefined && { functions }), ...(function_call !== undefined && { function_call }), ...(tools !== undefined && { tools }), ...(tool_choice !== undefined && { tool_choice }), }; return triggerRequest(chatRequest); } const chatRequest: ChatRequest = { messages: messagesRef.current, options, ...(functions !== undefined && { functions }), ...(function_call !== undefined && { function_call }), ...(tools !== undefined && { tools }), ...(tool_choice !== undefined && { tool_choice }), }; return triggerRequest(chatRequest); }, [triggerRequest], ); const stop = useCallback(() => { if (abortControllerRef.current) { abortControllerRef.current.abort(); abortControllerRef.current = null; } }, []); const setMessages = useCallback( (messages: Message[]) => { mutate(messages, false); messagesRef.current = messages; }, [mutate], ); // Input state and handlers. const [input, setInput] = useState(initialInput); const handleSubmit = useCallback( ( e: React.FormEvent, options: ChatRequestOptions = {}, metadata?: Object, ) => { if (metadata) { extraMetadataRef.current = { ...extraMetadataRef.current, ...metadata, }; } e.preventDefault(); if (!input) return; append( { content: input, role: 'user', createdAt: new Date(), }, options, ); setInput(''); }, [input, append], ); const handleInputChange = (e: any) => { setInput(e.target.value); }; return { messages: messages || [], error, append, reload, stop, setMessages, input, setInput, handleInputChange, handleSubmit, isLoading, data: streamData, }; }