hts/packages/isdk/rsc/streamable.tsx

512 lines
14 KiB
TypeScript

import type { ReactNode } from 'react';
import type OpenAI from 'openai';
import { z } from 'zod';
import zodToJsonSchema from 'zod-to-json-schema';
// TODO: This needs to be externalized.
import { OpenAIStream } from '../streams';
import {
STREAMABLE_VALUE_TYPE,
DEV_DEFAULT_STREAMABLE_WARNING_TIME,
} from './constants';
import {
createResolvablePromise,
createSuspensedChunk,
consumeStream,
} from './utils';
import type { StreamablePatch, StreamableValue } from './types';
/**
* Create a piece of changable UI that can be streamed to the client.
* On the client side, it can be rendered as a normal React node.
*/
export function createStreamableUI(initialValue?: React.ReactNode) {
let currentValue = initialValue;
let closed = false;
let { row, resolve, reject } = createSuspensedChunk(initialValue);
function assertStream(method: string) {
if (closed) {
throw new Error(method + ': UI stream is already closed.');
}
}
let warningTimeout: NodeJS.Timeout | undefined;
function warnUnclosedStream() {
if (process.env.NODE_ENV === 'development') {
if (warningTimeout) {
clearTimeout(warningTimeout);
}
warningTimeout = setTimeout(() => {
console.warn(
'The streamable UI has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`.',
);
}, DEV_DEFAULT_STREAMABLE_WARNING_TIME);
}
}
warnUnclosedStream();
return {
/**
* The value of the streamable UI. This can be returned from a Server Action and received by the client.
*/
value: row,
/**
* This method updates the current UI node. It takes a new UI node and replaces the old one.
*/
update(value: React.ReactNode) {
assertStream('.update()');
// There is no need to update the value if it's referentially equal.
if (value === currentValue) {
warnUnclosedStream();
return;
}
const resolvable = createResolvablePromise();
currentValue = value;
resolve({ value: currentValue, done: false, next: resolvable.promise });
resolve = resolvable.resolve;
reject = resolvable.reject;
warnUnclosedStream();
},
/**
* This method is used to append a new UI node to the end of the old one.
* Once appended a new UI node, the previous UI node cannot be updated anymore.
*
* @example
* ```jsx
* const ui = createStreamableUI(<div>hello</div>)
* ui.append(<div>world</div>)
*
* // The UI node will be:
* // <>
* // <div>hello</div>
* // <div>world</div>
* // </>
* ```
*/
append(value: React.ReactNode) {
assertStream('.append()');
const resolvable = createResolvablePromise();
currentValue = value;
resolve({ value, done: false, append: true, next: resolvable.promise });
resolve = resolvable.resolve;
reject = resolvable.reject;
warnUnclosedStream();
},
/**
* This method is used to signal that there is an error in the UI stream.
* It will be thrown on the client side and caught by the nearest error boundary component.
*/
error(error: any) {
assertStream('.error()');
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
reject(error);
},
/**
* This method marks the UI node as finalized. You can either call it without any parameters or with a new UI node as the final state.
* Once called, the UI node cannot be updated or appended anymore.
*
* This method is always **required** to be called, otherwise the response will be stuck in a loading state.
*/
done(...args: [] | [React.ReactNode]) {
assertStream('.done()');
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
if (args.length) {
resolve({ value: args[0], done: true });
return;
}
resolve({ value: currentValue, done: true });
},
};
}
/**
* Create a wrapped, changable value that can be streamed to the client.
* On the client side, the value can be accessed via the readStreamableValue() API.
*/
export function createStreamableValue<T = any, E = any>(initialValue?: T) {
let closed = false;
let resolvable = createResolvablePromise<StreamableValue<T, E>>();
let currentValue = initialValue;
let currentError: E | undefined;
let currentPromise: typeof resolvable.promise | undefined =
resolvable.promise;
let currentPatchValue: StreamablePatch;
function assertStream(method: string) {
if (closed) {
throw new Error(method + ': Value stream is already closed.');
}
}
let warningTimeout: NodeJS.Timeout | undefined;
function warnUnclosedStream() {
if (process.env.NODE_ENV === 'development') {
if (warningTimeout) {
clearTimeout(warningTimeout);
}
warningTimeout = setTimeout(() => {
console.warn(
'The streamable UI has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`.',
);
}, DEV_DEFAULT_STREAMABLE_WARNING_TIME);
}
}
warnUnclosedStream();
function createWrapped(initialChunk?: boolean): StreamableValue<T, E> {
// This makes the payload much smaller if there're mutative updates before the first read.
let init: Partial<StreamableValue<T, E>>;
if (currentError !== undefined) {
init = { error: currentError };
} else {
if (currentPatchValue && !initialChunk) {
init = { diff: currentPatchValue };
} else {
init = { curr: currentValue };
}
}
if (currentPromise) {
init.next = currentPromise;
}
if (initialChunk) {
init.type = STREAMABLE_VALUE_TYPE;
}
return init;
}
// Update the internal `currentValue` and `currentPatchValue` if needed.
function updateValueStates(value: T) {
// If we can only send a patch over the wire, it's better to do so.
currentPatchValue = undefined;
if (typeof value === 'string') {
if (typeof currentValue === 'string') {
if (value.startsWith(currentValue)) {
currentPatchValue = [0, value.slice(currentValue.length)];
}
}
}
currentValue = value;
}
return {
/**
* The value of the streamable. This can be returned from a Server Action and
* received by the client. To read the streamed values, use the
* `readStreamableValue` or `useStreamableValue` APIs.
*/
get value() {
return createWrapped(true);
},
/**
* This method updates the current value with a new one.
*/
update(value: T) {
assertStream('.update()');
const resolvePrevious = resolvable.resolve;
resolvable = createResolvablePromise();
updateValueStates(value);
currentPromise = resolvable.promise;
resolvePrevious(createWrapped());
warnUnclosedStream();
},
error(error: any) {
assertStream('.error()');
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
currentError = error;
currentPromise = undefined;
resolvable.resolve({ error });
},
done(...args: [] | [T]) {
assertStream('.done()');
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
currentPromise = undefined;
if (args.length) {
updateValueStates(args[0]);
resolvable.resolve(createWrapped());
return;
}
resolvable.resolve({});
},
};
}
type Streamable = ReactNode | Promise<ReactNode>;
type Renderer<T> = (
props: T,
) =>
| Streamable
| Generator<Streamable, Streamable, void>
| AsyncGenerator<Streamable, Streamable, void>;
/**
* `render` is a helper function to create a streamable UI from some LLMs.
* Currently, it only supports OpenAI's GPT models with Function Calling and Assistants Tools.
*/
export function render<
TS extends {
[name: string]: z.Schema;
} = {},
FS extends {
[name: string]: z.Schema;
} = {},
>(options: {
/**
* The model name to use. Must be OpenAI SDK compatible. Tools and Functions are only supported
* GPT models (3.5/4), OpenAI Assistants, Mistral small and large, and Fireworks firefunction-v1.
*
* @example "gpt-3.5-turbo"
*/
model: string;
/**
* The provider instance to use. Currently the only provider available is OpenAI.
* This needs to match the model name.
*/
provider: OpenAI;
messages: Parameters<
typeof OpenAI.prototype.chat.completions.create
>[0]['messages'];
text?: Renderer<{
/**
* The full text content from the model so far.
*/
content: string;
/**
* The new appended text content from the model since the last `text` call.
*/
delta: string;
/**
* Whether the model is done generating text.
* If `true`, the `content` will be the final output and this call will be the last.
*/
done: boolean;
}>;
tools?: {
[name in keyof TS]: {
description?: string;
parameters: TS[name];
render: Renderer<z.infer<TS[name]>>;
};
};
functions?: {
[name in keyof FS]: {
description?: string;
parameters: FS[name];
render: Renderer<z.infer<FS[name]>>;
};
};
initial?: ReactNode;
temperature?: number;
}): ReactNode {
const ui = createStreamableUI(options.initial);
// The default text renderer just returns the content as string.
const text = options.text
? options.text
: ({ content }: { content: string }) => content;
const functions = options.functions
? Object.entries(options.functions).map(
([name, { description, parameters }]) => {
return {
name,
description,
parameters: zodToJsonSchema(parameters) as Record<string, unknown>,
};
},
)
: undefined;
const tools = options.tools
? Object.entries(options.tools).map(
([name, { description, parameters }]) => {
return {
type: 'function' as const,
function: {
name,
description,
parameters: zodToJsonSchema(parameters) as Record<
string,
unknown
>,
},
};
},
)
: undefined;
if (functions && tools) {
throw new Error(
"You can't have both functions and tools defined. Please choose one or the other.",
);
}
let finished: Promise<void> | undefined;
async function handleRender(
args: any,
renderer: undefined | Renderer<any>,
res: ReturnType<typeof createStreamableUI>,
) {
if (!renderer) return;
const resolvable = createResolvablePromise<void>();
if (finished) {
finished = finished.then(() => resolvable.promise);
} else {
finished = resolvable.promise;
}
const value = renderer(args);
if (
value instanceof Promise ||
(value &&
typeof value === 'object' &&
'then' in value &&
typeof value.then === 'function')
) {
const node = await (value as Promise<React.ReactNode>);
res.update(node);
resolvable.resolve(void 0);
} else if (
value &&
typeof value === 'object' &&
Symbol.asyncIterator in value
) {
const it = value as AsyncGenerator<
React.ReactNode,
React.ReactNode,
void
>;
while (true) {
const { done, value } = await it.next();
res.update(value);
if (done) break;
}
resolvable.resolve(void 0);
} else if (value && typeof value === 'object' && Symbol.iterator in value) {
const it = value as Generator<React.ReactNode, React.ReactNode, void>;
while (true) {
const { done, value } = it.next();
res.update(value);
if (done) break;
}
resolvable.resolve(void 0);
} else {
res.update(value);
resolvable.resolve(void 0);
}
}
(async () => {
let hasFunction = false;
let content = '';
consumeStream(
OpenAIStream(
(await options.provider.chat.completions.create({
model: options.model,
messages: options.messages,
temperature: options.temperature,
stream: true,
...(functions
? {
functions,
}
: {}),
...(tools
? {
tools,
}
: {}),
})) as any,
{
...(functions
? {
async experimental_onFunctionCall(functionCallPayload) {
hasFunction = true;
handleRender(
functionCallPayload.arguments,
options.functions?.[functionCallPayload.name as any]
?.render,
ui,
);
},
}
: {}),
...(tools
? {
async experimental_onToolCall(toolCallPayload: any) {
hasFunction = true;
// TODO: We might need Promise.all here?
for (const tool of toolCallPayload.tools) {
handleRender(
tool.func.arguments,
options.tools?.[tool.func.name as any]?.render,
ui,
);
}
},
}
: {}),
onText(chunk) {
content += chunk;
handleRender({ content, done: false, delta: chunk }, text, ui);
},
async onFinal() {
if (hasFunction) {
await finished;
ui.done();
return;
}
handleRender({ content, done: true }, text, ui);
await finished;
ui.done();
},
},
),
);
})();
return ui.value;
}