162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
"""SQL Structured Store."""
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import Any, Optional, Sequence, Union
|
|
|
|
from sqlalchemy import Table
|
|
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.data_structs.table import SQLStructTable
|
|
from llama_index.indices.common.struct_store.schema import SQLContextContainer
|
|
from llama_index.indices.common.struct_store.sql import SQLStructDatapointExtractor
|
|
from llama_index.indices.struct_store.base import BaseStructStoreIndex
|
|
from llama_index.indices.struct_store.container_builder import (
|
|
SQLContextContainerBuilder,
|
|
)
|
|
from llama_index.schema import BaseNode
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.utilities.sql_wrapper import SQLDatabase
|
|
|
|
|
|
class SQLQueryMode(str, Enum):
|
|
SQL = "sql"
|
|
NL = "nl"
|
|
|
|
|
|
class SQLStructStoreIndex(BaseStructStoreIndex[SQLStructTable]):
|
|
"""SQL Struct Store Index.
|
|
|
|
The SQLStructStoreIndex is an index that uses a SQL database
|
|
under the hood. During index construction, the data can be inferred
|
|
from unstructured documents given a schema extract prompt,
|
|
or it can be pre-loaded in the database.
|
|
|
|
During query time, the user can either specify a raw SQL query
|
|
or a natural language query to retrieve their data.
|
|
|
|
NOTE: this is deprecated.
|
|
|
|
Args:
|
|
documents (Optional[Sequence[DOCUMENTS_INPUT]]): Documents to index.
|
|
NOTE: in the SQL index, this is an optional field.
|
|
sql_database (Optional[SQLDatabase]): SQL database to use,
|
|
including table names to specify.
|
|
See :ref:`Ref-Struct-Store` for more details.
|
|
table_name (Optional[str]): Name of the table to use
|
|
for extracting data.
|
|
Either table_name or table must be specified.
|
|
table (Optional[Table]): SQLAlchemy Table object to use.
|
|
Specifying the Table object explicitly, instead of
|
|
the table name, allows you to pass in a view.
|
|
Either table_name or table must be specified.
|
|
sql_context_container (Optional[SQLContextContainer]): SQL context container.
|
|
an be generated from a SQLContextContainerBuilder.
|
|
See :ref:`Ref-Struct-Store` for more details.
|
|
|
|
"""
|
|
|
|
index_struct_cls = SQLStructTable
|
|
|
|
def __init__(
|
|
self,
|
|
nodes: Optional[Sequence[BaseNode]] = None,
|
|
index_struct: Optional[SQLStructTable] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
sql_database: Optional[SQLDatabase] = None,
|
|
table_name: Optional[str] = None,
|
|
table: Optional[Table] = None,
|
|
ref_doc_id_column: Optional[str] = None,
|
|
sql_context_container: Optional[SQLContextContainer] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
if sql_database is None:
|
|
raise ValueError("sql_database must be specified")
|
|
self.sql_database = sql_database
|
|
# needed here for data extractor
|
|
self._ref_doc_id_column = ref_doc_id_column
|
|
self._table_name = table_name
|
|
self._table = table
|
|
|
|
# if documents aren't specified, pass in a blank []
|
|
if index_struct is None:
|
|
nodes = nodes or []
|
|
|
|
super().__init__(
|
|
nodes=nodes,
|
|
index_struct=index_struct,
|
|
service_context=service_context,
|
|
**kwargs,
|
|
)
|
|
|
|
# TODO: index_struct context_dict is deprecated,
|
|
# we're migrating storage of information to here.
|
|
if sql_context_container is None:
|
|
container_builder = SQLContextContainerBuilder(sql_database)
|
|
sql_context_container = container_builder.build_context_container()
|
|
self.sql_context_container = sql_context_container
|
|
|
|
@property
|
|
def ref_doc_id_column(self) -> Optional[str]:
|
|
return self._ref_doc_id_column
|
|
|
|
def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> SQLStructTable:
|
|
"""Build index from nodes."""
|
|
index_struct = self.index_struct_cls()
|
|
if len(nodes) == 0:
|
|
return index_struct
|
|
else:
|
|
data_extractor = SQLStructDatapointExtractor(
|
|
self._service_context.llm,
|
|
self.schema_extract_prompt,
|
|
self.output_parser,
|
|
self.sql_database,
|
|
table_name=self._table_name,
|
|
table=self._table,
|
|
ref_doc_id_column=self._ref_doc_id_column,
|
|
)
|
|
# group nodes by ids
|
|
source_to_node = defaultdict(list)
|
|
for node in nodes:
|
|
source_to_node[node.ref_doc_id].append(node)
|
|
|
|
for node_set in source_to_node.values():
|
|
data_extractor.insert_datapoint_from_nodes(node_set)
|
|
return index_struct
|
|
|
|
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
|
|
"""Insert a document."""
|
|
data_extractor = SQLStructDatapointExtractor(
|
|
self._service_context.llm,
|
|
self.schema_extract_prompt,
|
|
self.output_parser,
|
|
self.sql_database,
|
|
table_name=self._table_name,
|
|
table=self._table,
|
|
ref_doc_id_column=self._ref_doc_id_column,
|
|
)
|
|
data_extractor.insert_datapoint_from_nodes(nodes)
|
|
|
|
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
|
raise NotImplementedError("Not supported")
|
|
|
|
def as_query_engine(
|
|
self, query_mode: Union[str, SQLQueryMode] = SQLQueryMode.NL, **kwargs: Any
|
|
) -> BaseQueryEngine:
|
|
# NOTE: lazy import
|
|
from llama_index.indices.struct_store.sql_query import (
|
|
NLStructStoreQueryEngine,
|
|
SQLStructStoreQueryEngine,
|
|
)
|
|
|
|
if query_mode == SQLQueryMode.NL:
|
|
return NLStructStoreQueryEngine(self, **kwargs)
|
|
elif query_mode == SQLQueryMode.SQL:
|
|
return SQLStructStoreQueryEngine(self, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown query mode: {query_mode}")
|
|
|
|
|
|
GPTSQLStructStoreIndex = SQLStructStoreIndex
|