439 lines
10 KiB
TypeScript
439 lines
10 KiB
TypeScript
import { Dispatch, MutableRefObject, SetStateAction, useCallback, useEffect, useId, useRef, useState } from 'react';
|
||
import useSWR, { KeyedMutator } from 'swr';
|
||
|
||
export interface FunctionCall {
|
||
/**
|
||
* The arguments to call the function with, as generated by the model in JSON
|
||
* format. Note that the model does not always generate valid JSON, and may
|
||
* hallucinate parameters not defined by your function schema. Validate the
|
||
* arguments in your code before calling your function.
|
||
*/
|
||
arguments?: string;
|
||
|
||
/**
|
||
* The name of the function to call.
|
||
*/
|
||
name?: string;
|
||
}
|
||
|
||
/**
|
||
* The tool calls generated by the model, such as function calls.
|
||
*/
|
||
export interface ToolCall {
|
||
// The ID of the tool call.
|
||
id: string;
|
||
|
||
// The type of the tool. Currently, only `function` is supported.
|
||
type: string;
|
||
|
||
// The function that the model called.
|
||
function: {
|
||
// The name of the function.
|
||
name: string;
|
||
|
||
// The arguments to call the function with, as generated by the model in JSON
|
||
arguments: string;
|
||
};
|
||
}
|
||
|
||
export type JSONValue =
|
||
| null
|
||
| string
|
||
| number
|
||
| boolean
|
||
| { [x: string]: JSONValue }
|
||
| Array<JSONValue>;
|
||
|
||
/**
|
||
* Shared types between the API and UI packages.
|
||
*/
|
||
export interface Message {
|
||
id?: string;
|
||
tool_call_id?: string;
|
||
createdAt?: Date;
|
||
content: string;
|
||
ui?: string | JSX.Element | JSX.Element[] | null | undefined;
|
||
// role: 'system' | 'user' | 'assistant' | 'function' | 'data' | 'tool';
|
||
/**
|
||
* If the message has a role of `function`, the `name` field is the name of the function.
|
||
* Otherwise, the name field should not be set.
|
||
*/
|
||
name?: string;
|
||
/**
|
||
* If the assistant role makes a function call, the `function_call` field
|
||
* contains the function call name and arguments. Otherwise, the field should
|
||
* not be set. (Deprecated and replaced by tool_calls.)
|
||
*/
|
||
function_call?: string | FunctionCall;
|
||
|
||
data?: JSONValue;
|
||
/**
|
||
* If the assistant role makes a tool call, the `tool_calls` field contains
|
||
* the tool call name and arguments. Otherwise, the field should not be set.
|
||
*/
|
||
tool_calls?: string | ToolCall[];
|
||
|
||
/**
|
||
* Additional message-specific information added on the server via StreamData
|
||
*/
|
||
annotations?: JSONValue[] | undefined;
|
||
}
|
||
|
||
interface Error {
|
||
name: string;
|
||
message: string;
|
||
stack?: string;
|
||
}
|
||
|
||
interface WebSocketMessage {
|
||
messageId?: string;
|
||
conversationId?: string;
|
||
type: string;
|
||
content: string;
|
||
streamStatus: "start" | "end" | "middle";
|
||
}
|
||
|
||
// type WebSocketData = WebSocketMessage;
|
||
|
||
interface WebSocketOptions {
|
||
id?: string;
|
||
initialMessages?: WebSocketMessage[];
|
||
initialInput?: string;
|
||
onResponse?: (response: Response) => void | Promise<void>;
|
||
onFinish?: (message: WebSocketMessage[] | string[]) => void;
|
||
onError?: (error: Error | any) => void;
|
||
}
|
||
|
||
|
||
|
||
export type WebSocketHook = {
|
||
messages: WebSocketMessage[];
|
||
isLoading: boolean;
|
||
input: string;
|
||
|
||
error: undefined | Error;
|
||
|
||
setMessages: (messages: WebSocketMessage[]) => void;
|
||
|
||
setInput: React.Dispatch<React.SetStateAction<string>>;
|
||
// sendMessage: (message: string) => void;
|
||
// append: (message: string) => void;
|
||
append: (
|
||
message: Message,
|
||
// chatRequestOptions?: ChatRequestOptions,
|
||
) => Promise<string | null | undefined | void>;
|
||
}
|
||
|
||
|
||
|
||
const getStreamedResponse = async (
|
||
event: MessageEvent,
|
||
setMessages: (messages: WebSocketMessage[]) => void,
|
||
mutate: KeyedMutator<WebSocketMessage[]>,
|
||
messagesRef: MutableRefObject<WebSocketMessage[]>,
|
||
mutateLoading: KeyedMutator<boolean>,
|
||
setInput: Dispatch<SetStateAction<string>>
|
||
) => {
|
||
const receivedMessage = event.data;
|
||
let parsedMessage = receivedMessage;
|
||
|
||
try {
|
||
parsedMessage = JSON.parse(receivedMessage);
|
||
} catch (error) {
|
||
// console.error("解析接收到的消息时出错:", error);
|
||
}
|
||
|
||
let messageBuffer: WebSocketMessage = {
|
||
messageId: '',
|
||
conversationId: '',
|
||
type: '',
|
||
content: "",
|
||
streamStatus: 'middle'
|
||
};
|
||
|
||
switch (parsedMessage.streamStatus) {
|
||
case "start":
|
||
mutateLoading(true);
|
||
console.log("开始接收信息", parsedMessage);
|
||
|
||
messageBuffer.type = parsedMessage.type
|
||
|
||
// setMessages([...messagesRef.current, messageBuffer])
|
||
messagesRef.current = [...messagesRef.current, messageBuffer]
|
||
break;
|
||
case "end":
|
||
console.log("结束接收信息", parsedMessage);
|
||
|
||
|
||
setInput("")
|
||
mutateLoading(false);
|
||
break;
|
||
default:
|
||
|
||
|
||
const lastMessage = messagesRef.current[messagesRef.current.length - 1];
|
||
console.log("中间接收信息进行消息输出拼接", parsedMessage, lastMessage);
|
||
|
||
lastMessage.content = lastMessage.content + parsedMessage
|
||
// setMessages([...messagesRef.current.slice(0, -1), lastMessage])
|
||
|
||
|
||
|
||
// // 更新最后一条消息的内容
|
||
// const lastMessageIndex = messagesRef.current.length - 1;
|
||
// const lastMessage = messagesRef.current[lastMessageIndex];
|
||
// lastMessage.content += parsedMessage;
|
||
|
||
// 更新 messagesRef.current,并触发 UI 重新渲染
|
||
messagesRef.current = [
|
||
...messagesRef.current.slice(0, -1),
|
||
lastMessage
|
||
];
|
||
|
||
|
||
|
||
break;
|
||
}
|
||
|
||
};
|
||
|
||
const useISDK_B = (url: string, options: WebSocketOptions): WebSocketHook => {
|
||
|
||
const websocketOptions = {
|
||
headers: {
|
||
'Sec-WebSocket-Version': '13',
|
||
// 添加其他自定义请求头
|
||
'Custom-Header': 'value',
|
||
}
|
||
};
|
||
|
||
const socketRef = useRef<WebSocket | null>()
|
||
|
||
|
||
const { onResponse, onFinish, onError } = options || {};
|
||
const [input, setInput] = useState<string>(options?.initialInput || '');
|
||
|
||
|
||
const [initialMessagesFallback] = useState([]);
|
||
|
||
// let messages = options?.initialMessages ?? initialMessagesFallback
|
||
|
||
const hookId = useId();
|
||
const idKey = options?.id ?? hookId;
|
||
const chatKey = typeof url === 'string' ? [url, idKey] : idKey;
|
||
|
||
const { data: messages, mutate } = useSWR<WebSocketMessage[]>(
|
||
[chatKey, 'messages'],
|
||
null,
|
||
{ fallbackData: (options?.initialMessages ?? initialMessagesFallback) || [] },
|
||
);
|
||
|
||
const { data: isLoading = false, mutate: mutateLoading } = useSWR<boolean>(
|
||
[chatKey, 'loading'],
|
||
null,
|
||
);
|
||
|
||
const { data: error = undefined, mutate: setError } = useSWR<
|
||
undefined | Error
|
||
>([chatKey, 'error'], null);
|
||
|
||
// Keep the latest messages in a ref.
|
||
const messagesRef = useRef<WebSocketMessage[]>(messages || []);
|
||
|
||
useEffect(() => {
|
||
// messagesRef.current = messages || [];
|
||
mutate(messagesRef.current)
|
||
}, [messagesRef]);
|
||
|
||
const setMessages = useCallback(
|
||
(newMessages: WebSocketMessage[]) => {
|
||
|
||
// messagesRef.current = newMessages || [];
|
||
mutate(newMessages, false);
|
||
|
||
console.log("--setMessages---", messagesRef.current);
|
||
},
|
||
[mutate]
|
||
);
|
||
|
||
const setupWebSocket = useCallback((newURL: string) => {
|
||
// if (socketRef.current) {
|
||
// socketRef.current.close(); // 关闭旧连接
|
||
// }
|
||
|
||
// const webSocket = new WebSocket(newURL)
|
||
const socket = new WebSocket(newURL, []);
|
||
|
||
if (socketRef.current) return;
|
||
// setWs(socket);
|
||
|
||
socket.onopen = () => {
|
||
console.log('WebSocket 连接已建立');
|
||
if (options?.initialInput !== undefined) {
|
||
// console.log("---useWebSocket------", options)
|
||
// ws.send(options.initialInput as string); // 将 undefined 转换为 string
|
||
setInput("")
|
||
append({
|
||
id: '',
|
||
content: options.initialInput,
|
||
})
|
||
}
|
||
};
|
||
|
||
socket.onclose = () => {
|
||
console.log('WebSocket 连接已关闭');
|
||
// if (onFinish && !!messages) {
|
||
// onFinish(messages);
|
||
// }
|
||
if (socketRef?.current) {
|
||
socketRef.current.close()
|
||
setupWebSocket(newURL)
|
||
}
|
||
};
|
||
|
||
|
||
socket.onerror = (error) => {
|
||
console.error('WebSocket 连接发生错误:', error);
|
||
if (onError) {
|
||
onError(error);
|
||
}
|
||
mutateLoading(false);// 发生错误后,设置 isLoading 为 false
|
||
};
|
||
|
||
socketRef.current = socket
|
||
|
||
}, [
|
||
socketRef.current,
|
||
// input,
|
||
mutate,
|
||
messagesRef,
|
||
onError,
|
||
],
|
||
)
|
||
|
||
useEffect(() => {
|
||
|
||
const socket = socketRef.current
|
||
|
||
|
||
if (url && socket && socket?.url !== url) {
|
||
socket.close()
|
||
setupWebSocket(url)
|
||
} else if (url && !socket) {
|
||
setupWebSocket(url)
|
||
}
|
||
}, [socketRef.current, onResponse, onFinish, onError])
|
||
|
||
|
||
// 发送消息
|
||
const triggerRequest = useCallback(async (chatRequest: Message) => {
|
||
console.log("-----triggerRequest--------", chatRequest)
|
||
|
||
|
||
try {
|
||
if (socketRef?.current && socketRef?.current.readyState === WebSocket.OPEN) {
|
||
mutateLoading(true); // 发送消息时,设置 isLoading 为 true
|
||
socketRef.current.send(chatRequest.content);
|
||
} else {
|
||
console.error('WebSocket 连接未建立或已关闭');
|
||
}
|
||
} catch (error) {
|
||
|
||
} finally {
|
||
mutateLoading(false);
|
||
}
|
||
|
||
}, [
|
||
socketRef.current,
|
||
messagesRef,
|
||
|
||
// input,
|
||
mutate,
|
||
mutateLoading,
|
||
onResponse,
|
||
onFinish,
|
||
onError,
|
||
]);
|
||
|
||
|
||
// 追加消息并发送
|
||
const append = useCallback(
|
||
async (
|
||
message: Message,
|
||
|
||
) => {
|
||
|
||
let chatRequest = message
|
||
|
||
// setMessages([...messagesRef.current, {
|
||
// // messageId: message.id,
|
||
// conversationId: '',
|
||
// type: 'text',
|
||
// content: message.content,
|
||
// streamStatus: 'middle'
|
||
// }])
|
||
|
||
messagesRef.current = [...messagesRef.current, {
|
||
// messageId: message.id,
|
||
conversationId: '',
|
||
type: 'text',
|
||
content: message.content,
|
||
streamStatus: 'middle'
|
||
}]
|
||
|
||
return triggerRequest(chatRequest);
|
||
},
|
||
[triggerRequest],
|
||
);
|
||
|
||
|
||
// 在依赖项数组中使用函数的正确方式
|
||
useEffect(() => {
|
||
|
||
|
||
if (socketRef.current) {
|
||
// socketRef.current.onmessage = handleWebSocketMessage;
|
||
socketRef.current.onmessage = (event) => getStreamedResponse(
|
||
event,
|
||
setMessages,
|
||
mutate,
|
||
messagesRef,
|
||
mutateLoading,
|
||
setInput,
|
||
);
|
||
}
|
||
|
||
|
||
return () => {
|
||
if (socketRef.current) {
|
||
socketRef.current.onmessage = null; // 清除事件处理程序以防止内存泄漏
|
||
}
|
||
};
|
||
}, [
|
||
socketRef.current,
|
||
mutate,
|
||
mutateLoading,
|
||
messagesRef,
|
||
|
||
onResponse,
|
||
onFinish,
|
||
onError,
|
||
]);
|
||
|
||
|
||
return {
|
||
messages: messagesRef.current || [],
|
||
isLoading,
|
||
error,
|
||
append,
|
||
|
||
setMessages,
|
||
|
||
input,
|
||
setInput,
|
||
|
||
};
|
||
};
|
||
|
||
export default useISDK_B;
|