faiss_rag_enterprise/llama_index/indices/struct_store/sql.py

162 lines
6.0 KiB
Python

"""SQL Structured Store."""
from collections import defaultdict
from enum import Enum
from typing import Any, Optional, Sequence, Union
from sqlalchemy import Table
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.data_structs.table import SQLStructTable
from llama_index.indices.common.struct_store.schema import SQLContextContainer
from llama_index.indices.common.struct_store.sql import SQLStructDatapointExtractor
from llama_index.indices.struct_store.base import BaseStructStoreIndex
from llama_index.indices.struct_store.container_builder import (
SQLContextContainerBuilder,
)
from llama_index.schema import BaseNode
from llama_index.service_context import ServiceContext
from llama_index.utilities.sql_wrapper import SQLDatabase
class SQLQueryMode(str, Enum):
SQL = "sql"
NL = "nl"
class SQLStructStoreIndex(BaseStructStoreIndex[SQLStructTable]):
"""SQL Struct Store Index.
The SQLStructStoreIndex is an index that uses a SQL database
under the hood. During index construction, the data can be inferred
from unstructured documents given a schema extract prompt,
or it can be pre-loaded in the database.
During query time, the user can either specify a raw SQL query
or a natural language query to retrieve their data.
NOTE: this is deprecated.
Args:
documents (Optional[Sequence[DOCUMENTS_INPUT]]): Documents to index.
NOTE: in the SQL index, this is an optional field.
sql_database (Optional[SQLDatabase]): SQL database to use,
including table names to specify.
See :ref:`Ref-Struct-Store` for more details.
table_name (Optional[str]): Name of the table to use
for extracting data.
Either table_name or table must be specified.
table (Optional[Table]): SQLAlchemy Table object to use.
Specifying the Table object explicitly, instead of
the table name, allows you to pass in a view.
Either table_name or table must be specified.
sql_context_container (Optional[SQLContextContainer]): SQL context container.
an be generated from a SQLContextContainerBuilder.
See :ref:`Ref-Struct-Store` for more details.
"""
index_struct_cls = SQLStructTable
def __init__(
self,
nodes: Optional[Sequence[BaseNode]] = None,
index_struct: Optional[SQLStructTable] = None,
service_context: Optional[ServiceContext] = None,
sql_database: Optional[SQLDatabase] = None,
table_name: Optional[str] = None,
table: Optional[Table] = None,
ref_doc_id_column: Optional[str] = None,
sql_context_container: Optional[SQLContextContainer] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
if sql_database is None:
raise ValueError("sql_database must be specified")
self.sql_database = sql_database
# needed here for data extractor
self._ref_doc_id_column = ref_doc_id_column
self._table_name = table_name
self._table = table
# if documents aren't specified, pass in a blank []
if index_struct is None:
nodes = nodes or []
super().__init__(
nodes=nodes,
index_struct=index_struct,
service_context=service_context,
**kwargs,
)
# TODO: index_struct context_dict is deprecated,
# we're migrating storage of information to here.
if sql_context_container is None:
container_builder = SQLContextContainerBuilder(sql_database)
sql_context_container = container_builder.build_context_container()
self.sql_context_container = sql_context_container
@property
def ref_doc_id_column(self) -> Optional[str]:
return self._ref_doc_id_column
def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> SQLStructTable:
"""Build index from nodes."""
index_struct = self.index_struct_cls()
if len(nodes) == 0:
return index_struct
else:
data_extractor = SQLStructDatapointExtractor(
self._service_context.llm,
self.schema_extract_prompt,
self.output_parser,
self.sql_database,
table_name=self._table_name,
table=self._table,
ref_doc_id_column=self._ref_doc_id_column,
)
# group nodes by ids
source_to_node = defaultdict(list)
for node in nodes:
source_to_node[node.ref_doc_id].append(node)
for node_set in source_to_node.values():
data_extractor.insert_datapoint_from_nodes(node_set)
return index_struct
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
"""Insert a document."""
data_extractor = SQLStructDatapointExtractor(
self._service_context.llm,
self.schema_extract_prompt,
self.output_parser,
self.sql_database,
table_name=self._table_name,
table=self._table,
ref_doc_id_column=self._ref_doc_id_column,
)
data_extractor.insert_datapoint_from_nodes(nodes)
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError("Not supported")
def as_query_engine(
self, query_mode: Union[str, SQLQueryMode] = SQLQueryMode.NL, **kwargs: Any
) -> BaseQueryEngine:
# NOTE: lazy import
from llama_index.indices.struct_store.sql_query import (
NLStructStoreQueryEngine,
SQLStructStoreQueryEngine,
)
if query_mode == SQLQueryMode.NL:
return NLStructStoreQueryEngine(self, **kwargs)
elif query_mode == SQLQueryMode.SQL:
return SQLStructStoreQueryEngine(self, **kwargs)
else:
raise ValueError(f"Unknown query mode: {query_mode}")
GPTSQLStructStoreIndex = SQLStructStoreIndex