faiss_rag_enterprise/llama_index/vector_stores/postgres.py

700 lines
23 KiB
Python

import logging
from typing import Any, List, NamedTuple, Optional, Type
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
BasePydanticVectorStore,
FilterOperator,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
class DBEmbeddingRow(NamedTuple):
node_id: str # FIXME: verify this type hint
text: str
metadata: dict
similarity: float
_logger = logging.getLogger(__name__)
def get_data_model(
base: Type,
index_name: str,
schema_name: str,
hybrid_search: bool,
text_search_config: str,
cache_okay: bool,
embed_dim: int = 1536,
use_jsonb: bool = False,
) -> Any:
"""
This part create a dynamic sqlalchemy model with a new table.
"""
from pgvector.sqlalchemy import Vector
from sqlalchemy import Column, Computed
from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR
from sqlalchemy.schema import Index
from sqlalchemy.types import TypeDecorator
class TSVector(TypeDecorator):
impl = TSVECTOR
cache_ok = cache_okay
tablename = "data_%s" % index_name # dynamic table name
class_name = "Data%s" % index_name # dynamic class name
indexname = "%s_idx" % index_name # dynamic class name
metadata_dtype = JSONB if use_jsonb else JSON
if hybrid_search:
class HybridAbstractData(base): # type: ignore
__abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False)
metadata_ = Column(metadata_dtype)
node_id = Column(VARCHAR)
embedding = Column(Vector(embed_dim)) # type: ignore
text_search_tsv = Column( # type: ignore
TSVector(),
Computed(
"to_tsvector('%s', text)" % text_search_config, persisted=True
),
)
model = type(
class_name,
(HybridAbstractData,),
{"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
)
Index(
indexname,
model.text_search_tsv, # type: ignore
postgresql_using="gin",
)
else:
class AbstractData(base): # type: ignore
__abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False)
metadata_ = Column(metadata_dtype)
node_id = Column(VARCHAR)
embedding = Column(Vector(embed_dim)) # type: ignore
model = type(
class_name,
(AbstractData,),
{"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
)
return model
class PGVectorStore(BasePydanticVectorStore):
from sqlalchemy.sql.selectable import Select
stores_text = True
flat_metadata = False
connection_string: str
async_connection_string: str
table_name: str
schema_name: str
embed_dim: int
hybrid_search: bool
text_search_config: str
cache_ok: bool
perform_setup: bool
debug: bool
use_jsonb: bool
_base: Any = PrivateAttr()
_table_class: Any = PrivateAttr()
_engine: Any = PrivateAttr()
_session: Any = PrivateAttr()
_async_engine: Any = PrivateAttr()
_async_session: Any = PrivateAttr()
_is_initialized: bool = PrivateAttr(default=False)
def __init__(
self,
connection_string: str,
async_connection_string: str,
table_name: str,
schema_name: str,
hybrid_search: bool = False,
text_search_config: str = "english",
embed_dim: int = 1536,
cache_ok: bool = False,
perform_setup: bool = True,
debug: bool = False,
use_jsonb: bool = False,
) -> None:
try:
import asyncpg # noqa
import pgvector # noqa
import psycopg2 # noqa
import sqlalchemy
import sqlalchemy.ext.asyncio # noqa
except ImportError:
raise ImportError(
"`sqlalchemy[asyncio]`, `pgvector`, `psycopg2-binary` and `asyncpg` "
"packages should be pre installed"
)
table_name = table_name.lower()
schema_name = schema_name.lower()
if hybrid_search and text_search_config is None:
raise ValueError(
"Sparse vector index creation requires "
"a text search configuration specification."
)
from sqlalchemy.orm import declarative_base
# sqlalchemy model
self._base = declarative_base()
self._table_class = get_data_model(
self._base,
table_name,
schema_name,
hybrid_search,
text_search_config,
cache_ok,
embed_dim=embed_dim,
use_jsonb=use_jsonb,
)
super().__init__(
connection_string=connection_string,
async_connection_string=async_connection_string,
table_name=table_name,
schema_name=schema_name,
hybrid_search=hybrid_search,
text_search_config=text_search_config,
embed_dim=embed_dim,
cache_ok=cache_ok,
perform_setup=perform_setup,
debug=debug,
use_jsonb=use_jsonb,
)
async def close(self) -> None:
if not self._is_initialized:
return
self._session.close_all()
self._engine.dispose()
await self._async_engine.dispose()
@classmethod
def class_name(cls) -> str:
return "PGVectorStore"
@classmethod
def from_params(
cls,
host: Optional[str] = None,
port: Optional[str] = None,
database: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
table_name: str = "llamaindex",
schema_name: str = "public",
connection_string: Optional[str] = None,
async_connection_string: Optional[str] = None,
hybrid_search: bool = False,
text_search_config: str = "english",
embed_dim: int = 1536,
cache_ok: bool = False,
perform_setup: bool = True,
debug: bool = False,
use_jsonb: bool = False,
) -> "PGVectorStore":
"""Return connection string from database parameters."""
conn_str = (
connection_string
or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
)
async_conn_str = async_connection_string or (
f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
)
return cls(
connection_string=conn_str,
async_connection_string=async_conn_str,
table_name=table_name,
schema_name=schema_name,
hybrid_search=hybrid_search,
text_search_config=text_search_config,
embed_dim=embed_dim,
cache_ok=cache_ok,
perform_setup=perform_setup,
debug=debug,
use_jsonb=use_jsonb,
)
@property
def client(self) -> Any:
if not self._is_initialized:
return None
return self._engine
def _connect(self) -> Any:
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
self._engine = create_engine(self.connection_string, echo=self.debug)
self._session = sessionmaker(self._engine)
self._async_engine = create_async_engine(self.async_connection_string)
self._async_session = sessionmaker(self._async_engine, class_=AsyncSession) # type: ignore
def _create_schema_if_not_exists(self) -> None:
with self._session() as session, session.begin():
from sqlalchemy import text
# Check if the specified schema exists with "CREATE" statement
check_schema_statement = text(
f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{self.schema_name}'"
)
result = session.execute(check_schema_statement).fetchone()
# If the schema does not exist, then create it
if not result:
create_schema_statement = text(
f"CREATE SCHEMA IF NOT EXISTS {self.schema_name}"
)
session.execute(create_schema_statement)
session.commit()
def _create_tables_if_not_exists(self) -> None:
with self._session() as session, session.begin():
self._base.metadata.create_all(session.connection())
def _create_extension(self) -> None:
import sqlalchemy
with self._session() as session, session.begin():
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")
session.execute(statement)
session.commit()
def _initialize(self) -> None:
if not self._is_initialized:
self._connect()
if self.perform_setup:
self._create_extension()
self._create_schema_if_not_exists()
self._create_tables_if_not_exists()
self._is_initialized = True
def _node_to_table_row(self, node: BaseNode) -> Any:
return self._table_class(
node_id=node.node_id,
embedding=node.get_embedding(),
text=node.get_content(metadata_mode=MetadataMode.NONE),
metadata_=node_to_metadata_dict(
node,
remove_text=True,
flat_metadata=self.flat_metadata,
),
)
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
self._initialize()
ids = []
with self._session() as session, session.begin():
for node in nodes:
ids.append(node.node_id)
item = self._node_to_table_row(node)
session.add(item)
session.commit()
return ids
async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
self._initialize()
ids = []
async with self._async_session() as session, session.begin():
for node in nodes:
ids.append(node.node_id)
item = self._node_to_table_row(node)
session.add(item)
await session.commit()
return ids
def _to_postgres_operator(self, operator: FilterOperator) -> str:
if operator == FilterOperator.EQ:
return "="
elif operator == FilterOperator.GT:
return ">"
elif operator == FilterOperator.LT:
return "<"
elif operator == FilterOperator.NE:
return "!="
elif operator == FilterOperator.GTE:
return ">="
elif operator == FilterOperator.LTE:
return "<="
elif operator == FilterOperator.IN:
return "@>"
else:
_logger.warning(f"Unknown operator: {operator}, fallback to '='")
return "="
def _apply_filters_and_limit(
self,
stmt: Select,
limit: int,
metadata_filters: Optional[MetadataFilters] = None,
) -> Any:
import sqlalchemy
sqlalchemy_conditions = {
"or": sqlalchemy.sql.or_,
"and": sqlalchemy.sql.and_,
}
if metadata_filters:
if metadata_filters.condition not in sqlalchemy_conditions:
raise ValueError(
f"Invalid condition: {metadata_filters.condition}. "
f"Must be one of {list(sqlalchemy_conditions.keys())}"
)
stmt = stmt.where( # type: ignore
sqlalchemy_conditions[metadata_filters.condition](
*(
(
sqlalchemy.text(
f"metadata_::jsonb->'{filter_.key}' "
f"{self._to_postgres_operator(filter_.operator)} "
f"'[\"{filter_.value}\"]'"
)
if filter_.operator == FilterOperator.IN
else sqlalchemy.text(
f"metadata_->>'{filter_.key}' "
f"{self._to_postgres_operator(filter_.operator)} "
f"'{filter_.value}'"
)
)
for filter_ in metadata_filters.filters
)
)
)
return stmt.limit(limit) # type: ignore
def _build_query(
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> Any:
from sqlalchemy import select, text
stmt = select( # type: ignore
self._table_class.id,
self._table_class.node_id,
self._table_class.text,
self._table_class.metadata_,
self._table_class.embedding.cosine_distance(embedding).label("distance"),
).order_by(text("distance asc"))
return self._apply_filters_and_limit(stmt, limit, metadata_filters)
def _query_with_score(
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
**kwargs: Any,
) -> List[DBEmbeddingRow]:
stmt = self._build_query(embedding, limit, metadata_filters)
with self._session() as session, session.begin():
from sqlalchemy import text
if kwargs.get("ivfflat_probes"):
session.execute(
text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
)
if kwargs.get("hnsw_ef_search"):
session.execute(
text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
)
res = session.execute(
stmt,
)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=(1 - item.distance) if item.distance is not None else 0,
)
for item in res.all()
]
async def _aquery_with_score(
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
**kwargs: Any,
) -> List[DBEmbeddingRow]:
stmt = self._build_query(embedding, limit, metadata_filters)
async with self._async_session() as async_session, async_session.begin():
from sqlalchemy import text
if kwargs.get("hnsw_ef_search"):
await async_session.execute(
text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
)
if kwargs.get("ivfflat_probes"):
await async_session.execute(
text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
)
res = await async_session.execute(stmt)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=(1 - item.distance) if item.distance is not None else 0,
)
for item in res.all()
]
def _build_sparse_query(
self,
query_str: Optional[str],
limit: int,
metadata_filters: Optional[MetadataFilters] = None,
) -> Any:
from sqlalchemy import select, type_coerce
from sqlalchemy.sql import func, text
from sqlalchemy.types import UserDefinedType
class REGCONFIG(UserDefinedType):
def get_col_spec(self, **kw: Any) -> str:
return "regconfig"
if query_str is None:
raise ValueError("query_str must be specified for a sparse vector query.")
ts_query = func.plainto_tsquery(
type_coerce(self.text_search_config, REGCONFIG), query_str
)
stmt = (
select( # type: ignore
self._table_class.id,
self._table_class.node_id,
self._table_class.text,
self._table_class.metadata_,
func.ts_rank(self._table_class.text_search_tsv, ts_query).label("rank"),
)
.where(self._table_class.text_search_tsv.op("@@")(ts_query))
.order_by(text("rank desc"))
)
# type: ignore
return self._apply_filters_and_limit(stmt, limit, metadata_filters)
async def _async_sparse_query_with_rank(
self,
query_str: Optional[str] = None,
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[DBEmbeddingRow]:
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
async with self._async_session() as async_session, async_session.begin():
res = await async_session.execute(stmt)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=item.rank,
)
for item in res.all()
]
def _sparse_query_with_rank(
self,
query_str: Optional[str] = None,
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[DBEmbeddingRow]:
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
with self._session() as session, session.begin():
res = session.execute(stmt)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=item.rank,
)
for item in res.all()
]
async def _async_hybrid_query(
self, query: VectorStoreQuery, **kwargs: Any
) -> List[DBEmbeddingRow]:
import asyncio
if query.alpha is not None:
_logger.warning("postgres hybrid search does not support alpha parameter.")
sparse_top_k = query.sparse_top_k or query.similarity_top_k
results = await asyncio.gather(
self._aquery_with_score(
query.query_embedding,
query.similarity_top_k,
query.filters,
**kwargs,
),
self._async_sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
),
)
dense_results, sparse_results = results
all_results = dense_results + sparse_results
return _dedup_results(all_results)
def _hybrid_query(
self, query: VectorStoreQuery, **kwargs: Any
) -> List[DBEmbeddingRow]:
if query.alpha is not None:
_logger.warning("postgres hybrid search does not support alpha parameter.")
sparse_top_k = query.sparse_top_k or query.similarity_top_k
dense_results = self._query_with_score(
query.query_embedding,
query.similarity_top_k,
query.filters,
**kwargs,
)
sparse_results = self._sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
)
all_results = dense_results + sparse_results
return _dedup_results(all_results)
def _db_rows_to_query_result(
self, rows: List[DBEmbeddingRow]
) -> VectorStoreQueryResult:
nodes = []
similarities = []
ids = []
for db_embedding_row in rows:
try:
node = metadata_dict_to_node(db_embedding_row.metadata)
node.set_content(str(db_embedding_row.text))
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
node = TextNode(
id_=db_embedding_row.node_id,
text=db_embedding_row.text,
metadata=db_embedding_row.metadata,
)
similarities.append(db_embedding_row.similarity)
ids.append(db_embedding_row.node_id)
nodes.append(node)
return VectorStoreQueryResult(
nodes=nodes,
similarities=similarities,
ids=ids,
)
async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
) -> VectorStoreQueryResult:
self._initialize()
if query.mode == VectorStoreQueryMode.HYBRID:
results = await self._async_hybrid_query(query, **kwargs)
elif query.mode in [
VectorStoreQueryMode.SPARSE,
VectorStoreQueryMode.TEXT_SEARCH,
]:
sparse_top_k = query.sparse_top_k or query.similarity_top_k
results = await self._async_sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
)
elif query.mode == VectorStoreQueryMode.DEFAULT:
results = await self._aquery_with_score(
query.query_embedding,
query.similarity_top_k,
query.filters,
**kwargs,
)
else:
raise ValueError(f"Invalid query mode: {query.mode}")
return self._db_rows_to_query_result(results)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
self._initialize()
if query.mode == VectorStoreQueryMode.HYBRID:
results = self._hybrid_query(query, **kwargs)
elif query.mode in [
VectorStoreQueryMode.SPARSE,
VectorStoreQueryMode.TEXT_SEARCH,
]:
sparse_top_k = query.sparse_top_k or query.similarity_top_k
results = self._sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
)
elif query.mode == VectorStoreQueryMode.DEFAULT:
results = self._query_with_score(
query.query_embedding,
query.similarity_top_k,
query.filters,
**kwargs,
)
else:
raise ValueError(f"Invalid query mode: {query.mode}")
return self._db_rows_to_query_result(results)
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
import sqlalchemy
self._initialize()
with self._session() as session, session.begin():
stmt = sqlalchemy.text(
f"DELETE FROM {self.schema_name}.data_{self.table_name} where "
f"(metadata_->>'doc_id')::text = '{ref_doc_id}' "
)
session.execute(stmt)
session.commit()
def _dedup_results(results: List[DBEmbeddingRow]) -> List[DBEmbeddingRow]:
seen_ids = set()
deduped_results = []
for result in results:
if result.node_id not in seen_ids:
deduped_results.append(result)
seen_ids.add(result.node_id)
return deduped_results