faiss_rag_enterprise/llama_index/query_engine/retry_query_engine.py

137 lines
5.2 KiB
Python

import logging
from typing import Optional
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.response.schema import RESPONSE_TYPE, Response
from llama_index.evaluation.base import BaseEvaluator
from llama_index.evaluation.guideline import GuidelineEvaluator
from llama_index.indices.query.query_transform.feedback_transform import (
FeedbackQueryTransformation,
)
from llama_index.prompts.mixin import PromptMixinType
from llama_index.schema import QueryBundle
logger = logging.getLogger(__name__)
class RetryQueryEngine(BaseQueryEngine):
"""Does retry on query engine if it fails evaluation.
Args:
query_engine (BaseQueryEngine): A query engine object
evaluator (BaseEvaluator): An evaluator object
max_retries (int): Maximum number of retries
callback_manager (Optional[CallbackManager]): A callback manager object
"""
def __init__(
self,
query_engine: BaseQueryEngine,
evaluator: BaseEvaluator,
max_retries: int = 3,
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._query_engine = query_engine
self._evaluator = evaluator
self.max_retries = max_retries
super().__init__(callback_manager)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {"query_engine": self._query_engine, "evaluator": self._evaluator}
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
response = self._query_engine._query(query_bundle)
if self.max_retries <= 0:
return response
typed_response = (
response if isinstance(response, Response) else response.get_response()
)
query_str = query_bundle.query_str
eval = self._evaluator.evaluate_response(query_str, typed_response)
if eval.passing:
logger.debug("Evaluation returned True.")
return response
else:
logger.debug("Evaluation returned False.")
new_query_engine = RetryQueryEngine(
self._query_engine, self._evaluator, self.max_retries - 1
)
query_transformer = FeedbackQueryTransformation()
new_query = query_transformer.run(query_bundle, {"evaluation": eval})
return new_query_engine.query(new_query)
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Not supported."""
return self._query(query_bundle)
class RetryGuidelineQueryEngine(BaseQueryEngine):
"""Does retry with evaluator feedback
if query engine fails evaluation.
Args:
query_engine (BaseQueryEngine): A query engine object
guideline_evaluator (GuidelineEvaluator): A guideline evaluator object
resynthesize_query (bool): Whether to resynthesize query
max_retries (int): Maximum number of retries
callback_manager (Optional[CallbackManager]): A callback manager object
"""
def __init__(
self,
query_engine: BaseQueryEngine,
guideline_evaluator: GuidelineEvaluator,
resynthesize_query: bool = False,
max_retries: int = 3,
callback_manager: Optional[CallbackManager] = None,
query_transformer: Optional[FeedbackQueryTransformation] = None,
) -> None:
self._query_engine = query_engine
self._guideline_evaluator = guideline_evaluator
self.max_retries = max_retries
self.resynthesize_query = resynthesize_query
self.query_transformer = query_transformer or FeedbackQueryTransformation(
resynthesize_query=self.resynthesize_query
)
super().__init__(callback_manager)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {
"query_engine": self._query_engine,
"guideline_evalator": self._guideline_evaluator,
}
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
response = self._query_engine._query(query_bundle)
if self.max_retries <= 0:
return response
typed_response = (
response if isinstance(response, Response) else response.get_response()
)
query_str = query_bundle.query_str
eval = self._guideline_evaluator.evaluate_response(query_str, typed_response)
if eval.passing:
logger.debug("Evaluation returned True.")
return response
else:
logger.debug("Evaluation returned False.")
new_query_engine = RetryGuidelineQueryEngine(
self._query_engine,
self._guideline_evaluator,
self.resynthesize_query,
self.max_retries - 1,
self.callback_manager,
)
new_query = self.query_transformer.run(query_bundle, {"evaluation": eval})
logger.debug("New query: %s", new_query.query_str)
return new_query_engine.query(new_query)
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Not supported."""
return self._query(query_bundle)