faiss_rag_enterprise/llama_index/query_engine/flare/output_parser.py

67 lines
2.0 KiB
Python

"""FLARE output parsers."""
from typing import Any, Callable, Optional
from llama_index.query_engine.flare.schema import QueryTask
from llama_index.types import BaseOutputParser
def default_parse_is_done_fn(response: str) -> bool:
"""Default parse is done function."""
return "done" in response.lower()
def default_format_done_answer(response: str) -> str:
"""Default format done answer."""
return response.replace("done", "").strip()
class IsDoneOutputParser(BaseOutputParser):
"""Is done output parser."""
def __init__(
self,
is_done_fn: Optional[Callable[[str], bool]] = None,
fmt_answer_fn: Optional[Callable[[str], str]] = None,
) -> None:
"""Init params."""
self._is_done_fn = is_done_fn or default_parse_is_done_fn
self._fmt_answer_fn = fmt_answer_fn or default_format_done_answer
def parse(self, output: str) -> Any:
"""Parse output."""
is_done = default_parse_is_done_fn(output)
if is_done:
return True, self._fmt_answer_fn(output)
else:
return False, output
def format(self, output: str) -> str:
"""Format a query with structured output formatting instructions."""
raise NotImplementedError
class QueryTaskOutputParser(BaseOutputParser):
"""QueryTask output parser.
By default, parses output that contains "[Search(query)]" tags.
"""
def parse(self, output: str) -> Any:
"""Parse output."""
query_tasks = []
for idx, char in enumerate(output):
if char == "[":
start_idx = idx
elif char == "]":
end_idx = idx
raw_query_str = output[start_idx + 1 : end_idx]
query_str = raw_query_str.split("(")[1].split(")")[0]
query_tasks.append(QueryTask(query_str, start_idx, end_idx))
return query_tasks
def format(self, output: str) -> str:
"""Format a query with structured output formatting instructions."""
raise NotImplementedError