import type { ReactNode } from 'react'; import type OpenAI from 'openai'; import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; // TODO: This needs to be externalized. import { OpenAIStream } from '../streams'; import { STREAMABLE_VALUE_TYPE, DEV_DEFAULT_STREAMABLE_WARNING_TIME, } from './constants'; import { createResolvablePromise, createSuspensedChunk, consumeStream, } from './utils'; import type { StreamablePatch, StreamableValue } from './types'; /** * Create a piece of changable UI that can be streamed to the client. * On the client side, it can be rendered as a normal React node. */ export function createStreamableUI(initialValue?: React.ReactNode) { let currentValue = initialValue; let closed = false; let { row, resolve, reject } = createSuspensedChunk(initialValue); function assertStream(method: string) { if (closed) { throw new Error(method + ': UI stream is already closed.'); } } let warningTimeout: NodeJS.Timeout | undefined; function warnUnclosedStream() { if (process.env.NODE_ENV === 'development') { if (warningTimeout) { clearTimeout(warningTimeout); } warningTimeout = setTimeout(() => { console.warn( 'The streamable UI has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`.', ); }, DEV_DEFAULT_STREAMABLE_WARNING_TIME); } } warnUnclosedStream(); return { /** * The value of the streamable UI. This can be returned from a Server Action and received by the client. */ value: row, /** * This method updates the current UI node. It takes a new UI node and replaces the old one. */ update(value: React.ReactNode) { assertStream('.update()'); // There is no need to update the value if it's referentially equal. if (value === currentValue) { warnUnclosedStream(); return; } const resolvable = createResolvablePromise(); currentValue = value; resolve({ value: currentValue, done: false, next: resolvable.promise }); resolve = resolvable.resolve; reject = resolvable.reject; warnUnclosedStream(); }, /** * This method is used to append a new UI node to the end of the old one. * Once appended a new UI node, the previous UI node cannot be updated anymore. * * @example * ```jsx * const ui = createStreamableUI(
hello
) * ui.append(
world
) * * // The UI node will be: * // <> * //
hello
* //
world
* // * ``` */ append(value: React.ReactNode) { assertStream('.append()'); const resolvable = createResolvablePromise(); currentValue = value; resolve({ value, done: false, append: true, next: resolvable.promise }); resolve = resolvable.resolve; reject = resolvable.reject; warnUnclosedStream(); }, /** * This method is used to signal that there is an error in the UI stream. * It will be thrown on the client side and caught by the nearest error boundary component. */ error(error: any) { assertStream('.error()'); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; reject(error); }, /** * This method marks the UI node as finalized. You can either call it without any parameters or with a new UI node as the final state. * Once called, the UI node cannot be updated or appended anymore. * * This method is always **required** to be called, otherwise the response will be stuck in a loading state. */ done(...args: [] | [React.ReactNode]) { assertStream('.done()'); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; if (args.length) { resolve({ value: args[0], done: true }); return; } resolve({ value: currentValue, done: true }); }, }; } /** * Create a wrapped, changable value that can be streamed to the client. * On the client side, the value can be accessed via the readStreamableValue() API. */ export function createStreamableValue(initialValue?: T) { let closed = false; let resolvable = createResolvablePromise>(); let currentValue = initialValue; let currentError: E | undefined; let currentPromise: typeof resolvable.promise | undefined = resolvable.promise; let currentPatchValue: StreamablePatch; function assertStream(method: string) { if (closed) { throw new Error(method + ': Value stream is already closed.'); } } let warningTimeout: NodeJS.Timeout | undefined; function warnUnclosedStream() { if (process.env.NODE_ENV === 'development') { if (warningTimeout) { clearTimeout(warningTimeout); } warningTimeout = setTimeout(() => { console.warn( 'The streamable UI has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`.', ); }, DEV_DEFAULT_STREAMABLE_WARNING_TIME); } } warnUnclosedStream(); function createWrapped(initialChunk?: boolean): StreamableValue { // This makes the payload much smaller if there're mutative updates before the first read. let init: Partial>; if (currentError !== undefined) { init = { error: currentError }; } else { if (currentPatchValue && !initialChunk) { init = { diff: currentPatchValue }; } else { init = { curr: currentValue }; } } if (currentPromise) { init.next = currentPromise; } if (initialChunk) { init.type = STREAMABLE_VALUE_TYPE; } return init; } // Update the internal `currentValue` and `currentPatchValue` if needed. function updateValueStates(value: T) { // If we can only send a patch over the wire, it's better to do so. currentPatchValue = undefined; if (typeof value === 'string') { if (typeof currentValue === 'string') { if (value.startsWith(currentValue)) { currentPatchValue = [0, value.slice(currentValue.length)]; } } } currentValue = value; } return { /** * The value of the streamable. This can be returned from a Server Action and * received by the client. To read the streamed values, use the * `readStreamableValue` or `useStreamableValue` APIs. */ get value() { return createWrapped(true); }, /** * This method updates the current value with a new one. */ update(value: T) { assertStream('.update()'); const resolvePrevious = resolvable.resolve; resolvable = createResolvablePromise(); updateValueStates(value); currentPromise = resolvable.promise; resolvePrevious(createWrapped()); warnUnclosedStream(); }, error(error: any) { assertStream('.error()'); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; currentError = error; currentPromise = undefined; resolvable.resolve({ error }); }, done(...args: [] | [T]) { assertStream('.done()'); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; currentPromise = undefined; if (args.length) { updateValueStates(args[0]); resolvable.resolve(createWrapped()); return; } resolvable.resolve({}); }, }; } type Streamable = ReactNode | Promise; type Renderer = ( props: T, ) => | Streamable | Generator | AsyncGenerator; /** * `render` is a helper function to create a streamable UI from some LLMs. * Currently, it only supports OpenAI's GPT models with Function Calling and Assistants Tools. */ export function render< TS extends { [name: string]: z.Schema; } = {}, FS extends { [name: string]: z.Schema; } = {}, >(options: { /** * The model name to use. Must be OpenAI SDK compatible. Tools and Functions are only supported * GPT models (3.5/4), OpenAI Assistants, Mistral small and large, and Fireworks firefunction-v1. * * @example "gpt-3.5-turbo" */ model: string; /** * The provider instance to use. Currently the only provider available is OpenAI. * This needs to match the model name. */ provider: OpenAI; messages: Parameters< typeof OpenAI.prototype.chat.completions.create >[0]['messages']; text?: Renderer<{ /** * The full text content from the model so far. */ content: string; /** * The new appended text content from the model since the last `text` call. */ delta: string; /** * Whether the model is done generating text. * If `true`, the `content` will be the final output and this call will be the last. */ done: boolean; }>; tools?: { [name in keyof TS]: { description?: string; parameters: TS[name]; render: Renderer>; }; }; functions?: { [name in keyof FS]: { description?: string; parameters: FS[name]; render: Renderer>; }; }; initial?: ReactNode; temperature?: number; }): ReactNode { const ui = createStreamableUI(options.initial); // The default text renderer just returns the content as string. const text = options.text ? options.text : ({ content }: { content: string }) => content; const functions = options.functions ? Object.entries(options.functions).map( ([name, { description, parameters }]) => { return { name, description, parameters: zodToJsonSchema(parameters) as Record, }; }, ) : undefined; const tools = options.tools ? Object.entries(options.tools).map( ([name, { description, parameters }]) => { return { type: 'function' as const, function: { name, description, parameters: zodToJsonSchema(parameters) as Record< string, unknown >, }, }; }, ) : undefined; if (functions && tools) { throw new Error( "You can't have both functions and tools defined. Please choose one or the other.", ); } let finished: Promise | undefined; async function handleRender( args: any, renderer: undefined | Renderer, res: ReturnType, ) { if (!renderer) return; const resolvable = createResolvablePromise(); if (finished) { finished = finished.then(() => resolvable.promise); } else { finished = resolvable.promise; } const value = renderer(args); if ( value instanceof Promise || (value && typeof value === 'object' && 'then' in value && typeof value.then === 'function') ) { const node = await (value as Promise); res.update(node); resolvable.resolve(void 0); } else if ( value && typeof value === 'object' && Symbol.asyncIterator in value ) { const it = value as AsyncGenerator< React.ReactNode, React.ReactNode, void >; while (true) { const { done, value } = await it.next(); res.update(value); if (done) break; } resolvable.resolve(void 0); } else if (value && typeof value === 'object' && Symbol.iterator in value) { const it = value as Generator; while (true) { const { done, value } = it.next(); res.update(value); if (done) break; } resolvable.resolve(void 0); } else { res.update(value); resolvable.resolve(void 0); } } (async () => { let hasFunction = false; let content = ''; consumeStream( OpenAIStream( (await options.provider.chat.completions.create({ model: options.model, messages: options.messages, temperature: options.temperature, stream: true, ...(functions ? { functions, } : {}), ...(tools ? { tools, } : {}), })) as any, { ...(functions ? { async experimental_onFunctionCall(functionCallPayload) { hasFunction = true; handleRender( functionCallPayload.arguments, options.functions?.[functionCallPayload.name as any] ?.render, ui, ); }, } : {}), ...(tools ? { async experimental_onToolCall(toolCallPayload: any) { hasFunction = true; // TODO: We might need Promise.all here? for (const tool of toolCallPayload.tools) { handleRender( tool.func.arguments, options.tools?.[tool.func.name as any]?.render, ui, ); } }, } : {}), onText(chunk) { content += chunk; handleRender({ content, done: false, delta: chunk }, text, ui); }, async onFinal() { if (hasFunction) { await finished; ui.done(); return; } handleRender({ content, done: true }, text, ui); await finished; ui.done(); }, }, ), ); })(); return ui.value; }