233 lines
9.4 KiB
Python
233 lines
9.4 KiB
Python
"""SQL wrapper around SQLDatabase in langchain."""
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
from sqlalchemy import MetaData, create_engine, insert, inspect, text
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
|
|
|
|
|
class SQLDatabase:
|
|
"""SQL Database.
|
|
|
|
This class provides a wrapper around the SQLAlchemy engine to interact with a SQL
|
|
database.
|
|
It provides methods to execute SQL commands, insert data into tables, and retrieve
|
|
information about the database schema.
|
|
It also supports optional features such as including or excluding specific tables,
|
|
sampling rows for table info,
|
|
including indexes in table info, and supporting views.
|
|
|
|
Based on langchain SQLDatabase.
|
|
https://github.com/langchain-ai/langchain/blob/e355606b1100097665207ca259de6dc548d44c78/libs/langchain/langchain/utilities/sql_database.py#L39
|
|
|
|
Args:
|
|
engine (Engine): The SQLAlchemy engine instance to use for database operations.
|
|
schema (Optional[str]): The name of the schema to use, if any.
|
|
metadata (Optional[MetaData]): The metadata instance to use, if any.
|
|
ignore_tables (Optional[List[str]]): List of table names to ignore. If set,
|
|
include_tables must be None.
|
|
include_tables (Optional[List[str]]): List of table names to include. If set,
|
|
ignore_tables must be None.
|
|
sample_rows_in_table_info (int): The number of sample rows to include in table
|
|
info.
|
|
indexes_in_table_info (bool): Whether to include indexes in table info.
|
|
custom_table_info (Optional[dict]): Custom table info to use.
|
|
view_support (bool): Whether to support views.
|
|
max_string_length (int): The maximum string length to use.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Engine,
|
|
schema: Optional[str] = None,
|
|
metadata: Optional[MetaData] = None,
|
|
ignore_tables: Optional[List[str]] = None,
|
|
include_tables: Optional[List[str]] = None,
|
|
sample_rows_in_table_info: int = 3,
|
|
indexes_in_table_info: bool = False,
|
|
custom_table_info: Optional[dict] = None,
|
|
view_support: bool = False,
|
|
max_string_length: int = 300,
|
|
):
|
|
"""Create engine from database URI."""
|
|
self._engine = engine
|
|
self._schema = schema
|
|
if include_tables and ignore_tables:
|
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
|
|
|
self._inspector = inspect(self._engine)
|
|
|
|
# including view support by adding the views as well as tables to the all
|
|
# tables list if view_support is True
|
|
self._all_tables = set(
|
|
self._inspector.get_table_names(schema=schema)
|
|
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
|
|
)
|
|
|
|
self._include_tables = set(include_tables) if include_tables else set()
|
|
if self._include_tables:
|
|
missing_tables = self._include_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"include_tables {missing_tables} not found in database"
|
|
)
|
|
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
|
if self._ignore_tables:
|
|
missing_tables = self._ignore_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"ignore_tables {missing_tables} not found in database"
|
|
)
|
|
usable_tables = self.get_usable_table_names()
|
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
|
|
|
if not isinstance(sample_rows_in_table_info, int):
|
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
|
|
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
|
self._indexes_in_table_info = indexes_in_table_info
|
|
|
|
self._custom_table_info = custom_table_info
|
|
if self._custom_table_info:
|
|
if not isinstance(self._custom_table_info, dict):
|
|
raise TypeError(
|
|
"table_info must be a dictionary with table names as keys and the "
|
|
"desired table info as values"
|
|
)
|
|
# only keep the tables that are also present in the database
|
|
intersection = set(self._custom_table_info).intersection(self._all_tables)
|
|
self._custom_table_info = {
|
|
table: info
|
|
for table, info in self._custom_table_info.items()
|
|
if table in intersection
|
|
}
|
|
|
|
self._max_string_length = max_string_length
|
|
|
|
self._metadata = metadata or MetaData()
|
|
# including view support if view_support = true
|
|
self._metadata.reflect(
|
|
views=view_support,
|
|
bind=self._engine,
|
|
only=list(self._usable_tables),
|
|
schema=self._schema,
|
|
)
|
|
|
|
@property
|
|
def engine(self) -> Engine:
|
|
"""Return SQL Alchemy engine."""
|
|
return self._engine
|
|
|
|
@property
|
|
def metadata_obj(self) -> MetaData:
|
|
"""Return SQL Alchemy metadata."""
|
|
return self._metadata
|
|
|
|
@classmethod
|
|
def from_uri(
|
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
|
) -> "SQLDatabase":
|
|
"""Construct a SQLAlchemy engine from URI."""
|
|
_engine_args = engine_args or {}
|
|
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
|
|
|
@property
|
|
def dialect(self) -> str:
|
|
"""Return string representation of dialect to use."""
|
|
return self._engine.dialect.name
|
|
|
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
"""Get names of tables available."""
|
|
if self._include_tables:
|
|
return sorted(self._include_tables)
|
|
return sorted(self._all_tables - self._ignore_tables)
|
|
|
|
def get_table_columns(self, table_name: str) -> List[Any]:
|
|
"""Get table columns."""
|
|
return self._inspector.get_columns(table_name)
|
|
|
|
def get_single_table_info(self, table_name: str) -> str:
|
|
"""Get table info for a single table."""
|
|
# same logic as table_info, but with specific table names
|
|
template = (
|
|
"Table '{table_name}' has columns: {columns}, "
|
|
"and foreign keys: {foreign_keys}."
|
|
)
|
|
columns = []
|
|
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
|
if column.get("comment"):
|
|
columns.append(
|
|
f"{column['name']} ({column['type']!s}): "
|
|
f"'{column.get('comment')}'"
|
|
)
|
|
else:
|
|
columns.append(f"{column['name']} ({column['type']!s})")
|
|
|
|
column_str = ", ".join(columns)
|
|
foreign_keys = []
|
|
for foreign_key in self._inspector.get_foreign_keys(
|
|
table_name, schema=self._schema
|
|
):
|
|
foreign_keys.append(
|
|
f"{foreign_key['constrained_columns']} -> "
|
|
f"{foreign_key['referred_table']}.{foreign_key['referred_columns']}"
|
|
)
|
|
foreign_key_str = ", ".join(foreign_keys)
|
|
return template.format(
|
|
table_name=table_name, columns=column_str, foreign_keys=foreign_key_str
|
|
)
|
|
|
|
def insert_into_table(self, table_name: str, data: dict) -> None:
|
|
"""Insert data into a table."""
|
|
table = self._metadata.tables[table_name]
|
|
stmt = insert(table).values(**data)
|
|
with self._engine.begin() as connection:
|
|
connection.execute(stmt)
|
|
|
|
def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> str:
|
|
"""
|
|
Truncate a string to a certain number of words, based on the max string
|
|
length.
|
|
"""
|
|
if not isinstance(content, str) or length <= 0:
|
|
return content
|
|
|
|
if len(content) <= length:
|
|
return content
|
|
|
|
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
|
|
|
def run_sql(self, command: str) -> Tuple[str, Dict]:
|
|
"""Execute a SQL statement and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
"""
|
|
with self._engine.begin() as connection:
|
|
try:
|
|
if self._schema:
|
|
command = command.replace("FROM ", f"FROM {self._schema}.")
|
|
cursor = connection.execute(text(command))
|
|
except (ProgrammingError, OperationalError) as exc:
|
|
raise NotImplementedError(
|
|
f"Statement {command!r} is invalid SQL."
|
|
) from exc
|
|
if cursor.returns_rows:
|
|
result = cursor.fetchall()
|
|
# truncate the results to the max string length
|
|
# we can't use str(result) directly because it automatically truncates long strings
|
|
truncated_results = []
|
|
for row in result:
|
|
# truncate each column, then convert the row to a tuple
|
|
truncated_row = tuple(
|
|
self.truncate_word(column, length=self._max_string_length)
|
|
for column in row
|
|
)
|
|
truncated_results.append(truncated_row)
|
|
return str(truncated_results), {
|
|
"result": truncated_results,
|
|
"col_keys": list(cursor.keys()),
|
|
}
|
|
return "", {}
|