460 lines
17 KiB
Python
460 lines
17 KiB
Python
import logging
|
|
from typing import Any, Callable, Generator, Optional, Sequence, Type, cast
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field, ValidationError
|
|
from llama_index.indices.utils import truncate_text
|
|
from llama_index.llm_predictor.base import LLMPredictorType
|
|
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
|
|
from llama_index.prompts.default_prompt_selectors import (
|
|
DEFAULT_REFINE_PROMPT_SEL,
|
|
DEFAULT_TEXT_QA_PROMPT_SEL,
|
|
)
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.response.utils import get_response_text
|
|
from llama_index.response_synthesizers.base import BaseSynthesizer
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.types import RESPONSE_TEXT_TYPE, BasePydanticProgram
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StructuredRefineResponse(BaseModel):
|
|
"""
|
|
Used to answer a given query based on the provided context.
|
|
|
|
Also indicates if the query was satisfied with the provided answer.
|
|
"""
|
|
|
|
answer: str = Field(
|
|
description="The answer for the given query, based on the context and not "
|
|
"prior knowledge."
|
|
)
|
|
query_satisfied: bool = Field(
|
|
description="True if there was enough context given to provide an answer "
|
|
"that satisfies the query."
|
|
)
|
|
|
|
|
|
class DefaultRefineProgram(BasePydanticProgram):
|
|
"""
|
|
Runs the query on the LLM as normal and always returns the answer with
|
|
query_satisfied=True. In effect, doesn't do any answer filtering.
|
|
"""
|
|
|
|
def __init__(
|
|
self, prompt: BasePromptTemplate, llm: LLMPredictorType, output_cls: BaseModel
|
|
):
|
|
self._prompt = prompt
|
|
self._llm = llm
|
|
self._output_cls = output_cls
|
|
|
|
@property
|
|
def output_cls(self) -> Type[BaseModel]:
|
|
return StructuredRefineResponse
|
|
|
|
def __call__(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
|
|
if self._output_cls is not None:
|
|
answer = self._llm.structured_predict(
|
|
self._output_cls,
|
|
self._prompt,
|
|
**kwds,
|
|
)
|
|
answer = answer.json()
|
|
else:
|
|
answer = self._llm.predict(
|
|
self._prompt,
|
|
**kwds,
|
|
)
|
|
return StructuredRefineResponse(answer=answer, query_satisfied=True)
|
|
|
|
async def acall(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
|
|
if self._output_cls is not None:
|
|
answer = await self._llm.astructured_predict(
|
|
self._output_cls,
|
|
self._prompt,
|
|
**kwds,
|
|
)
|
|
answer = answer.json()
|
|
else:
|
|
answer = await self._llm.apredict(
|
|
self._prompt,
|
|
**kwds,
|
|
)
|
|
return StructuredRefineResponse(answer=answer, query_satisfied=True)
|
|
|
|
|
|
class Refine(BaseSynthesizer):
|
|
"""Refine a response to a query across text chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
service_context: Optional[ServiceContext] = None,
|
|
text_qa_template: Optional[BasePromptTemplate] = None,
|
|
refine_template: Optional[BasePromptTemplate] = None,
|
|
output_cls: Optional[BaseModel] = None,
|
|
streaming: bool = False,
|
|
verbose: bool = False,
|
|
structured_answer_filtering: bool = False,
|
|
program_factory: Optional[
|
|
Callable[[BasePromptTemplate], BasePydanticProgram]
|
|
] = None,
|
|
) -> None:
|
|
super().__init__(service_context=service_context, streaming=streaming)
|
|
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
|
|
self._refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
|
|
self._verbose = verbose
|
|
self._structured_answer_filtering = structured_answer_filtering
|
|
self._output_cls = output_cls
|
|
|
|
if self._streaming and self._structured_answer_filtering:
|
|
raise ValueError(
|
|
"Streaming not supported with structured answer filtering."
|
|
)
|
|
if not self._structured_answer_filtering and program_factory is not None:
|
|
raise ValueError(
|
|
"Program factory not supported without structured answer filtering."
|
|
)
|
|
self._program_factory = program_factory or self._default_program_factory
|
|
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {
|
|
"text_qa_template": self._text_qa_template,
|
|
"refine_template": self._refine_template,
|
|
}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "text_qa_template" in prompts:
|
|
self._text_qa_template = prompts["text_qa_template"]
|
|
if "refine_template" in prompts:
|
|
self._refine_template = prompts["refine_template"]
|
|
|
|
def get_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Give response over chunks."""
|
|
response: Optional[RESPONSE_TEXT_TYPE] = None
|
|
for text_chunk in text_chunks:
|
|
if prev_response is None:
|
|
# if this is the first chunk, and text chunk already
|
|
# is an answer, then return it
|
|
response = self._give_response_single(
|
|
query_str, text_chunk, **response_kwargs
|
|
)
|
|
else:
|
|
# refine response if possible
|
|
response = self._refine_response_single(
|
|
prev_response, query_str, text_chunk, **response_kwargs
|
|
)
|
|
prev_response = response
|
|
if isinstance(response, str):
|
|
if self._output_cls is not None:
|
|
response = self._output_cls.parse_raw(response)
|
|
else:
|
|
response = response or "Empty Response"
|
|
else:
|
|
response = cast(Generator, response)
|
|
return response
|
|
|
|
def _default_program_factory(self, prompt: PromptTemplate) -> BasePydanticProgram:
|
|
if self._structured_answer_filtering:
|
|
from llama_index.program.utils import get_program_for_llm
|
|
|
|
return get_program_for_llm(
|
|
StructuredRefineResponse,
|
|
prompt,
|
|
self._service_context.llm,
|
|
verbose=self._verbose,
|
|
)
|
|
else:
|
|
return DefaultRefineProgram(
|
|
prompt=prompt,
|
|
llm=self._service_context.llm,
|
|
output_cls=self._output_cls,
|
|
)
|
|
|
|
def _give_response_single(
|
|
self,
|
|
query_str: str,
|
|
text_chunk: str,
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Give response given a query and a corresponding text chunk."""
|
|
text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
|
|
text_chunks = self._service_context.prompt_helper.repack(
|
|
text_qa_template, [text_chunk]
|
|
)
|
|
|
|
response: Optional[RESPONSE_TEXT_TYPE] = None
|
|
program = self._program_factory(text_qa_template)
|
|
# TODO: consolidate with loop in get_response_default
|
|
for cur_text_chunk in text_chunks:
|
|
query_satisfied = False
|
|
if response is None and not self._streaming:
|
|
try:
|
|
structured_response = cast(
|
|
StructuredRefineResponse,
|
|
program(
|
|
context_str=cur_text_chunk,
|
|
**response_kwargs,
|
|
),
|
|
)
|
|
query_satisfied = structured_response.query_satisfied
|
|
if query_satisfied:
|
|
response = structured_response.answer
|
|
except ValidationError as e:
|
|
logger.warning(
|
|
f"Validation error on structured response: {e}", exc_info=True
|
|
)
|
|
elif response is None and self._streaming:
|
|
response = self._service_context.llm.stream(
|
|
text_qa_template,
|
|
context_str=cur_text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
query_satisfied = True
|
|
else:
|
|
response = self._refine_response_single(
|
|
cast(RESPONSE_TEXT_TYPE, response),
|
|
query_str,
|
|
cur_text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
if response is None:
|
|
response = "Empty Response"
|
|
if isinstance(response, str):
|
|
response = response or "Empty Response"
|
|
else:
|
|
response = cast(Generator, response)
|
|
return response
|
|
|
|
def _refine_response_single(
|
|
self,
|
|
response: RESPONSE_TEXT_TYPE,
|
|
query_str: str,
|
|
text_chunk: str,
|
|
**response_kwargs: Any,
|
|
) -> Optional[RESPONSE_TEXT_TYPE]:
|
|
"""Refine response."""
|
|
# TODO: consolidate with logic in response/schema.py
|
|
if isinstance(response, Generator):
|
|
response = get_response_text(response)
|
|
|
|
fmt_text_chunk = truncate_text(text_chunk, 50)
|
|
logger.debug(f"> Refine context: {fmt_text_chunk}")
|
|
if self._verbose:
|
|
print(f"> Refine context: {fmt_text_chunk}")
|
|
|
|
# NOTE: partial format refine template with query_str and existing_answer here
|
|
refine_template = self._refine_template.partial_format(
|
|
query_str=query_str, existing_answer=response
|
|
)
|
|
|
|
# compute available chunk size to see if there is any available space
|
|
# determine if the refine template is too big (which can happen if
|
|
# prompt template + query + existing answer is too large)
|
|
avail_chunk_size = (
|
|
self._service_context.prompt_helper._get_available_chunk_size(
|
|
refine_template
|
|
)
|
|
)
|
|
|
|
if avail_chunk_size < 0:
|
|
# if the available chunk size is negative, then the refine template
|
|
# is too big and we just return the original response
|
|
return response
|
|
|
|
# obtain text chunks to add to the refine template
|
|
text_chunks = self._service_context.prompt_helper.repack(
|
|
refine_template, text_chunks=[text_chunk]
|
|
)
|
|
|
|
program = self._program_factory(refine_template)
|
|
for cur_text_chunk in text_chunks:
|
|
query_satisfied = False
|
|
if not self._streaming:
|
|
try:
|
|
structured_response = cast(
|
|
StructuredRefineResponse,
|
|
program(
|
|
context_msg=cur_text_chunk,
|
|
**response_kwargs,
|
|
),
|
|
)
|
|
query_satisfied = structured_response.query_satisfied
|
|
if query_satisfied:
|
|
response = structured_response.answer
|
|
except ValidationError as e:
|
|
logger.warning(
|
|
f"Validation error on structured response: {e}", exc_info=True
|
|
)
|
|
else:
|
|
# TODO: structured response not supported for streaming
|
|
if isinstance(response, Generator):
|
|
response = "".join(response)
|
|
|
|
refine_template = self._refine_template.partial_format(
|
|
query_str=query_str, existing_answer=response
|
|
)
|
|
|
|
response = self._service_context.llm.stream(
|
|
refine_template,
|
|
context_msg=cur_text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
|
|
return response
|
|
|
|
async def aget_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
response: Optional[RESPONSE_TEXT_TYPE] = None
|
|
for text_chunk in text_chunks:
|
|
if prev_response is None:
|
|
# if this is the first chunk, and text chunk already
|
|
# is an answer, then return it
|
|
response = await self._agive_response_single(
|
|
query_str, text_chunk, **response_kwargs
|
|
)
|
|
else:
|
|
response = await self._arefine_response_single(
|
|
prev_response, query_str, text_chunk, **response_kwargs
|
|
)
|
|
prev_response = response
|
|
if response is None:
|
|
response = "Empty Response"
|
|
if isinstance(response, str):
|
|
if self._output_cls is not None:
|
|
response = self._output_cls.parse_raw(response)
|
|
else:
|
|
response = response or "Empty Response"
|
|
else:
|
|
response = cast(Generator, response)
|
|
return response
|
|
|
|
async def _arefine_response_single(
|
|
self,
|
|
response: RESPONSE_TEXT_TYPE,
|
|
query_str: str,
|
|
text_chunk: str,
|
|
**response_kwargs: Any,
|
|
) -> Optional[RESPONSE_TEXT_TYPE]:
|
|
"""Refine response."""
|
|
# TODO: consolidate with logic in response/schema.py
|
|
if isinstance(response, Generator):
|
|
response = get_response_text(response)
|
|
|
|
fmt_text_chunk = truncate_text(text_chunk, 50)
|
|
logger.debug(f"> Refine context: {fmt_text_chunk}")
|
|
|
|
# NOTE: partial format refine template with query_str and existing_answer here
|
|
refine_template = self._refine_template.partial_format(
|
|
query_str=query_str, existing_answer=response
|
|
)
|
|
|
|
# compute available chunk size to see if there is any available space
|
|
# determine if the refine template is too big (which can happen if
|
|
# prompt template + query + existing answer is too large)
|
|
avail_chunk_size = (
|
|
self._service_context.prompt_helper._get_available_chunk_size(
|
|
refine_template
|
|
)
|
|
)
|
|
|
|
if avail_chunk_size < 0:
|
|
# if the available chunk size is negative, then the refine template
|
|
# is too big and we just return the original response
|
|
return response
|
|
|
|
# obtain text chunks to add to the refine template
|
|
text_chunks = self._service_context.prompt_helper.repack(
|
|
refine_template, text_chunks=[text_chunk]
|
|
)
|
|
|
|
program = self._program_factory(refine_template)
|
|
for cur_text_chunk in text_chunks:
|
|
query_satisfied = False
|
|
if not self._streaming:
|
|
try:
|
|
structured_response = await program.acall(
|
|
context_msg=cur_text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
structured_response = cast(
|
|
StructuredRefineResponse, structured_response
|
|
)
|
|
query_satisfied = structured_response.query_satisfied
|
|
if query_satisfied:
|
|
response = structured_response.answer
|
|
except ValidationError as e:
|
|
logger.warning(
|
|
f"Validation error on structured response: {e}", exc_info=True
|
|
)
|
|
else:
|
|
raise ValueError("Streaming not supported for async")
|
|
|
|
if query_satisfied:
|
|
refine_template = self._refine_template.partial_format(
|
|
query_str=query_str, existing_answer=response
|
|
)
|
|
|
|
return response
|
|
|
|
async def _agive_response_single(
|
|
self,
|
|
query_str: str,
|
|
text_chunk: str,
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
"""Give response given a query and a corresponding text chunk."""
|
|
text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
|
|
text_chunks = self._service_context.prompt_helper.repack(
|
|
text_qa_template, [text_chunk]
|
|
)
|
|
|
|
response: Optional[RESPONSE_TEXT_TYPE] = None
|
|
program = self._program_factory(text_qa_template)
|
|
# TODO: consolidate with loop in get_response_default
|
|
for cur_text_chunk in text_chunks:
|
|
if response is None and not self._streaming:
|
|
try:
|
|
structured_response = await program.acall(
|
|
context_str=cur_text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
structured_response = cast(
|
|
StructuredRefineResponse, structured_response
|
|
)
|
|
query_satisfied = structured_response.query_satisfied
|
|
if query_satisfied:
|
|
response = structured_response.answer
|
|
except ValidationError as e:
|
|
logger.warning(
|
|
f"Validation error on structured response: {e}", exc_info=True
|
|
)
|
|
elif response is None and self._streaming:
|
|
raise ValueError("Streaming not supported for async")
|
|
else:
|
|
response = await self._arefine_response_single(
|
|
cast(RESPONSE_TEXT_TYPE, response),
|
|
query_str,
|
|
cur_text_chunk,
|
|
**response_kwargs,
|
|
)
|
|
if response is None:
|
|
response = "Empty Response"
|
|
if isinstance(response, str):
|
|
response = response or "Empty Response"
|
|
else:
|
|
response = cast(Generator, response)
|
|
return response
|