import { Dispatch, MutableRefObject, SetStateAction, useCallback, useEffect, useId, useRef, useState } from 'react'; import useSWR, { KeyedMutator } from 'swr'; export interface FunctionCall { /** * The arguments to call the function with, as generated by the model in JSON * format. Note that the model does not always generate valid JSON, and may * hallucinate parameters not defined by your function schema. Validate the * arguments in your code before calling your function. */ arguments?: string; /** * The name of the function to call. */ name?: string; } /** * The tool calls generated by the model, such as function calls. */ export interface ToolCall { // The ID of the tool call. id: string; // The type of the tool. Currently, only `function` is supported. type: string; // The function that the model called. function: { // The name of the function. name: string; // The arguments to call the function with, as generated by the model in JSON arguments: string; }; } export type JSONValue = | null | string | number | boolean | { [x: string]: JSONValue } | Array; /** * Shared types between the API and UI packages. */ export interface Message { id?: string; tool_call_id?: string; createdAt?: Date; content: string; ui?: string | JSX.Element | JSX.Element[] | null | undefined; // role: 'system' | 'user' | 'assistant' | 'function' | 'data' | 'tool'; /** * If the message has a role of `function`, the `name` field is the name of the function. * Otherwise, the name field should not be set. */ name?: string; /** * If the assistant role makes a function call, the `function_call` field * contains the function call name and arguments. Otherwise, the field should * not be set. (Deprecated and replaced by tool_calls.) */ function_call?: string | FunctionCall; data?: JSONValue; /** * If the assistant role makes a tool call, the `tool_calls` field contains * the tool call name and arguments. Otherwise, the field should not be set. */ tool_calls?: string | ToolCall[]; /** * Additional message-specific information added on the server via StreamData */ annotations?: JSONValue[] | undefined; } interface Error { name: string; message: string; stack?: string; } interface WebSocketMessage { messageId?: string; conversationId?: string; type: string; content: string; streamStatus: "start" | "end" | "middle"; } // type WebSocketData = WebSocketMessage; interface WebSocketOptions { id?: string; initialMessages?: WebSocketMessage[]; initialInput?: string; onResponse?: (response: Response) => void | Promise; onFinish?: (message: WebSocketMessage[] | string[]) => void; onError?: (error: Error | any) => void; } export type WebSocketHook = { messages: WebSocketMessage[]; isLoading: boolean; input: string; error: undefined | Error; setMessages: (messages: WebSocketMessage[]) => void; setInput: React.Dispatch>; // sendMessage: (message: string) => void; // append: (message: string) => void; append: ( message: Message, // chatRequestOptions?: ChatRequestOptions, ) => Promise; } const getStreamedResponse = async ( event: MessageEvent, setMessages: (messages: WebSocketMessage[]) => void, mutate: KeyedMutator, messagesRef: MutableRefObject, mutateLoading: KeyedMutator, setInput: Dispatch> ) => { const receivedMessage = event.data; let parsedMessage = receivedMessage; try { parsedMessage = JSON.parse(receivedMessage); } catch (error) { // console.error("解析接收到的消息时出错:", error); } let messageBuffer: WebSocketMessage = { messageId: '', conversationId: '', type: '', content: "", streamStatus: 'middle' }; switch (parsedMessage.streamStatus) { case "start": mutateLoading(true); console.log("开始接收信息", parsedMessage); messageBuffer.type = parsedMessage.type // setMessages([...messagesRef.current, messageBuffer]) messagesRef.current = [...messagesRef.current, messageBuffer] break; case "end": console.log("结束接收信息", parsedMessage); setInput("") mutateLoading(false); break; default: const lastMessage = messagesRef.current[messagesRef.current.length - 1]; console.log("中间接收信息进行消息输出拼接", parsedMessage, lastMessage); lastMessage.content = lastMessage.content + parsedMessage // setMessages([...messagesRef.current.slice(0, -1), lastMessage]) // // 更新最后一条消息的内容 // const lastMessageIndex = messagesRef.current.length - 1; // const lastMessage = messagesRef.current[lastMessageIndex]; // lastMessage.content += parsedMessage; // 更新 messagesRef.current,并触发 UI 重新渲染 messagesRef.current = [ ...messagesRef.current.slice(0, -1), lastMessage ]; break; } }; const useISDK_B = (url: string, options: WebSocketOptions): WebSocketHook => { const websocketOptions = { headers: { 'Sec-WebSocket-Version': '13', // 添加其他自定义请求头 'Custom-Header': 'value', } }; const socketRef = useRef() const { onResponse, onFinish, onError } = options || {}; const [input, setInput] = useState(options?.initialInput || ''); const [initialMessagesFallback] = useState([]); // let messages = options?.initialMessages ?? initialMessagesFallback const hookId = useId(); const idKey = options?.id ?? hookId; const chatKey = typeof url === 'string' ? [url, idKey] : idKey; const { data: messages, mutate } = useSWR( [chatKey, 'messages'], null, { fallbackData: (options?.initialMessages ?? initialMessagesFallback) || [] }, ); const { data: isLoading = false, mutate: mutateLoading } = useSWR( [chatKey, 'loading'], null, ); const { data: error = undefined, mutate: setError } = useSWR< undefined | Error >([chatKey, 'error'], null); // Keep the latest messages in a ref. const messagesRef = useRef(messages || []); useEffect(() => { // messagesRef.current = messages || []; mutate(messagesRef.current) }, [messagesRef]); const setMessages = useCallback( (newMessages: WebSocketMessage[]) => { // messagesRef.current = newMessages || []; mutate(newMessages, false); console.log("--setMessages---", messagesRef.current); }, [mutate] ); const setupWebSocket = useCallback((newURL: string) => { // if (socketRef.current) { // socketRef.current.close(); // 关闭旧连接 // } // const webSocket = new WebSocket(newURL) const socket = new WebSocket(newURL, []); if (socketRef.current) return; // setWs(socket); socket.onopen = () => { console.log('WebSocket 连接已建立'); if (options?.initialInput !== undefined) { // console.log("---useWebSocket------", options) // ws.send(options.initialInput as string); // 将 undefined 转换为 string setInput("") append({ id: '', content: options.initialInput, }) } }; socket.onclose = () => { console.log('WebSocket 连接已关闭'); // if (onFinish && !!messages) { // onFinish(messages); // } if (socketRef?.current) { socketRef.current.close() setupWebSocket(newURL) } }; socket.onerror = (error) => { console.error('WebSocket 连接发生错误:', error); if (onError) { onError(error); } mutateLoading(false);// 发生错误后,设置 isLoading 为 false }; socketRef.current = socket }, [ socketRef.current, // input, mutate, messagesRef, onError, ], ) useEffect(() => { const socket = socketRef.current if (url && socket && socket?.url !== url) { socket.close() setupWebSocket(url) } else if (url && !socket) { setupWebSocket(url) } }, [socketRef.current, onResponse, onFinish, onError]) // 发送消息 const triggerRequest = useCallback(async (chatRequest: Message) => { console.log("-----triggerRequest--------", chatRequest) try { if (socketRef?.current && socketRef?.current.readyState === WebSocket.OPEN) { mutateLoading(true); // 发送消息时,设置 isLoading 为 true socketRef.current.send(chatRequest.content); } else { console.error('WebSocket 连接未建立或已关闭'); } } catch (error) { } finally { mutateLoading(false); } }, [ socketRef.current, messagesRef, // input, mutate, mutateLoading, onResponse, onFinish, onError, ]); // 追加消息并发送 const append = useCallback( async ( message: Message, ) => { let chatRequest = message // setMessages([...messagesRef.current, { // // messageId: message.id, // conversationId: '', // type: 'text', // content: message.content, // streamStatus: 'middle' // }]) messagesRef.current = [...messagesRef.current, { // messageId: message.id, conversationId: '', type: 'text', content: message.content, streamStatus: 'middle' }] return triggerRequest(chatRequest); }, [triggerRequest], ); // 在依赖项数组中使用函数的正确方式 useEffect(() => { if (socketRef.current) { // socketRef.current.onmessage = handleWebSocketMessage; socketRef.current.onmessage = (event) => getStreamedResponse( event, setMessages, mutate, messagesRef, mutateLoading, setInput, ); } return () => { if (socketRef.current) { socketRef.current.onmessage = null; // 清除事件处理程序以防止内存泄漏 } }; }, [ socketRef.current, mutate, mutateLoading, messagesRef, onResponse, onFinish, onError, ]); return { messages: messagesRef.current || [], isLoading, error, append, setMessages, input, setInput, }; }; export default useISDK_B;