208 lines
5.9 KiB
TypeScript
208 lines
5.9 KiB
TypeScript
import { AsyncLocalStorage } from 'async_hooks';
|
|
import * as jsondiffpatch from 'jsondiffpatch';
|
|
import { createResolvablePromise, isFunction } from './utils';
|
|
import type {
|
|
AIProvider,
|
|
InternalAIStateStorageOptions,
|
|
InferAIState,
|
|
MutableAIState,
|
|
ValueOrUpdater,
|
|
} from './types';
|
|
|
|
// It is possible that multiple AI requests get in concurrently, for different
|
|
// AI instances. So ALS is necessary here for a simpler API.
|
|
const asyncAIStateStorage = new AsyncLocalStorage<{
|
|
currentState: any;
|
|
originalState: any;
|
|
sealed: boolean;
|
|
options: InternalAIStateStorageOptions;
|
|
mutationDeltaPromise?: Promise<any>;
|
|
mutationDeltaResolve?: (v: any) => void;
|
|
}>();
|
|
|
|
function getAIStateStoreOrThrow(message: string) {
|
|
const store = asyncAIStateStorage.getStore();
|
|
if (!store) {
|
|
throw new Error(message);
|
|
}
|
|
return store;
|
|
}
|
|
|
|
export function withAIState<S, T>(
|
|
{ state, options }: { state: S; options: InternalAIStateStorageOptions },
|
|
fn: () => T,
|
|
): T {
|
|
return asyncAIStateStorage.run(
|
|
{
|
|
currentState: state,
|
|
originalState: state,
|
|
sealed: false,
|
|
options,
|
|
},
|
|
fn,
|
|
);
|
|
}
|
|
|
|
export function getAIStateDeltaPromise() {
|
|
const store = getAIStateStoreOrThrow('Internal error occurred.');
|
|
return store.mutationDeltaPromise;
|
|
}
|
|
|
|
// Internal method. This will be called after the AI Action has been returned
|
|
// and you can no longer call `getMutableAIState()` inside any async callbacks
|
|
// created by that Action.
|
|
export function sealMutableAIState() {
|
|
const store = getAIStateStoreOrThrow('Internal error occurred.');
|
|
store.sealed = true;
|
|
}
|
|
|
|
/**
|
|
* Get the current AI state.
|
|
* If `key` is provided, it will return the value of the specified key in the
|
|
* AI state, if it's an object. If it's not an object, it will throw an error.
|
|
*
|
|
* @example const state = getAIState() // Get the entire AI state
|
|
* @example const field = getAIState('key') // Get the value of the key
|
|
*/
|
|
function getAIState<AI extends AIProvider = any>(): InferAIState<AI, any>;
|
|
function getAIState<AI extends AIProvider = any>(
|
|
key: keyof InferAIState<AI, any>,
|
|
): InferAIState<AI, any>[typeof key];
|
|
function getAIState<AI extends AIProvider = any>(
|
|
...args: [] | [key: keyof InferAIState<AI, any>]
|
|
) {
|
|
const store = getAIStateStoreOrThrow(
|
|
'`getAIState` must be called within an AI Action.',
|
|
);
|
|
|
|
if (args.length > 0) {
|
|
const key = args[0];
|
|
if (typeof store.currentState !== 'object') {
|
|
throw new Error(
|
|
`You can't get the "${String(
|
|
key,
|
|
)}" field from the AI state because it's not an object.`,
|
|
);
|
|
}
|
|
return store.currentState[key as keyof typeof store.currentState];
|
|
}
|
|
|
|
return store.currentState;
|
|
}
|
|
|
|
/**
|
|
* Get the mutable AI state. Note that you must call `.close()` when finishing
|
|
* updating the AI state.
|
|
*
|
|
* @example
|
|
* ```tsx
|
|
* const state = getMutableAIState()
|
|
* state.update({ ...state.get(), key: 'value' })
|
|
* state.update((currentState) => ({ ...currentState, key: 'value' }))
|
|
* state.done()
|
|
* ```
|
|
*
|
|
* @example
|
|
* ```tsx
|
|
* const state = getMutableAIState()
|
|
* state.done({ ...state.get(), key: 'value' }) // Done with a new state
|
|
* ```
|
|
*/
|
|
function getMutableAIState<AI extends AIProvider = any>(): MutableAIState<
|
|
InferAIState<AI, any>
|
|
>;
|
|
function getMutableAIState<AI extends AIProvider = any>(
|
|
key: keyof InferAIState<AI, any>,
|
|
): MutableAIState<InferAIState<AI, any>[typeof key]>;
|
|
function getMutableAIState<AI extends AIProvider = any>(
|
|
...args: [] | [key: keyof InferAIState<AI, any>]
|
|
) {
|
|
type AIState = InferAIState<AI, any>;
|
|
type AIStateWithKey = typeof args extends [key: keyof AIState]
|
|
? AIState[(typeof args)[0]]
|
|
: AIState;
|
|
type NewStateOrUpdater = ValueOrUpdater<AIStateWithKey>;
|
|
|
|
const store = getAIStateStoreOrThrow(
|
|
'`getMutableAIState` must be called within an AI Action.',
|
|
);
|
|
|
|
if (store.sealed) {
|
|
throw new Error(
|
|
"`getMutableAIState` must be called before returning from an AI Action. Please move it to the top level of the Action's function body.",
|
|
);
|
|
}
|
|
|
|
if (!store.mutationDeltaPromise) {
|
|
const { promise, resolve } = createResolvablePromise();
|
|
store.mutationDeltaPromise = promise;
|
|
store.mutationDeltaResolve = resolve;
|
|
}
|
|
|
|
function doUpdate(newState: NewStateOrUpdater, done: boolean) {
|
|
if (args.length > 0) {
|
|
if (typeof store.currentState !== 'object') {
|
|
const key = args[0];
|
|
throw new Error(
|
|
`You can't modify the "${String(
|
|
key,
|
|
)}" field of the AI state because it's not an object.`,
|
|
);
|
|
}
|
|
}
|
|
|
|
if (isFunction(newState)) {
|
|
if (args.length > 0) {
|
|
store.currentState[args[0]] = newState(store.currentState[args[0]]);
|
|
} else {
|
|
store.currentState = newState(store.currentState);
|
|
}
|
|
} else {
|
|
if (args.length > 0) {
|
|
store.currentState[args[0]] = newState;
|
|
} else {
|
|
store.currentState = newState;
|
|
}
|
|
}
|
|
|
|
store.options.onSetAIState?.({
|
|
key: args.length > 0 ? args[0] : undefined,
|
|
state: store.currentState,
|
|
done,
|
|
});
|
|
}
|
|
|
|
const mutableState = {
|
|
get: () => {
|
|
if (args.length > 0) {
|
|
const key = args[0];
|
|
if (typeof store.currentState !== 'object') {
|
|
throw new Error(
|
|
`You can't get the "${String(
|
|
key,
|
|
)}" field from the AI state because it's not an object.`,
|
|
);
|
|
}
|
|
return store.currentState[key];
|
|
}
|
|
|
|
return store.currentState as AIState;
|
|
},
|
|
update: function update(newAIState: NewStateOrUpdater) {
|
|
doUpdate(newAIState, false);
|
|
},
|
|
done: function done(...doneArgs: [] | [NewStateOrUpdater]) {
|
|
if (doneArgs.length > 0) {
|
|
doUpdate(doneArgs[0] as NewStateOrUpdater, true);
|
|
}
|
|
|
|
const delta = jsondiffpatch.diff(store.originalState, store.currentState);
|
|
store.mutationDeltaResolve!(delta);
|
|
},
|
|
};
|
|
|
|
return mutableState;
|
|
}
|
|
|
|
export { getAIState, getMutableAIState };
|