"""Query engines based on the FLARE paper. Active Retrieval Augmented Generation. """ from typing import Any, Dict, Optional from llama_index.callbacks.base import CallbackManager from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.query_engine.flare.answer_inserter import ( BaseLookaheadAnswerInserter, LLMLookaheadAnswerInserter, ) from llama_index.query_engine.flare.output_parser import ( IsDoneOutputParser, QueryTaskOutputParser, ) from llama_index.schema import QueryBundle from llama_index.service_context import ServiceContext from llama_index.utils import print_text # These prompts are taken from the FLARE repo: # https://github.com/jzbjyb/FLARE/blob/main/src/templates.py DEFAULT_EXAMPLES = """ Query: But what are the risks during production of nanomaterials? Answer: [Search(What are some nanomaterial production risks?)] Query: The colors on the flag of Ghana have the following meanings. Answer: Red is for [Search(What is the meaning of Ghana's flag being red?)], \ green for forests, and gold for mineral wealth. Query: What did the author do during his time in college? Answer: The author took classes in [Search(What classes did the author take in \ college?)]. """ DEFAULT_FIRST_SKILL = f"""\ Skill 1. Use the Search API to look up relevant information by writing \ "[Search(query)]" where "query" is the search query you want to look up. \ For example: {DEFAULT_EXAMPLES} """ DEFAULT_SECOND_SKILL = """\ Skill 2. Solve more complex generation tasks by thinking step by step. For example: Query: Give a summary of the author's life and career. Answer: The author was born in 1990. Growing up, he [Search(What did the \ author do during his childhood?)]. Query: Can you write a summary of the Great Gatsby. Answer: The Great Gatsby is a novel written by F. Scott Fitzgerald. It is about \ [Search(What is the Great Gatsby about?)]. """ DEFAULT_END = """ Now given the following task, and the stub of an existing answer, generate the \ next portion of the answer. You may use the Search API \ "[Search(query)]" whenever possible. If the answer is complete and no longer contains any "[Search(query)]" tags, write \ "done" to finish the task. Do not write "done" if the answer still contains "[Search(query)]" tags. Do not make up answers. It is better to generate one "[Search(query)]" tag and stop \ generation than to fill in the answer with made up information with no "[Search(query)]" tags or multiple "[Search(query)]" tags that assume a structure in the answer. Try to limit generation to one sentence if possible. """ DEFAULT_INSTRUCT_PROMPT_TMPL = ( DEFAULT_FIRST_SKILL + DEFAULT_SECOND_SKILL + DEFAULT_END + ( """ Query: {query_str} Existing Answer: {existing_answer} Answer: """ ) ) DEFAULT_INSTRUCT_PROMPT = PromptTemplate(DEFAULT_INSTRUCT_PROMPT_TMPL) class FLAREInstructQueryEngine(BaseQueryEngine): """FLARE Instruct query engine. This is the version of FLARE that uses retrieval-encouraging instructions. NOTE: this is a beta feature. Interfaces might change, and it might not always give correct answers. Args: query_engine (BaseQueryEngine): query engine to use service_context (Optional[ServiceContext]): service context. Defaults to None. instruct_prompt (Optional[PromptTemplate]): instruct prompt. Defaults to None. lookahead_answer_inserter (Optional[BaseLookaheadAnswerInserter]): lookahead answer inserter. Defaults to None. done_output_parser (Optional[IsDoneOutputParser]): done output parser. Defaults to None. query_task_output_parser (Optional[QueryTaskOutputParser]): query task output parser. Defaults to None. max_iterations (int): max iterations. Defaults to 10. max_lookahead_query_tasks (int): max lookahead query tasks. Defaults to 1. callback_manager (Optional[CallbackManager]): callback manager. Defaults to None. verbose (bool): give verbose outputs. Defaults to False. """ def __init__( self, query_engine: BaseQueryEngine, service_context: Optional[ServiceContext] = None, instruct_prompt: Optional[BasePromptTemplate] = None, lookahead_answer_inserter: Optional[BaseLookaheadAnswerInserter] = None, done_output_parser: Optional[IsDoneOutputParser] = None, query_task_output_parser: Optional[QueryTaskOutputParser] = None, max_iterations: int = 10, max_lookahead_query_tasks: int = 1, callback_manager: Optional[CallbackManager] = None, verbose: bool = False, ) -> None: """Init params.""" super().__init__(callback_manager=callback_manager) self._query_engine = query_engine self._service_context = service_context or ServiceContext.from_defaults() self._instruct_prompt = instruct_prompt or DEFAULT_INSTRUCT_PROMPT self._lookahead_answer_inserter = lookahead_answer_inserter or ( LLMLookaheadAnswerInserter(service_context=self._service_context) ) self._done_output_parser = done_output_parser or IsDoneOutputParser() self._query_task_output_parser = ( query_task_output_parser or QueryTaskOutputParser() ) self._max_iterations = max_iterations self._max_lookahead_query_tasks = max_lookahead_query_tasks self._verbose = verbose def _get_prompts(self) -> Dict[str, Any]: """Get prompts.""" return { "instruct_prompt": self._instruct_prompt, } def _update_prompts(self, prompts: PromptDictType) -> None: """Update prompts.""" if "instruct_prompt" in prompts: self._instruct_prompt = prompts["instruct_prompt"] def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" return { "query_engine": self._query_engine, "lookahead_answer_inserter": self._lookahead_answer_inserter, } def _get_relevant_lookahead_response(self, updated_lookahead_resp: str) -> str: """Get relevant lookahead response.""" # if there's remaining query tasks, then truncate the response # until the start position of the first tag # there may be remaining query tasks because the _max_lookahead_query_tasks # is less than the total number of generated [Search(query)] tags remaining_query_tasks = self._query_task_output_parser.parse( updated_lookahead_resp ) if len(remaining_query_tasks) == 0: relevant_lookahead_resp = updated_lookahead_resp else: first_task = remaining_query_tasks[0] relevant_lookahead_resp = updated_lookahead_resp[: first_task.start_idx] return relevant_lookahead_resp def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Query and get response.""" print_text(f"Query: {query_bundle.query_str}\n", color="green") cur_response = "" source_nodes = [] for iter in range(self._max_iterations): if self._verbose: print_text(f"Current response: {cur_response}\n", color="blue") # generate "lookahead response" that contains "[Search(query)]" tags # e.g. # The colors on the flag of Ghana have the following meanings. Red is # for [Search(Ghana flag meaning)],... lookahead_resp = self._service_context.llm.predict( self._instruct_prompt, query_str=query_bundle.query_str, existing_answer=cur_response, ) lookahead_resp = lookahead_resp.strip() if self._verbose: print_text(f"Lookahead response: {lookahead_resp}\n", color="pink") is_done, fmt_lookahead = self._done_output_parser.parse(lookahead_resp) if is_done: cur_response = cur_response.strip() + " " + fmt_lookahead.strip() break # parse lookahead response into query tasks query_tasks = self._query_task_output_parser.parse(lookahead_resp) # get answers for each query task query_tasks = query_tasks[: self._max_lookahead_query_tasks] query_answers = [] for _, query_task in enumerate(query_tasks): answer_obj = self._query_engine.query(query_task.query_str) if not isinstance(answer_obj, Response): raise ValueError( f"Expected Response object, got {type(answer_obj)} instead." ) query_answer = str(answer_obj) query_answers.append(query_answer) source_nodes.extend(answer_obj.source_nodes) # fill in the lookahead response template with the query answers # from the query engine updated_lookahead_resp = self._lookahead_answer_inserter.insert( lookahead_resp, query_tasks, query_answers, prev_response=cur_response ) # get "relevant" lookahead response by truncating the updated # lookahead response until the start position of the first tag # also remove the prefix from the lookahead response, so that # we can concatenate it with the existing response relevant_lookahead_resp_wo_prefix = self._get_relevant_lookahead_response( updated_lookahead_resp ) if self._verbose: print_text( "Updated lookahead response: " + f"{relevant_lookahead_resp_wo_prefix}\n", color="pink", ) # append the relevant lookahead response to the final response cur_response = ( cur_response.strip() + " " + relevant_lookahead_resp_wo_prefix.strip() ) # NOTE: at the moment, does not support streaming return Response(response=cur_response, source_nodes=source_nodes) async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: return self._query(query_bundle)