faiss_rag_enterprise/llama_index/vector_stores/lantern.py

641 lines
21 KiB
Python

import logging
from typing import Any, List, NamedTuple, Optional, Type
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
class DBEmbeddingRow(NamedTuple):
node_id: str # FIXME: verify this type hint
text: str
metadata: dict
similarity: float
_logger = logging.getLogger(__name__)
def get_data_model(
base: Type,
index_name: str,
schema_name: str,
hybrid_search: bool,
text_search_config: str,
cache_okay: bool,
embed_dim: int = 1536,
m: int = 16,
ef_construction: int = 128,
ef: int = 64,
) -> Any:
"""
This part create a dynamic sqlalchemy model with a new table.
"""
from sqlalchemy import Column, Computed
from sqlalchemy.dialects.postgresql import (
ARRAY,
BIGINT,
JSON,
REAL,
TSVECTOR,
VARCHAR,
)
from sqlalchemy.schema import Index
from sqlalchemy.types import TypeDecorator
class TSVector(TypeDecorator):
impl = TSVECTOR
cache_ok = cache_okay
tablename = "data_%s" % index_name # dynamic table name
class_name = "Data%s" % index_name # dynamic class name
indexname = "%s_idx" % index_name # dynamic index name
hnsw_indexname = "%s_hnsw_idx" % index_name # dynamic hnsw index name
if hybrid_search:
class HybridAbstractData(base): # type: ignore
__abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False)
metadata_ = Column(JSON)
node_id = Column(VARCHAR)
embedding = Column(ARRAY(REAL, embed_dim)) # type: ignore
text_search_tsv = Column( # type: ignore
TSVector(),
Computed(
"to_tsvector('%s', text)" % text_search_config, persisted=True
),
)
model = type(
class_name,
(HybridAbstractData,),
{"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
)
Index(
indexname,
model.text_search_tsv, # type: ignore
postgresql_using="gin",
)
else:
class AbstractData(base): # type: ignore
__abstract__ = True # this line is necessary
id = Column(BIGINT, primary_key=True, autoincrement=True)
text = Column(VARCHAR, nullable=False)
metadata_ = Column(JSON)
node_id = Column(VARCHAR)
embedding = Column(ARRAY(REAL, embed_dim)) # type: ignore
model = type(
class_name,
(AbstractData,),
{"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
)
Index(
hnsw_indexname,
model.embedding, # type: ignore
postgresql_using="hnsw",
postgresql_with={
"m": m,
"ef_construction": ef_construction,
"ef": ef,
"dim": embed_dim,
},
postgresql_ops={"embedding": "dist_cos_ops"},
)
return model
class LanternVectorStore(BasePydanticVectorStore):
from sqlalchemy.sql.selectable import Select
stores_text = True
flat_metadata = False
connection_string: str
async_connection_string: str
table_name: str
schema_name: str
embed_dim: int
hybrid_search: bool
text_search_config: str
cache_ok: bool
perform_setup: bool
debug: bool
_base: Any = PrivateAttr()
_table_class: Any = PrivateAttr()
_engine: Any = PrivateAttr()
_session: Any = PrivateAttr()
_async_engine: Any = PrivateAttr()
_async_session: Any = PrivateAttr()
_is_initialized: bool = PrivateAttr(default=False)
def __init__(
self,
connection_string: str,
async_connection_string: str,
table_name: str,
schema_name: str,
hybrid_search: bool = False,
text_search_config: str = "english",
embed_dim: int = 1536,
m: int = 16,
ef_construction: int = 128,
ef: int = 64,
cache_ok: bool = False,
perform_setup: bool = True,
debug: bool = False,
) -> None:
try:
import asyncpg # noqa
import psycopg2 # noqa
import sqlalchemy
import sqlalchemy.ext.asyncio # noqa
except ImportError:
raise ImportError(
"`sqlalchemy[asyncio]`, `psycopg2-binary` and `asyncpg` "
"packages should be pre installed"
)
table_name = table_name.lower()
schema_name = schema_name.lower()
if hybrid_search and text_search_config is None:
raise ValueError(
"Sparse vector index creation requires "
"a text search configuration specification."
)
from sqlalchemy.orm import declarative_base
# sqlalchemy model
self._base = declarative_base()
self._table_class = get_data_model(
self._base,
table_name,
schema_name,
hybrid_search,
text_search_config,
cache_ok,
embed_dim=embed_dim,
m=m,
ef_construction=ef_construction,
ef=ef,
)
super().__init__(
connection_string=connection_string,
async_connection_string=async_connection_string,
table_name=table_name,
schema_name=schema_name,
hybrid_search=hybrid_search,
text_search_config=text_search_config,
embed_dim=embed_dim,
cache_ok=cache_ok,
perform_setup=perform_setup,
debug=debug,
)
async def close(self) -> None:
if not self._is_initialized:
return
self._session.close_all()
self._engine.dispose()
await self._async_engine.dispose()
@classmethod
def class_name(cls) -> str:
return "LanternStore"
@classmethod
def from_params(
cls,
host: Optional[str] = None,
port: Optional[str] = None,
database: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
table_name: str = "llamaindex",
schema_name: str = "public",
connection_string: Optional[str] = None,
async_connection_string: Optional[str] = None,
hybrid_search: bool = False,
text_search_config: str = "english",
embed_dim: int = 1536,
m: int = 16,
ef_construction: int = 128,
ef: int = 64,
cache_ok: bool = False,
perform_setup: bool = True,
debug: bool = False,
) -> "LanternVectorStore":
"""Return connection string from database parameters."""
conn_str = (
connection_string
or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
)
async_conn_str = async_connection_string or (
f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
)
return cls(
connection_string=conn_str,
async_connection_string=async_conn_str,
table_name=table_name,
schema_name=schema_name,
hybrid_search=hybrid_search,
text_search_config=text_search_config,
embed_dim=embed_dim,
m=m,
ef_construction=ef_construction,
ef=ef,
cache_ok=cache_ok,
perform_setup=perform_setup,
debug=debug,
)
@property
def client(self) -> Any:
if not self._is_initialized:
return None
return self._engine
def _connect(self) -> Any:
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
self._engine = create_engine(self.connection_string, echo=self.debug)
self._session = sessionmaker(self._engine)
self._async_engine = create_async_engine(self.async_connection_string)
self._async_session = sessionmaker(self._async_engine, class_=AsyncSession) # type: ignore
def _create_schema_if_not_exists(self) -> None:
with self._session() as session, session.begin():
from sqlalchemy import text
statement = text(f"CREATE SCHEMA IF NOT EXISTS {self.schema_name}")
session.execute(statement)
session.commit()
def _create_tables_if_not_exists(self) -> None:
with self._session() as session, session.begin():
self._base.metadata.create_all(session.connection())
def _create_extension(self) -> None:
import sqlalchemy
with self._session() as session, session.begin():
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS lantern")
session.execute(statement)
session.commit()
def _initialize(self) -> None:
if not self._is_initialized:
self._connect()
if self.perform_setup:
self._create_extension()
self._create_schema_if_not_exists()
self._create_tables_if_not_exists()
self._is_initialized = True
def _node_to_table_row(self, node: BaseNode) -> Any:
return self._table_class(
node_id=node.node_id,
embedding=node.get_embedding(),
text=node.get_content(metadata_mode=MetadataMode.NONE),
metadata_=node_to_metadata_dict(
node,
remove_text=True,
flat_metadata=self.flat_metadata,
),
)
def add(self, nodes: List[BaseNode]) -> List[str]:
self._initialize()
ids = []
with self._session() as session, session.begin():
for node in nodes:
ids.append(node.node_id)
item = self._node_to_table_row(node)
session.add(item)
session.commit()
return ids
async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
self._initialize()
ids = []
async with self._async_session() as session, session.begin():
for node in nodes:
ids.append(node.node_id)
item = self._node_to_table_row(node)
session.add(item)
await session.commit()
return ids
def _apply_filters_and_limit(
self,
stmt: Select,
limit: int,
metadata_filters: Optional[MetadataFilters] = None,
) -> Any:
import sqlalchemy
if metadata_filters:
for filter_ in metadata_filters.legacy_filters():
bind_parameter = f"value_{filter_.key}"
stmt = stmt.where( # type: ignore
sqlalchemy.text(f"metadata_->>'{filter_.key}' = :{bind_parameter}")
)
stmt = stmt.params( # type: ignore
**{bind_parameter: str(filter_.value)}
)
return stmt.limit(limit) # type: ignore
def _build_query(
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> Any:
from sqlalchemy import func, select
stmt = select( # type: ignore
self._table_class,
func.cos_dist(self._table_class.embedding, embedding),
).order_by(self._table_class.embedding.op("<=>")(embedding))
return self._apply_filters_and_limit(stmt, limit, metadata_filters)
def _prepare_query(self, session: Any, limit: int) -> None:
from sqlalchemy import text
session.execute(text("SET enable_seqscan=OFF")) # always use index
session.execute(text(f"SET hnsw.init_k={limit}")) # always use index
async def _aprepare_query(self, session: Any, limit: int) -> None:
from sqlalchemy import text
await session.execute(text("SET enable_seqscan=OFF")) # always use index
await session.execute(text(f"SET hnsw.init_k={limit}")) # always use index
def _query_with_score(
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[DBEmbeddingRow]:
stmt = self._build_query(embedding, limit, metadata_filters)
with self._session() as session, session.begin():
self._prepare_query(session, limit)
res = session.execute(
stmt,
)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=(1 - distance) if distance is not None else 0,
)
for item, distance in res.all()
]
async def _aquery_with_score(
self,
embedding: Optional[List[float]],
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[DBEmbeddingRow]:
stmt = self._build_query(embedding, limit, metadata_filters)
async with self._async_session() as async_session, async_session.begin():
await self._aprepare_query(async_session, limit)
res = await async_session.execute(stmt)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=(1 - distance) if distance is not None else 0,
)
for item, distance in res.all()
]
def _build_sparse_query(
self,
query_str: Optional[str],
limit: int,
metadata_filters: Optional[MetadataFilters] = None,
) -> Any:
from sqlalchemy import select, type_coerce
from sqlalchemy.sql import func, text
from sqlalchemy.types import UserDefinedType
class REGCONFIG(UserDefinedType):
def get_col_spec(self, **kw: Any) -> str:
return "regconfig"
if query_str is None:
raise ValueError("query_str must be specified for a sparse vector query.")
ts_query = func.plainto_tsquery(
type_coerce(self.text_search_config, REGCONFIG), query_str
)
stmt = (
select( # type: ignore
self._table_class,
func.ts_rank(self._table_class.text_search_tsv, ts_query).label("rank"),
)
.where(self._table_class.text_search_tsv.op("@@")(ts_query))
.order_by(text("rank desc"))
)
# type: ignore
return self._apply_filters_and_limit(stmt, limit, metadata_filters)
async def _async_sparse_query_with_rank(
self,
query_str: Optional[str] = None,
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[DBEmbeddingRow]:
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
async with self._async_session() as async_session, async_session.begin():
res = await async_session.execute(stmt)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=rank,
)
for item, rank in res.all()
]
def _sparse_query_with_rank(
self,
query_str: Optional[str] = None,
limit: int = 10,
metadata_filters: Optional[MetadataFilters] = None,
) -> List[DBEmbeddingRow]:
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
with self._session() as session, session.begin():
res = session.execute(stmt)
return [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=rank,
)
for item, rank in res.all()
]
async def _async_hybrid_query(
self, query: VectorStoreQuery
) -> List[DBEmbeddingRow]:
import asyncio
if query.alpha is not None:
_logger.warning("postgres hybrid search does not support alpha parameter.")
sparse_top_k = query.sparse_top_k or query.similarity_top_k
results = await asyncio.gather(
self._aquery_with_score(
query.query_embedding, query.similarity_top_k, query.filters
),
self._async_sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
),
)
dense_results, sparse_results = results
all_results = dense_results + sparse_results
return _dedup_results(all_results)
def _hybrid_query(self, query: VectorStoreQuery) -> List[DBEmbeddingRow]:
if query.alpha is not None:
_logger.warning("postgres hybrid search does not support alpha parameter.")
sparse_top_k = query.sparse_top_k or query.similarity_top_k
dense_results = self._query_with_score(
query.query_embedding, query.similarity_top_k, query.filters
)
sparse_results = self._sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
)
all_results = dense_results + sparse_results
return _dedup_results(all_results)
def _db_rows_to_query_result(
self, rows: List[DBEmbeddingRow]
) -> VectorStoreQueryResult:
nodes = []
similarities = []
ids = []
for db_embedding_row in rows:
try:
node = metadata_dict_to_node(db_embedding_row.metadata)
node.set_content(str(db_embedding_row.text))
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
node = TextNode(
id_=db_embedding_row.node_id,
text=db_embedding_row.text,
metadata=db_embedding_row.metadata,
)
similarities.append(db_embedding_row.similarity)
ids.append(db_embedding_row.node_id)
nodes.append(node)
return VectorStoreQueryResult(
nodes=nodes,
similarities=similarities,
ids=ids,
)
async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
) -> VectorStoreQueryResult:
self._initialize()
if query.mode == VectorStoreQueryMode.HYBRID:
results = await self._async_hybrid_query(query)
elif query.mode in [
VectorStoreQueryMode.SPARSE,
VectorStoreQueryMode.TEXT_SEARCH,
]:
sparse_top_k = query.sparse_top_k or query.similarity_top_k
results = await self._async_sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
)
elif query.mode == VectorStoreQueryMode.DEFAULT:
results = await self._aquery_with_score(
query.query_embedding, query.similarity_top_k, query.filters
)
else:
raise ValueError(f"Invalid query mode: {query.mode}")
return self._db_rows_to_query_result(results)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
self._initialize()
if query.mode == VectorStoreQueryMode.HYBRID:
results = self._hybrid_query(query)
elif query.mode in [
VectorStoreQueryMode.SPARSE,
VectorStoreQueryMode.TEXT_SEARCH,
]:
sparse_top_k = query.sparse_top_k or query.similarity_top_k
results = self._sparse_query_with_rank(
query.query_str, sparse_top_k, query.filters
)
elif query.mode == VectorStoreQueryMode.DEFAULT:
results = self._query_with_score(
query.query_embedding, query.similarity_top_k, query.filters
)
else:
raise ValueError(f"Invalid query mode: {query.mode}")
return self._db_rows_to_query_result(results)
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
import sqlalchemy
self._initialize()
with self._session() as session, session.begin():
stmt = sqlalchemy.text(
f"DELETE FROM {self.schema_name}.data_{self.table_name} where "
f"(metadata_->>'doc_id')::text = '{ref_doc_id}' "
)
session.execute(stmt)
session.commit()
def _dedup_results(results: List[DBEmbeddingRow]) -> List[DBEmbeddingRow]:
seen_ids = set()
deduped_results = []
for result in results:
if result.node_id not in seen_ids:
deduped_results.append(result)
seen_ids.add(result.node_id)
return deduped_results