346 lines
13 KiB
Python
346 lines
13 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.response.schema import Response
|
|
from llama_index.indices.struct_store.sql_retriever import (
|
|
BaseSQLParser,
|
|
DefaultSQLParser,
|
|
)
|
|
from llama_index.prompts import BasePromptTemplate, PromptTemplate
|
|
from llama_index.prompts.default_prompts import DEFAULT_JSONALYZE_PROMPT
|
|
from llama_index.prompts.mixin import PromptDictType, PromptMixinType
|
|
from llama_index.prompts.prompt_type import PromptType
|
|
from llama_index.schema import QueryBundle
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.utils import print_text
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL = (
|
|
"Given a query, synthesize a response based on SQL query results"
|
|
" to satisfy the query. Only include details that are relevant to"
|
|
" the query. If you don't know the answer, then say that.\n"
|
|
"SQL Query: {sql_query}\n"
|
|
"Table Schema: {table_schema}\n"
|
|
"SQL Response: {sql_response}\n"
|
|
"Query: {query_str}\n"
|
|
"Response: "
|
|
)
|
|
|
|
DEFAULT_RESPONSE_SYNTHESIS_PROMPT = PromptTemplate(
|
|
DEFAULT_RESPONSE_SYNTHESIS_PROMPT_TMPL,
|
|
prompt_type=PromptType.SQL_RESPONSE_SYNTHESIS,
|
|
)
|
|
|
|
DEFAULT_TABLE_NAME = "items"
|
|
|
|
|
|
def default_jsonalyzer(
|
|
list_of_dict: List[Dict[str, Any]],
|
|
query_bundle: QueryBundle,
|
|
service_context: ServiceContext,
|
|
table_name: str = DEFAULT_TABLE_NAME,
|
|
prompt: BasePromptTemplate = DEFAULT_JSONALYZE_PROMPT,
|
|
sql_parser: BaseSQLParser = DefaultSQLParser(),
|
|
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
|
|
"""Default JSONalyzer that executes a query on a list of dictionaries.
|
|
|
|
Args:
|
|
list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
|
|
query_bundle (QueryBundle): The query bundle.
|
|
service_context (Optional[ServiceContext]): The service context.
|
|
table_name (str): The table name to use, defaults to DEFAULT_TABLE_NAME.
|
|
prompt (BasePromptTemplate): The prompt to use.
|
|
sql_parser (BaseSQLParser): The SQL parser to use.
|
|
|
|
Returns:
|
|
Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
|
|
the Schema, and the Result.
|
|
"""
|
|
try:
|
|
import sqlite_utils
|
|
except ImportError as exc:
|
|
IMPORT_ERROR_MSG = (
|
|
"sqlite-utils is needed to use this Query Engine:\n"
|
|
"pip install sqlite-utils"
|
|
)
|
|
|
|
raise ImportError(IMPORT_ERROR_MSG) from exc
|
|
# Instantiate in-memory SQLite database
|
|
db = sqlite_utils.Database(memory=True)
|
|
try:
|
|
# Load list of dictionaries into SQLite database
|
|
db[table_name].insert_all(list_of_dict)
|
|
except sqlite_utils.db_exceptions.IntegrityError as exc:
|
|
print_text(f"Error inserting into table {table_name}, expected format:")
|
|
print_text("[{col1: val1, col2: val2, ...}, ...]")
|
|
raise ValueError("Invalid list_of_dict") from exc
|
|
|
|
# Get the table schema
|
|
table_schema = db[table_name].columns_dict
|
|
|
|
query = query_bundle.query_str
|
|
prompt = prompt or DEFAULT_JSONALYZE_PROMPT
|
|
# Get the SQL query with text-to-SQL prompt
|
|
response_str = service_context.llm.predict(
|
|
prompt=prompt,
|
|
table_name=table_name,
|
|
table_schema=table_schema,
|
|
question=query,
|
|
)
|
|
|
|
sql_parser = sql_parser or DefaultSQLParser()
|
|
|
|
sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)
|
|
|
|
try:
|
|
# Execute the SQL query
|
|
results = list(db.query(sql_query))
|
|
except sqlite_utils.db_exceptions.OperationalError as exc:
|
|
print_text(f"Error executing query: {sql_query}")
|
|
raise ValueError("Invalid query") from exc
|
|
|
|
return sql_query, table_schema, results
|
|
|
|
|
|
async def async_default_jsonalyzer(
|
|
list_of_dict: List[Dict[str, Any]],
|
|
query_bundle: QueryBundle,
|
|
service_context: ServiceContext,
|
|
prompt: Optional[BasePromptTemplate] = None,
|
|
sql_parser: Optional[BaseSQLParser] = None,
|
|
table_name: str = DEFAULT_TABLE_NAME,
|
|
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
|
|
"""Default JSONalyzer.
|
|
|
|
Args:
|
|
list_of_dict (List[Dict[str, Any]]): List of dictionaries to query.
|
|
query_bundle (QueryBundle): The query bundle.
|
|
service_context (ServiceContext): ServiceContext
|
|
prompt (BasePromptTemplate, optional): The prompt to use.
|
|
sql_parser (BaseSQLParser, optional): The SQL parser to use.
|
|
table_name (str, optional): The table name to use, defaults to DEFAULT_TABLE_NAME.
|
|
|
|
Returns:
|
|
Tuple[str, Dict[str, Any], List[Dict[str, Any]]]: The SQL Query,
|
|
the Schema, and the Result.
|
|
"""
|
|
try:
|
|
import sqlite_utils
|
|
except ImportError as exc:
|
|
IMPORT_ERROR_MSG = (
|
|
"sqlite-utils is needed to use this Query Engine:\n"
|
|
"pip install sqlite-utils"
|
|
)
|
|
|
|
raise ImportError(IMPORT_ERROR_MSG) from exc
|
|
# Instantiate in-memory SQLite database
|
|
db = sqlite_utils.Database(memory=True)
|
|
try:
|
|
# Load list of dictionaries into SQLite database
|
|
db[table_name].insert_all(list_of_dict)
|
|
except sqlite_utils.db_exceptions.IntegrityError as exc:
|
|
print_text(f"Error inserting into table {table_name}, expected format:")
|
|
print_text("[{col1: val1, col2: val2, ...}, ...]")
|
|
raise ValueError("Invalid list_of_dict") from exc
|
|
|
|
# Get the table schema
|
|
table_schema = db[table_name].columns_dict
|
|
|
|
query = query_bundle.query_str
|
|
prompt = prompt or DEFAULT_JSONALYZE_PROMPT
|
|
# Get the SQL query with text-to-SQL prompt
|
|
response_str = await service_context.llm.apredict(
|
|
prompt=prompt,
|
|
table_name=table_name,
|
|
table_schema=table_schema,
|
|
question=query,
|
|
)
|
|
|
|
sql_parser = sql_parser or DefaultSQLParser()
|
|
|
|
sql_query = sql_parser.parse_response_to_sql(response_str, query_bundle)
|
|
|
|
try:
|
|
# Execute the SQL query
|
|
results = list(db.query(sql_query))
|
|
except sqlite_utils.db_exceptions.OperationalError as exc:
|
|
print_text(f"Error executing query: {sql_query}")
|
|
raise ValueError("Invalid query") from exc
|
|
|
|
return sql_query, table_schema, results
|
|
|
|
|
|
def load_jsonalyzer(
|
|
use_async: bool = False,
|
|
custom_jsonalyzer: Optional[Callable] = None,
|
|
) -> Callable:
|
|
"""Load the JSONalyzer.
|
|
|
|
Args:
|
|
use_async (bool): Whether to use async.
|
|
custom_jsonalyzer (Callable): A custom JSONalyzer to use.
|
|
|
|
Returns:
|
|
Callable: The JSONalyzer.
|
|
"""
|
|
if custom_jsonalyzer:
|
|
assert not use_async or asyncio.iscoroutinefunction(
|
|
custom_jsonalyzer
|
|
), "custom_jsonalyzer function must be async when use_async is True"
|
|
return custom_jsonalyzer
|
|
else:
|
|
# make mypy happy to indent this
|
|
if use_async:
|
|
return async_default_jsonalyzer
|
|
else:
|
|
return default_jsonalyzer
|
|
|
|
|
|
class JSONalyzeQueryEngine(BaseQueryEngine):
|
|
"""JSON List Shape Data Analysis Query Engine.
|
|
|
|
Converts natural language statasical queries to SQL within in-mem SQLite queries.
|
|
|
|
list_of_dict(List[Dict[str, Any]]): List of dictionaries to query.
|
|
service_context (ServiceContext): ServiceContext
|
|
jsonalyze_prompt (BasePromptTemplate): The JSONalyze prompt to use.
|
|
use_async (bool): Whether to use async.
|
|
analyzer (Callable): The analyzer that executes the query.
|
|
sql_parser (BaseSQLParser): The SQL parser that ensures valid SQL being parsed
|
|
from llm output.
|
|
synthesize_response (bool): Whether to synthesize a response.
|
|
response_synthesis_prompt (BasePromptTemplate): The response synthesis prompt
|
|
to use.
|
|
table_name (str): The table name to use.
|
|
verbose (bool): Whether to print verbose output.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
list_of_dict: List[Dict[str, Any]],
|
|
service_context: ServiceContext,
|
|
jsonalyze_prompt: Optional[BasePromptTemplate] = None,
|
|
use_async: bool = False,
|
|
analyzer: Optional[Callable] = None,
|
|
sql_parser: Optional[BaseSQLParser] = None,
|
|
synthesize_response: bool = True,
|
|
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
|
|
table_name: str = DEFAULT_TABLE_NAME,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
self._list_of_dict = list_of_dict
|
|
self._service_context = service_context or ServiceContext.from_defaults()
|
|
self._jsonalyze_prompt = jsonalyze_prompt or DEFAULT_JSONALYZE_PROMPT
|
|
self._use_async = use_async
|
|
self._analyzer = load_jsonalyzer(use_async, analyzer)
|
|
self._sql_parser = sql_parser or DefaultSQLParser()
|
|
self._synthesize_response = synthesize_response
|
|
self._response_synthesis_prompt = (
|
|
response_synthesis_prompt or DEFAULT_RESPONSE_SYNTHESIS_PROMPT
|
|
)
|
|
self._table_name = table_name
|
|
self._verbose = verbose
|
|
|
|
super().__init__(self._service_context.callback_manager)
|
|
|
|
def _get_prompts(self) -> Dict[str, Any]:
|
|
"""Get prompts."""
|
|
return {
|
|
"jsonalyze_prompt": self._jsonalyze_prompt,
|
|
"response_synthesis_prompt": self._response_synthesis_prompt,
|
|
}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "jsonalyze_prompt" in prompts:
|
|
self._jsonalyze_prompt = prompts["jsonalyze_prompt"]
|
|
if "response_synthesis_prompt" in prompts:
|
|
self._response_synthesis_prompt = prompts["response_synthesis_prompt"]
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
return {}
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> Response:
|
|
"""Answer an analytical query on the JSON List."""
|
|
query = query_bundle.query_str
|
|
if self._verbose:
|
|
print_text(f"Query: {query}\n", color="green")
|
|
|
|
# Perform the analysis
|
|
sql_query, table_schema, results = self._analyzer(
|
|
self._list_of_dict,
|
|
query_bundle,
|
|
self._service_context,
|
|
table_name=self._table_name,
|
|
prompt=self._jsonalyze_prompt,
|
|
sql_parser=self._sql_parser,
|
|
)
|
|
if self._verbose:
|
|
print_text(f"SQL Query: {sql_query}\n", color="blue")
|
|
print_text(f"Table Schema: {table_schema}\n", color="cyan")
|
|
print_text(f"SQL Response: {results}\n", color="yellow")
|
|
|
|
if self._synthesize_response:
|
|
response_str = self._service_context.llm.predict(
|
|
self._response_synthesis_prompt,
|
|
sql_query=sql_query,
|
|
table_schema=table_schema,
|
|
sql_response=results,
|
|
query_str=query_bundle.query_str,
|
|
)
|
|
if self._verbose:
|
|
print_text(f"Response: {response_str}", color="magenta")
|
|
else:
|
|
response_str = str(results)
|
|
response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}
|
|
|
|
return Response(response=response_str, metadata=response_metadata)
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> Response:
|
|
"""Answer an analytical query on the JSON List."""
|
|
query = query_bundle.query_str
|
|
if self._verbose:
|
|
print_text(f"Query: {query}", color="green")
|
|
|
|
# Perform the analysis
|
|
sql_query, table_schema, results = self._analyzer(
|
|
self._list_of_dict,
|
|
query,
|
|
self._service_context,
|
|
table_name=self._table_name,
|
|
prompt=self._jsonalyze_prompt,
|
|
)
|
|
if self._verbose:
|
|
print_text(f"SQL Query: {sql_query}\n", color="blue")
|
|
print_text(f"Table Schema: {table_schema}\n", color="cyan")
|
|
print_text(f"SQL Response: {results}\n", color="yellow")
|
|
|
|
if self._synthesize_response:
|
|
response_str = await self._service_context.llm.apredict(
|
|
self._response_synthesis_prompt,
|
|
sql_query=sql_query,
|
|
table_schema=table_schema,
|
|
sql_response=results,
|
|
query_str=query_bundle.query_str,
|
|
)
|
|
if self._verbose:
|
|
print_text(f"Response: {response_str}", color="magenta")
|
|
else:
|
|
response_str = json.dumps(
|
|
{
|
|
"sql_query": sql_query,
|
|
"table_schema": table_schema,
|
|
"sql_response": results,
|
|
}
|
|
)
|
|
response_metadata = {"sql_query": sql_query, "table_schema": str(table_schema)}
|
|
|
|
return Response(response=response_str, metadata=response_metadata)
|