faiss_rag_enterprise/llama_index/response_synthesizers/refine.py

460 lines
17 KiB
Python

import logging
from typing import Any, Callable, Generator, Optional, Sequence, Type, cast
from llama_index.bridge.pydantic import BaseModel, Field, ValidationError
from llama_index.indices.utils import truncate_text
from llama_index.llm_predictor.base import LLMPredictorType
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
from llama_index.prompts.default_prompt_selectors import (
DEFAULT_REFINE_PROMPT_SEL,
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.prompts.mixin import PromptDictType
from llama_index.response.utils import get_response_text
from llama_index.response_synthesizers.base import BaseSynthesizer
from llama_index.service_context import ServiceContext
from llama_index.types import RESPONSE_TEXT_TYPE, BasePydanticProgram
logger = logging.getLogger(__name__)
class StructuredRefineResponse(BaseModel):
"""
Used to answer a given query based on the provided context.
Also indicates if the query was satisfied with the provided answer.
"""
answer: str = Field(
description="The answer for the given query, based on the context and not "
"prior knowledge."
)
query_satisfied: bool = Field(
description="True if there was enough context given to provide an answer "
"that satisfies the query."
)
class DefaultRefineProgram(BasePydanticProgram):
"""
Runs the query on the LLM as normal and always returns the answer with
query_satisfied=True. In effect, doesn't do any answer filtering.
"""
def __init__(
self, prompt: BasePromptTemplate, llm: LLMPredictorType, output_cls: BaseModel
):
self._prompt = prompt
self._llm = llm
self._output_cls = output_cls
@property
def output_cls(self) -> Type[BaseModel]:
return StructuredRefineResponse
def __call__(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
if self._output_cls is not None:
answer = self._llm.structured_predict(
self._output_cls,
self._prompt,
**kwds,
)
answer = answer.json()
else:
answer = self._llm.predict(
self._prompt,
**kwds,
)
return StructuredRefineResponse(answer=answer, query_satisfied=True)
async def acall(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
if self._output_cls is not None:
answer = await self._llm.astructured_predict(
self._output_cls,
self._prompt,
**kwds,
)
answer = answer.json()
else:
answer = await self._llm.apredict(
self._prompt,
**kwds,
)
return StructuredRefineResponse(answer=answer, query_satisfied=True)
class Refine(BaseSynthesizer):
"""Refine a response to a query across text chunks."""
def __init__(
self,
service_context: Optional[ServiceContext] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
refine_template: Optional[BasePromptTemplate] = None,
output_cls: Optional[BaseModel] = None,
streaming: bool = False,
verbose: bool = False,
structured_answer_filtering: bool = False,
program_factory: Optional[
Callable[[BasePromptTemplate], BasePydanticProgram]
] = None,
) -> None:
super().__init__(service_context=service_context, streaming=streaming)
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
self._refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
self._verbose = verbose
self._structured_answer_filtering = structured_answer_filtering
self._output_cls = output_cls
if self._streaming and self._structured_answer_filtering:
raise ValueError(
"Streaming not supported with structured answer filtering."
)
if not self._structured_answer_filtering and program_factory is not None:
raise ValueError(
"Program factory not supported without structured answer filtering."
)
self._program_factory = program_factory or self._default_program_factory
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {
"text_qa_template": self._text_qa_template,
"refine_template": self._refine_template,
}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]
if "refine_template" in prompts:
self._refine_template = prompts["refine_template"]
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Give response over chunks."""
response: Optional[RESPONSE_TEXT_TYPE] = None
for text_chunk in text_chunks:
if prev_response is None:
# if this is the first chunk, and text chunk already
# is an answer, then return it
response = self._give_response_single(
query_str, text_chunk, **response_kwargs
)
else:
# refine response if possible
response = self._refine_response_single(
prev_response, query_str, text_chunk, **response_kwargs
)
prev_response = response
if isinstance(response, str):
if self._output_cls is not None:
response = self._output_cls.parse_raw(response)
else:
response = response or "Empty Response"
else:
response = cast(Generator, response)
return response
def _default_program_factory(self, prompt: PromptTemplate) -> BasePydanticProgram:
if self._structured_answer_filtering:
from llama_index.program.utils import get_program_for_llm
return get_program_for_llm(
StructuredRefineResponse,
prompt,
self._service_context.llm,
verbose=self._verbose,
)
else:
return DefaultRefineProgram(
prompt=prompt,
llm=self._service_context.llm,
output_cls=self._output_cls,
)
def _give_response_single(
self,
query_str: str,
text_chunk: str,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Give response given a query and a corresponding text chunk."""
text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
text_chunks = self._service_context.prompt_helper.repack(
text_qa_template, [text_chunk]
)
response: Optional[RESPONSE_TEXT_TYPE] = None
program = self._program_factory(text_qa_template)
# TODO: consolidate with loop in get_response_default
for cur_text_chunk in text_chunks:
query_satisfied = False
if response is None and not self._streaming:
try:
structured_response = cast(
StructuredRefineResponse,
program(
context_str=cur_text_chunk,
**response_kwargs,
),
)
query_satisfied = structured_response.query_satisfied
if query_satisfied:
response = structured_response.answer
except ValidationError as e:
logger.warning(
f"Validation error on structured response: {e}", exc_info=True
)
elif response is None and self._streaming:
response = self._service_context.llm.stream(
text_qa_template,
context_str=cur_text_chunk,
**response_kwargs,
)
query_satisfied = True
else:
response = self._refine_response_single(
cast(RESPONSE_TEXT_TYPE, response),
query_str,
cur_text_chunk,
**response_kwargs,
)
if response is None:
response = "Empty Response"
if isinstance(response, str):
response = response or "Empty Response"
else:
response = cast(Generator, response)
return response
def _refine_response_single(
self,
response: RESPONSE_TEXT_TYPE,
query_str: str,
text_chunk: str,
**response_kwargs: Any,
) -> Optional[RESPONSE_TEXT_TYPE]:
"""Refine response."""
# TODO: consolidate with logic in response/schema.py
if isinstance(response, Generator):
response = get_response_text(response)
fmt_text_chunk = truncate_text(text_chunk, 50)
logger.debug(f"> Refine context: {fmt_text_chunk}")
if self._verbose:
print(f"> Refine context: {fmt_text_chunk}")
# NOTE: partial format refine template with query_str and existing_answer here
refine_template = self._refine_template.partial_format(
query_str=query_str, existing_answer=response
)
# compute available chunk size to see if there is any available space
# determine if the refine template is too big (which can happen if
# prompt template + query + existing answer is too large)
avail_chunk_size = (
self._service_context.prompt_helper._get_available_chunk_size(
refine_template
)
)
if avail_chunk_size < 0:
# if the available chunk size is negative, then the refine template
# is too big and we just return the original response
return response
# obtain text chunks to add to the refine template
text_chunks = self._service_context.prompt_helper.repack(
refine_template, text_chunks=[text_chunk]
)
program = self._program_factory(refine_template)
for cur_text_chunk in text_chunks:
query_satisfied = False
if not self._streaming:
try:
structured_response = cast(
StructuredRefineResponse,
program(
context_msg=cur_text_chunk,
**response_kwargs,
),
)
query_satisfied = structured_response.query_satisfied
if query_satisfied:
response = structured_response.answer
except ValidationError as e:
logger.warning(
f"Validation error on structured response: {e}", exc_info=True
)
else:
# TODO: structured response not supported for streaming
if isinstance(response, Generator):
response = "".join(response)
refine_template = self._refine_template.partial_format(
query_str=query_str, existing_answer=response
)
response = self._service_context.llm.stream(
refine_template,
context_msg=cur_text_chunk,
**response_kwargs,
)
return response
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
response: Optional[RESPONSE_TEXT_TYPE] = None
for text_chunk in text_chunks:
if prev_response is None:
# if this is the first chunk, and text chunk already
# is an answer, then return it
response = await self._agive_response_single(
query_str, text_chunk, **response_kwargs
)
else:
response = await self._arefine_response_single(
prev_response, query_str, text_chunk, **response_kwargs
)
prev_response = response
if response is None:
response = "Empty Response"
if isinstance(response, str):
if self._output_cls is not None:
response = self._output_cls.parse_raw(response)
else:
response = response or "Empty Response"
else:
response = cast(Generator, response)
return response
async def _arefine_response_single(
self,
response: RESPONSE_TEXT_TYPE,
query_str: str,
text_chunk: str,
**response_kwargs: Any,
) -> Optional[RESPONSE_TEXT_TYPE]:
"""Refine response."""
# TODO: consolidate with logic in response/schema.py
if isinstance(response, Generator):
response = get_response_text(response)
fmt_text_chunk = truncate_text(text_chunk, 50)
logger.debug(f"> Refine context: {fmt_text_chunk}")
# NOTE: partial format refine template with query_str and existing_answer here
refine_template = self._refine_template.partial_format(
query_str=query_str, existing_answer=response
)
# compute available chunk size to see if there is any available space
# determine if the refine template is too big (which can happen if
# prompt template + query + existing answer is too large)
avail_chunk_size = (
self._service_context.prompt_helper._get_available_chunk_size(
refine_template
)
)
if avail_chunk_size < 0:
# if the available chunk size is negative, then the refine template
# is too big and we just return the original response
return response
# obtain text chunks to add to the refine template
text_chunks = self._service_context.prompt_helper.repack(
refine_template, text_chunks=[text_chunk]
)
program = self._program_factory(refine_template)
for cur_text_chunk in text_chunks:
query_satisfied = False
if not self._streaming:
try:
structured_response = await program.acall(
context_msg=cur_text_chunk,
**response_kwargs,
)
structured_response = cast(
StructuredRefineResponse, structured_response
)
query_satisfied = structured_response.query_satisfied
if query_satisfied:
response = structured_response.answer
except ValidationError as e:
logger.warning(
f"Validation error on structured response: {e}", exc_info=True
)
else:
raise ValueError("Streaming not supported for async")
if query_satisfied:
refine_template = self._refine_template.partial_format(
query_str=query_str, existing_answer=response
)
return response
async def _agive_response_single(
self,
query_str: str,
text_chunk: str,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Give response given a query and a corresponding text chunk."""
text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
text_chunks = self._service_context.prompt_helper.repack(
text_qa_template, [text_chunk]
)
response: Optional[RESPONSE_TEXT_TYPE] = None
program = self._program_factory(text_qa_template)
# TODO: consolidate with loop in get_response_default
for cur_text_chunk in text_chunks:
if response is None and not self._streaming:
try:
structured_response = await program.acall(
context_str=cur_text_chunk,
**response_kwargs,
)
structured_response = cast(
StructuredRefineResponse, structured_response
)
query_satisfied = structured_response.query_satisfied
if query_satisfied:
response = structured_response.answer
except ValidationError as e:
logger.warning(
f"Validation error on structured response: {e}", exc_info=True
)
elif response is None and self._streaming:
raise ValueError("Streaming not supported for async")
else:
response = await self._arefine_response_single(
cast(RESPONSE_TEXT_TYPE, response),
query_str,
cur_text_chunk,
**response_kwargs,
)
if response is None:
response = "Empty Response"
if isinstance(response, str):
response = response or "Empty Response"
else:
response = cast(Generator, response)
return response