273 lines
9.1 KiB
Python
273 lines
9.1 KiB
Python
"""Response builder class.
|
|
|
|
This class provides general functions for taking in a set of text
|
|
and generating a response.
|
|
|
|
Will support different modes, from 1) stuffing chunks into prompt,
|
|
2) create and refine separately over each chunk, 3) tree summarization.
|
|
|
|
"""
|
|
import logging
|
|
from abc import abstractmethod
|
|
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.query_pipeline.query_component import (
|
|
ChainableMixin,
|
|
InputKeys,
|
|
OutputKeys,
|
|
QueryComponent,
|
|
validate_and_convert_stringable,
|
|
)
|
|
from llama_index.core.response.schema import (
|
|
RESPONSE_TYPE,
|
|
PydanticResponse,
|
|
Response,
|
|
StreamingResponse,
|
|
)
|
|
from llama_index.prompts.mixin import PromptMixin
|
|
from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.types import RESPONSE_TEXT_TYPE
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
QueryTextType = Union[str, QueryBundle]
|
|
|
|
|
|
class BaseSynthesizer(ChainableMixin, PromptMixin):
|
|
"""Response builder class."""
|
|
|
|
def __init__(
|
|
self,
|
|
service_context: Optional[ServiceContext] = None,
|
|
streaming: bool = False,
|
|
output_cls: BaseModel = None,
|
|
) -> None:
|
|
"""Init params."""
|
|
self._service_context = service_context or ServiceContext.from_defaults()
|
|
self._callback_manager = self._service_context.callback_manager
|
|
self._streaming = streaming
|
|
self._output_cls = output_cls
|
|
|
|
def _get_prompt_modules(self) -> Dict[str, Any]:
|
|
"""Get prompt modules."""
|
|
# TODO: keep this for now since response synthesizers don't generally have sub-modules
|
|
return {}
|
|
|
|
@property
|
|
def service_context(self) -> ServiceContext:
|
|
return self._service_context
|
|
|
|
@property
|
|
def callback_manager(self) -> CallbackManager:
|
|
return self._callback_manager
|
|
|
|
@callback_manager.setter
|
|
def callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
self._callback_manager = callback_manager
|
|
# TODO: please fix this later
|
|
self._service_context.callback_manager = callback_manager
|
|
self._service_context.llm.callback_manager = callback_manager
|
|
self._service_context.embed_model.callback_manager = callback_manager
|
|
self._service_context.node_parser.callback_manager = callback_manager
|
|
|
|
@abstractmethod
|
|
def get_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Get response."""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def aget_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Get response."""
|
|
...
|
|
|
|
def _log_prompt_and_response(
|
|
self,
|
|
formatted_prompt: str,
|
|
response: RESPONSE_TEXT_TYPE,
|
|
log_prefix: str = "",
|
|
) -> None:
|
|
"""Log prompt and response from LLM."""
|
|
logger.debug(f"> {log_prefix} prompt template: {formatted_prompt}")
|
|
self._service_context.llama_logger.add_log(
|
|
{"formatted_prompt_template": formatted_prompt}
|
|
)
|
|
logger.debug(f"> {log_prefix} response: {response}")
|
|
self._service_context.llama_logger.add_log(
|
|
{f"{log_prefix.lower()}_response": response or "Empty Response"}
|
|
)
|
|
|
|
def _get_metadata_for_response(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Get metadata for response."""
|
|
return {node.node_id: node.metadata for node in nodes}
|
|
|
|
def _prepare_response_output(
|
|
self,
|
|
response_str: Optional[RESPONSE_TEXT_TYPE],
|
|
source_nodes: List[NodeWithScore],
|
|
) -> RESPONSE_TYPE:
|
|
"""Prepare response object from response string."""
|
|
response_metadata = self._get_metadata_for_response(
|
|
[node_with_score.node for node_with_score in source_nodes]
|
|
)
|
|
|
|
if isinstance(response_str, str):
|
|
return Response(
|
|
response_str,
|
|
source_nodes=source_nodes,
|
|
metadata=response_metadata,
|
|
)
|
|
if isinstance(response_str, Generator):
|
|
return StreamingResponse(
|
|
response_str,
|
|
source_nodes=source_nodes,
|
|
metadata=response_metadata,
|
|
)
|
|
if isinstance(response_str, self._output_cls):
|
|
return PydanticResponse(
|
|
response_str, source_nodes=source_nodes, metadata=response_metadata
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Response must be a string or a generator. Found {type(response_str)}"
|
|
)
|
|
|
|
def synthesize(
|
|
self,
|
|
query: QueryTextType,
|
|
nodes: List[NodeWithScore],
|
|
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TYPE:
|
|
if len(nodes) == 0:
|
|
return Response("Empty Response")
|
|
|
|
if isinstance(query, str):
|
|
query = QueryBundle(query_str=query)
|
|
|
|
with self._callback_manager.event(
|
|
CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
|
|
) as event:
|
|
response_str = self.get_response(
|
|
query_str=query.query_str,
|
|
text_chunks=[
|
|
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
|
|
],
|
|
**response_kwargs,
|
|
)
|
|
|
|
additional_source_nodes = additional_source_nodes or []
|
|
source_nodes = list(nodes) + list(additional_source_nodes)
|
|
|
|
response = self._prepare_response_output(response_str, source_nodes)
|
|
|
|
event.on_end(payload={EventPayload.RESPONSE: response})
|
|
|
|
return response
|
|
|
|
async def asynthesize(
|
|
self,
|
|
query: QueryTextType,
|
|
nodes: List[NodeWithScore],
|
|
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TYPE:
|
|
if len(nodes) == 0:
|
|
return Response("Empty Response")
|
|
|
|
if isinstance(query, str):
|
|
query = QueryBundle(query_str=query)
|
|
|
|
with self._callback_manager.event(
|
|
CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
|
|
) as event:
|
|
response_str = await self.aget_response(
|
|
query_str=query.query_str,
|
|
text_chunks=[
|
|
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
|
|
],
|
|
**response_kwargs,
|
|
)
|
|
|
|
additional_source_nodes = additional_source_nodes or []
|
|
source_nodes = list(nodes) + list(additional_source_nodes)
|
|
|
|
response = self._prepare_response_output(response_str, source_nodes)
|
|
|
|
event.on_end(payload={EventPayload.RESPONSE: response})
|
|
|
|
return response
|
|
|
|
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
|
|
"""As query component."""
|
|
return SynthesizerComponent(synthesizer=self)
|
|
|
|
|
|
class SynthesizerComponent(QueryComponent):
|
|
"""Synthesizer component."""
|
|
|
|
synthesizer: BaseSynthesizer = Field(..., description="Synthesizer")
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
|
|
"""Set callback manager."""
|
|
self.synthesizer.callback_manager = callback_manager
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
# make sure both query_str and nodes are there
|
|
if "query_str" not in input:
|
|
raise ValueError("Input must have key 'query_str'")
|
|
input["query_str"] = validate_and_convert_stringable(input["query_str"])
|
|
|
|
if "nodes" not in input:
|
|
raise ValueError("Input must have key 'nodes'")
|
|
nodes = input["nodes"]
|
|
if not isinstance(nodes, list):
|
|
raise ValueError("Input nodes must be a list")
|
|
for node in nodes:
|
|
if not isinstance(node, NodeWithScore):
|
|
raise ValueError("Input nodes must be a list of NodeWithScore")
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Run component."""
|
|
output = self.synthesizer.synthesize(kwargs["query_str"], kwargs["nodes"])
|
|
return {"output": output}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Run component."""
|
|
output = await self.synthesizer.asynthesize(
|
|
kwargs["query_str"], kwargs["nodes"]
|
|
)
|
|
return {"output": output}
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
return InputKeys.from_keys({"query_str", "nodes"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"output"})
|