faiss_rag_enterprise/llama_index/query_engine/flare/answer_inserter.py

217 lines
6.8 KiB
Python

"""Answer inserter."""
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType
from llama_index.query_engine.flare.schema import QueryTask
from llama_index.service_context import ServiceContext
class BaseLookaheadAnswerInserter(PromptMixin):
"""Lookahead answer inserter.
These are responsible for insert answers into a lookahead answer template.
E.g.
lookahead answer: Red is for [Search(What is the meaning of Ghana's
flag being red?)], green for forests, and gold for mineral wealth.
query: What is the meaning of Ghana's flag being red?
query answer: "the blood of those who died in the country's struggle
for independence"
final answer: Red is for the blood of those who died in the country's
struggle for independence, green for forests, and gold for mineral wealth.
"""
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {}
@abstractmethod
def insert(
self,
response: str,
query_tasks: List[QueryTask],
answers: List[str],
prev_response: Optional[str] = None,
) -> str:
"""Insert answers into response."""
DEFAULT_ANSWER_INSERT_PROMPT_TMPL = """
An existing 'lookahead response' is given below. The lookahead response
contains `[Search(query)]` tags. Some queries have been executed and the
response retrieved. The queries and answers are also given below.
Also the previous response (the response before the lookahead response)
is given below.
Given the lookahead template, previous response, and also queries and answers,
please 'fill in' the lookahead template with the appropriate answers.
NOTE: Please make sure that the final response grammatically follows
the previous response + lookahead template. For example, if the previous
response is "New York City has a population of " and the lookahead
template is "[Search(What is the population of New York City?)]", then
the final response should be "8.4 million".
NOTE: the lookahead template may not be a complete sentence and may
contain trailing/leading commas, etc. Please preserve the original
formatting of the lookahead template if possible.
NOTE:
NOTE: the exception to the above rule is if the answer to a query
is equivalent to "I don't know" or "I don't have an answer". In this case,
modify the lookahead template to indicate that the answer is not known.
NOTE: the lookahead template may contain multiple `[Search(query)]` tags
and only a subset of these queries have been executed.
Do not replace the `[Search(query)]` tags that have not been executed.
Previous Response:
Lookahead Template:
Red is for [Search(What is the meaning of Ghana's \
flag being red?)], green for forests, and gold for mineral wealth.
Query-Answer Pairs:
Query: What is the meaning of Ghana's flag being red?
Answer: The red represents the blood of those who died in the country's struggle \
for independence
Filled in Answers:
Red is for the blood of those who died in the country's struggle for independence, \
green for forests, and gold for mineral wealth.
Previous Response:
One of the largest cities in the world
Lookahead Template:
, the city contains a population of [Search(What is the population \
of New York City?)]
Query-Answer Pairs:
Query: What is the population of New York City?
Answer: The population of New York City is 8.4 million
Synthesized Response:
, the city contains a population of 8.4 million
Previous Response:
the city contains a population of
Lookahead Template:
[Search(What is the population of New York City?)]
Query-Answer Pairs:
Query: What is the population of New York City?
Answer: The population of New York City is 8.4 million
Synthesized Response:
8.4 million
Previous Response:
{prev_response}
Lookahead Template:
{lookahead_response}
Query-Answer Pairs:
{query_answer_pairs}
Synthesized Response:
"""
DEFAULT_ANSWER_INSERT_PROMPT = PromptTemplate(DEFAULT_ANSWER_INSERT_PROMPT_TMPL)
class LLMLookaheadAnswerInserter(BaseLookaheadAnswerInserter):
"""LLM Lookahead answer inserter.
Takes in a lookahead response and a list of query tasks, and the
lookahead answers, and inserts the answers into the lookahead response.
Args:
service_context (ServiceContext): Service context.
"""
def __init__(
self,
service_context: Optional[ServiceContext] = None,
answer_insert_prompt: Optional[BasePromptTemplate] = None,
) -> None:
"""Init params."""
self._service_context = service_context or ServiceContext.from_defaults()
self._answer_insert_prompt = (
answer_insert_prompt or DEFAULT_ANSWER_INSERT_PROMPT
)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
return {
"answer_insert_prompt": self._answer_insert_prompt,
}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "answer_insert_prompt" in prompts:
self._answer_insert_prompt = prompts["answer_insert_prompt"]
def insert(
self,
response: str,
query_tasks: List[QueryTask],
answers: List[str],
prev_response: Optional[str] = None,
) -> str:
"""Insert answers into response."""
prev_response = prev_response or ""
query_answer_pairs = ""
for query_task, answer in zip(query_tasks, answers):
query_answer_pairs += f"Query: {query_task.query_str}\nAnswer: {answer}\n"
return self._service_context.llm.predict(
self._answer_insert_prompt,
lookahead_response=response,
query_answer_pairs=query_answer_pairs,
prev_response=prev_response,
)
class DirectLookaheadAnswerInserter(BaseLookaheadAnswerInserter):
"""Direct lookahead answer inserter.
Simple inserter module that directly inserts answers into
the [Search(query)] tags in the lookahead response.
Args:
service_context (ServiceContext): Service context.
"""
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def insert(
self,
response: str,
query_tasks: List[QueryTask],
answers: List[str],
prev_response: Optional[str] = None,
) -> str:
"""Insert answers into response."""
for query_task, answer in zip(query_tasks, answers):
response = (
response[: query_task.start_idx]
+ answer
+ response[query_task.end_idx + 1 :]
)
return response