137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.response.schema import RESPONSE_TYPE, Response
|
|
from llama_index.evaluation.base import BaseEvaluator
|
|
from llama_index.evaluation.guideline import GuidelineEvaluator
|
|
from llama_index.indices.query.query_transform.feedback_transform import (
|
|
FeedbackQueryTransformation,
|
|
)
|
|
from llama_index.prompts.mixin import PromptMixinType
|
|
from llama_index.schema import QueryBundle
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RetryQueryEngine(BaseQueryEngine):
|
|
"""Does retry on query engine if it fails evaluation.
|
|
|
|
Args:
|
|
query_engine (BaseQueryEngine): A query engine object
|
|
evaluator (BaseEvaluator): An evaluator object
|
|
max_retries (int): Maximum number of retries
|
|
callback_manager (Optional[CallbackManager]): A callback manager object
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_engine: BaseQueryEngine,
|
|
evaluator: BaseEvaluator,
|
|
max_retries: int = 3,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> None:
|
|
self._query_engine = query_engine
|
|
self._evaluator = evaluator
|
|
self.max_retries = max_retries
|
|
super().__init__(callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
return {"query_engine": self._query_engine, "evaluator": self._evaluator}
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Answer a query."""
|
|
response = self._query_engine._query(query_bundle)
|
|
if self.max_retries <= 0:
|
|
return response
|
|
typed_response = (
|
|
response if isinstance(response, Response) else response.get_response()
|
|
)
|
|
query_str = query_bundle.query_str
|
|
eval = self._evaluator.evaluate_response(query_str, typed_response)
|
|
if eval.passing:
|
|
logger.debug("Evaluation returned True.")
|
|
return response
|
|
else:
|
|
logger.debug("Evaluation returned False.")
|
|
new_query_engine = RetryQueryEngine(
|
|
self._query_engine, self._evaluator, self.max_retries - 1
|
|
)
|
|
query_transformer = FeedbackQueryTransformation()
|
|
new_query = query_transformer.run(query_bundle, {"evaluation": eval})
|
|
return new_query_engine.query(new_query)
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Not supported."""
|
|
return self._query(query_bundle)
|
|
|
|
|
|
class RetryGuidelineQueryEngine(BaseQueryEngine):
|
|
"""Does retry with evaluator feedback
|
|
if query engine fails evaluation.
|
|
|
|
Args:
|
|
query_engine (BaseQueryEngine): A query engine object
|
|
guideline_evaluator (GuidelineEvaluator): A guideline evaluator object
|
|
resynthesize_query (bool): Whether to resynthesize query
|
|
max_retries (int): Maximum number of retries
|
|
callback_manager (Optional[CallbackManager]): A callback manager object
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_engine: BaseQueryEngine,
|
|
guideline_evaluator: GuidelineEvaluator,
|
|
resynthesize_query: bool = False,
|
|
max_retries: int = 3,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
query_transformer: Optional[FeedbackQueryTransformation] = None,
|
|
) -> None:
|
|
self._query_engine = query_engine
|
|
self._guideline_evaluator = guideline_evaluator
|
|
self.max_retries = max_retries
|
|
self.resynthesize_query = resynthesize_query
|
|
self.query_transformer = query_transformer or FeedbackQueryTransformation(
|
|
resynthesize_query=self.resynthesize_query
|
|
)
|
|
super().__init__(callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
return {
|
|
"query_engine": self._query_engine,
|
|
"guideline_evalator": self._guideline_evaluator,
|
|
}
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Answer a query."""
|
|
response = self._query_engine._query(query_bundle)
|
|
if self.max_retries <= 0:
|
|
return response
|
|
typed_response = (
|
|
response if isinstance(response, Response) else response.get_response()
|
|
)
|
|
query_str = query_bundle.query_str
|
|
eval = self._guideline_evaluator.evaluate_response(query_str, typed_response)
|
|
if eval.passing:
|
|
logger.debug("Evaluation returned True.")
|
|
return response
|
|
else:
|
|
logger.debug("Evaluation returned False.")
|
|
new_query_engine = RetryGuidelineQueryEngine(
|
|
self._query_engine,
|
|
self._guideline_evaluator,
|
|
self.resynthesize_query,
|
|
self.max_retries - 1,
|
|
self.callback_manager,
|
|
)
|
|
new_query = self.query_transformer.run(query_bundle, {"evaluation": eval})
|
|
logger.debug("New query: %s", new_query.query_str)
|
|
return new_query_engine.query(new_query)
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Not supported."""
|
|
return self._query(query_bundle)
|