45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from queue import Queue
|
|
from threading import Event
|
|
from typing import Any, Generator, List, Optional
|
|
from uuid import UUID
|
|
|
|
from llama_index.bridge.langchain import BaseCallbackHandler, LLMResult
|
|
|
|
|
|
class StreamingGeneratorCallbackHandler(BaseCallbackHandler):
|
|
"""Streaming callback handler."""
|
|
|
|
def __init__(self) -> None:
|
|
self._token_queue: Queue = Queue()
|
|
self._done = Event()
|
|
|
|
def __deepcopy__(self, memo: Any) -> "StreamingGeneratorCallbackHandler":
|
|
# NOTE: hack to bypass deepcopy in langchain
|
|
return self
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
self._token_queue.put_nowait(token)
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
self._done.set()
|
|
|
|
def on_llm_error(
|
|
self,
|
|
error: BaseException,
|
|
*,
|
|
run_id: UUID,
|
|
parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._done.set()
|
|
|
|
def get_response_gen(self) -> Generator:
|
|
while True:
|
|
if not self._token_queue.empty():
|
|
token = self._token_queue.get_nowait()
|
|
yield token
|
|
elif self._done.is_set():
|
|
break
|