faiss_rag_enterprise/llama_index/query_engine/jsonalyze_query_engine.py

346 lines
13 KiB
Python

import asyncio
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.response.schema import Response
from llama_index.indices.struct_store.sql_retriever import (
BaseSQLParser,
DefaultSQLParser,
)
from llama_index.prompts import BasePromptTemplate, PromptTemplate
from llama_index.prompts.default_prompts import DEFAULT_JSONALYZE_PROMPT
from llama_index.prompts.mixin import PromptDictType, PromptMixinType
from llama_index.prompts.prompt_type import PromptType
from llama_index.schema import QueryBundle
from llama_index.service_context import ServiceContext
from llama_index.utils import print_text
logger = logging.getLogger(__name__)
DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = (
"Given a query, synthesize a response based on SQL query results"
" to satisfy the query. Only include details that are relevant to"
" the query. If you don't know the answer, then say that.\n"
"SQL Query: {sql_query}\n"
"Table Schema: {table_schema}\n"
"SQL Response: {sql_response}\n"
"Query: {query_str}\n"
"Response: "
)
DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate(
DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL,
prompt_type=PromptType.SQL_RESPONSE_SYNTHESIS,
)
DEFAULT_TABLE_NAME = "items"
def default_jsonalyzer(
list_of_dict: List[Dict[str, Any]],
query_bundle: QueryBundle,
service_context: ServiceContext,
table_name: str = DEFAULT_TABLE_NAME,
prompt: BasePromptTemplate = DEFAULT_JSONALYZE_PROMPT,
sql_parser: BaseSQLParser = DefaultSQLParser(),
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
"""Default JSONalyzer that executes a query on a list of dictionaries.
Args:
list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
query_bundle (QueryBundle): The query bundle.
service_context (Optional[ServiceContext]): The service context.
table_name (str): The table name to use, defaults to DEFAULT_TABLE_NAME.
prompt (BasePromptTemplate): The prompt to use.
sql_parser (BaseSQLParser): The SQL parser to use.
Returns:
Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
the Schema, and the Result.
"""
try:
import sqlite_utils
except ImportError as exc:
IMPORT_ERROR_MSG = (
"sqlite-utils is needed to use this Query Engine:\n"
"pip install sqlite-utils"
)
raise ImportError(IMPORT_ERROR_MSG) from exc
# Instantiate in-memory SQLite database
db = sqlite_utils.Database(memory=True)
try:
# Load list of dictionaries into SQLite database
db[table_name].insert_all(list_of_dict)
except sqlite_utils.db_exceptions.IntegrityError as exc:
print_text(f"Error inserting into table {table_name}, expected format:")
print_text("[{col1: val1, col2: val2, ...}, ...]")
raise ValueError("Invalid list_of_dict") from exc
# Get the table schema
table_schema = db[table_name].columns_dict
query = query_bundle.query_str
prompt = prompt or DEFAULT_JSONALYZE_PROMPT
# Get the SQL query with text-to-SQL prompt
response_str = service_context.llm.predict(
prompt=prompt,
table_name=table_name,
table_schema=table_schema,
question=query,
)
sql_parser = sql_parser or DefaultSQLParser()
sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)
try:
# Execute the SQL query
results = list(db.query(sql_query))
except sqlite_utils.db_exceptions.OperationalError as exc:
print_text(f"Error executing query: {sql_query}")
raise ValueError("Invalid query") from exc
return sql_query, table_schema, results
async def async_default_jsonalyzer(
list_of_dict: List[Dict[str, Any]],
query_bundle: QueryBundle,
service_context: ServiceContext,
prompt: Optional[BasePromptTemplate] = None,
sql_parser: Optional[BaseSQLParser] = None,
table_name: str = DEFAULT_TABLE_NAME,
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
"""Default JSONalyzer.
Args:
list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
query_bundle (QueryBundle): The query bundle.
service_context (ServiceContext): ServiceContext
prompt (BasePromptTemplate, optional): The prompt to use.
sql_parser (BaseSQLParser, optional): The SQL parser to use.
table_name (str, optional): The table name to use, defaults to DEFAULT_TABLE_NAME.
Returns:
Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
the Schema, and the Result.
"""
try:
import sqlite_utils
except ImportError as exc:
IMPORT_ERROR_MSG = (
"sqlite-utils is needed to use this Query Engine:\n"
"pip install sqlite-utils"
)
raise ImportError(IMPORT_ERROR_MSG) from exc
# Instantiate in-memory SQLite database
db = sqlite_utils.Database(memory=True)
try:
# Load list of dictionaries into SQLite database
db[table_name].insert_all(list_of_dict)
except sqlite_utils.db_exceptions.IntegrityError as exc:
print_text(f"Error inserting into table {table_name}, expected format:")
print_text("[{col1: val1, col2: val2, ...}, ...]")
raise ValueError("Invalid list_of_dict") from exc
# Get the table schema
table_schema = db[table_name].columns_dict
query = query_bundle.query_str
prompt = prompt or DEFAULT_JSONALYZE_PROMPT
# Get the SQL query with text-to-SQL prompt
response_str = await service_context.llm.apredict(
prompt=prompt,
table_name=table_name,
table_schema=table_schema,
question=query,
)
sql_parser = sql_parser or DefaultSQLParser()
sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)
try:
# Execute the SQL query
results = list(db.query(sql_query))
except sqlite_utils.db_exceptions.OperationalError as exc:
print_text(f"Error executing query: {sql_query}")
raise ValueError("Invalid query") from exc
return sql_query, table_schema, results
def load_jsonalyzer(
use_async: bool = False,
custom_jsonalyzer: Optional[Callable] = None,
) -> Callable:
"""Load the JSONalyzer.
Args:
use_async (bool): Whether to use async.
custom_jsonalyzer (Callable): A custom JSONalyzer to use.
Returns:
Callable: The JSONalyzer.
"""
if custom_jsonalyzer:
assert not use_async or asyncio.iscoroutinefunction(
custom_jsonalyzer
), "custom_jsonalyzer function must be async when use_async is True"
return custom_jsonalyzer
else:
# make mypy happy to indent this
if use_async:
return async_default_jsonalyzer
else:
return default_jsonalyzer
class JSONalyzeQueryEngine(BaseQueryEngine):
"""JSON List Shape Data Analysis Query Engine.
Converts natural language statasical queries to SQL within in-mem SQLite queries.
list_of_dict(List[Dict[str, Any]]): List of dictionaries to query.
service_context (ServiceContext): ServiceContext
jsonalyze_prompt (BasePromptTemplate): The JSONalyze prompt to use.
use_async (bool): Whether to use async.
analyzer (Callable): The analyzer that executes the query.
sql_parser (BaseSQLParser): The SQL parser that ensures valid SQL being parsed
from llm output.
synthesize_response (bool): Whether to synthesize a response.
response_synthesis_prompt (BasePromptTemplate): The response synthesis prompt
to use.
table_name (str): The table name to use.
verbose (bool): Whether to print verbose output.
"""
def __init__(
self,
list_of_dict: List[Dict[str, Any]],
service_context: ServiceContext,
jsonalyze_prompt: Optional[BasePromptTemplate] = None,
use_async: bool = False,
analyzer: Optional[Callable] = None,
sql_parser: Optional[BaseSQLParser] = None,
synthesize_response: bool = True,
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
table_name: str = DEFAULT_TABLE_NAME,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._list_of_dict = list_of_dict
self._service_context = service_context or ServiceContext.from_defaults()
self._jsonalyze_prompt = jsonalyze_prompt or DEFAULT_JSONALYZE_PROMPT
self._use_async = use_async
self._analyzer = load_jsonalyzer(use_async, analyzer)
self._sql_parser = sql_parser or DefaultSQLParser()
self._synthesize_response = synthesize_response
self._response_synthesis_prompt = (
response_synthesis_prompt or DEFAULT_RESPONSE_SYNTHESIS_PROMPT
)
self._table_name = table_name
self._verbose = verbose
super().__init__(self._service_context.callback_manager)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
return {
"jsonalyze_prompt": self._jsonalyze_prompt,
"response_synthesis_prompt": self._response_synthesis_prompt,
}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "jsonalyze_prompt" in prompts:
self._jsonalyze_prompt = prompts["jsonalyze_prompt"]
if "response_synthesis_prompt" in prompts:
self._response_synthesis_prompt = prompts["response_synthesis_prompt"]
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {}
def _query(self, query_bundle: QueryBundle) -> Response:
"""Answer an analytical query on the JSON List."""
query = query_bundle.query_str
if self._verbose:
print_text(f"Query: {query}\n", color="green")
# Perform the analysis
sql_query, table_schema, results = self._analyzer(
self._list_of_dict,
query_bundle,
self._service_context,
table_name=self._table_name,
prompt=self._jsonalyze_prompt,
sql_parser=self._sql_parser,
)
if self._verbose:
print_text(f"SQL Query: {sql_query}\n", color="blue")
print_text(f"Table Schema: {table_schema}\n", color="cyan")
print_text(f"SQL Response: {results}\n", color="yellow")
if self._synthesize_response:
response_str = self._service_context.llm.predict(
self._response_synthesis_prompt,
sql_query=sql_query,
table_schema=table_schema,
sql_response=results,
query_str=query_bundle.query_str,
)
if self._verbose:
print_text(f"Response: {response_str}", color="magenta")
else:
response_str = str(results)
response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}
return Response(response=response_str, metadata=response_metadata)
async def _aquery(self, query_bundle: QueryBundle) -> Response:
"""Answer an analytical query on the JSON List."""
query = query_bundle.query_str
if self._verbose:
print_text(f"Query: {query}", color="green")
# Perform the analysis
sql_query, table_schema, results = self._analyzer(
self._list_of_dict,
query,
self._service_context,
table_name=self._table_name,
prompt=self._jsonalyze_prompt,
)
if self._verbose:
print_text(f"SQL Query: {sql_query}\n", color="blue")
print_text(f"Table Schema: {table_schema}\n", color="cyan")
print_text(f"SQL Response: {results}\n", color="yellow")
if self._synthesize_response:
response_str = await self._service_context.llm.apredict(
self._response_synthesis_prompt,
sql_query=sql_query,
table_schema=table_schema,
sql_response=results,
query_str=query_bundle.query_str,
)
if self._verbose:
print_text(f"Response: {response_str}", color="magenta")
else:
response_str = json.dumps(
{
"sql_query": sql_query,
"table_schema": table_schema,
"sql_response": results,
}
)
response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}
return Response(response=response_str, metadata=response_metadata)