faiss_rag_enterprise/llama_index/response_synthesizers/base.py

273 lines
9.1 KiB
Python

"""Response builder class.
This class provides general functions for taking in a set of text
and generating a response.
Will support different modes, from 1) stuffing chunks into prompt,
2) create and refine separately over each chunk, 3) tree summarization.
"""
import logging
from abc import abstractmethod
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.query_pipeline.query_component import (
ChainableMixin,
InputKeys,
OutputKeys,
QueryComponent,
validate_and_convert_stringable,
)
from llama_index.core.response.schema import (
RESPONSE_TYPE,
PydanticResponse,
Response,
StreamingResponse,
)
from llama_index.prompts.mixin import PromptMixin
from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
from llama_index.service_context import ServiceContext
from llama_index.types import RESPONSE_TEXT_TYPE
logger = logging.getLogger(__name__)
QueryTextType = Union[str, QueryBundle]
class BaseSynthesizer(ChainableMixin, PromptMixin):
"""Response builder class."""
def __init__(
self,
service_context: Optional[ServiceContext] = None,
streaming: bool = False,
output_cls: BaseModel = None,
) -> None:
"""Init params."""
self._service_context = service_context or ServiceContext.from_defaults()
self._callback_manager = self._service_context.callback_manager
self._streaming = streaming
self._output_cls = output_cls
def _get_prompt_modules(self) -> Dict[str, Any]:
"""Get prompt modules."""
# TODO: keep this for now since response synthesizers don't generally have sub-modules
return {}
@property
def service_context(self) -> ServiceContext:
return self._service_context
@property
def callback_manager(self) -> CallbackManager:
return self._callback_manager
@callback_manager.setter
def callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
self._callback_manager = callback_manager
# TODO: please fix this later
self._service_context.callback_manager = callback_manager
self._service_context.llm.callback_manager = callback_manager
self._service_context.embed_model.callback_manager = callback_manager
self._service_context.node_parser.callback_manager = callback_manager
@abstractmethod
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Get response."""
...
@abstractmethod
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Get response."""
...
def _log_prompt_and_response(
self,
formatted_prompt: str,
response: RESPONSE_TEXT_TYPE,
log_prefix: str = "",
) -> None:
"""Log prompt and response from LLM."""
logger.debug(f"> {log_prefix} prompt template: {formatted_prompt}")
self._service_context.llama_logger.add_log(
{"formatted_prompt_template": formatted_prompt}
)
logger.debug(f"> {log_prefix} response: {response}")
self._service_context.llama_logger.add_log(
{f"{log_prefix.lower()}_response": response or "Empty Response"}
)
def _get_metadata_for_response(
self,
nodes: List[BaseNode],
) -> Optional[Dict[str, Any]]:
"""Get metadata for response."""
return {node.node_id: node.metadata for node in nodes}
def _prepare_response_output(
self,
response_str: Optional[RESPONSE_TEXT_TYPE],
source_nodes: List[NodeWithScore],
) -> RESPONSE_TYPE:
"""Prepare response object from response string."""
response_metadata = self._get_metadata_for_response(
[node_with_score.node for node_with_score in source_nodes]
)
if isinstance(response_str, str):
return Response(
response_str,
source_nodes=source_nodes,
metadata=response_metadata,
)
if isinstance(response_str, Generator):
return StreamingResponse(
response_str,
source_nodes=source_nodes,
metadata=response_metadata,
)
if isinstance(response_str, self._output_cls):
return PydanticResponse(
response_str, source_nodes=source_nodes, metadata=response_metadata
)
raise ValueError(
f"Response must be a string or a generator. Found {type(response_str)}"
)
def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
if len(nodes) == 0:
return Response("Empty Response")
if isinstance(query, str):
query = QueryBundle(query_str=query)
with self._callback_manager.event(
CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
) as event:
response_str = self.get_response(
query_str=query.query_str,
text_chunks=[
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
],
**response_kwargs,
)
additional_source_nodes = additional_source_nodes or []
source_nodes = list(nodes) + list(additional_source_nodes)
response = self._prepare_response_output(response_str, source_nodes)
event.on_end(payload={EventPayload.RESPONSE: response})
return response
async def asynthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
if len(nodes) == 0:
return Response("Empty Response")
if isinstance(query, str):
query = QueryBundle(query_str=query)
with self._callback_manager.event(
CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
) as event:
response_str = await self.aget_response(
query_str=query.query_str,
text_chunks=[
n.node.get_content(metadata_mode=MetadataMode.LLM) for n in nodes
],
**response_kwargs,
)
additional_source_nodes = additional_source_nodes or []
source_nodes = list(nodes) + list(additional_source_nodes)
response = self._prepare_response_output(response_str, source_nodes)
event.on_end(payload={EventPayload.RESPONSE: response})
return response
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""As query component."""
return SynthesizerComponent(synthesizer=self)
class SynthesizerComponent(QueryComponent):
"""Synthesizer component."""
synthesizer: BaseSynthesizer = Field(..., description="Synthesizer")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
self.synthesizer.callback_manager = callback_manager
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# make sure both query_str and nodes are there
if "query_str" not in input:
raise ValueError("Input must have key 'query_str'")
input["query_str"] = validate_and_convert_stringable(input["query_str"])
if "nodes" not in input:
raise ValueError("Input must have key 'nodes'")
nodes = input["nodes"]
if not isinstance(nodes, list):
raise ValueError("Input nodes must be a list")
for node in nodes:
if not isinstance(node, NodeWithScore):
raise ValueError("Input nodes must be a list of NodeWithScore")
return input
def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
output = self.synthesizer.synthesize(kwargs["query_str"], kwargs["nodes"])
return {"output": output}
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
output = await self.synthesizer.asynthesize(
kwargs["query_str"], kwargs["nodes"]
)
return {"output": output}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
return InputKeys.from_keys({"query_str", "nodes"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})