700 lines
23 KiB
Python
700 lines
23 KiB
Python
import logging
|
|
from typing import Any, List, NamedTuple, Optional, Type
|
|
|
|
from llama_index.bridge.pydantic import PrivateAttr
|
|
from llama_index.schema import BaseNode, MetadataMode, TextNode
|
|
from llama_index.vector_stores.types import (
|
|
BasePydanticVectorStore,
|
|
FilterOperator,
|
|
MetadataFilters,
|
|
VectorStoreQuery,
|
|
VectorStoreQueryMode,
|
|
VectorStoreQueryResult,
|
|
)
|
|
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
|
|
|
|
|
|
class DBEmbeddingRow(NamedTuple):
|
|
node_id: str # FIXME: verify this type hint
|
|
text: str
|
|
metadata: dict
|
|
similarity: float
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_data_model(
|
|
base: Type,
|
|
index_name: str,
|
|
schema_name: str,
|
|
hybrid_search: bool,
|
|
text_search_config: str,
|
|
cache_okay: bool,
|
|
embed_dim: int = 1536,
|
|
use_jsonb: bool = False,
|
|
) -> Any:
|
|
"""
|
|
This part create a dynamic sqlalchemy model with a new table.
|
|
"""
|
|
from pgvector.sqlalchemy import Vector
|
|
from sqlalchemy import Column, Computed
|
|
from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR
|
|
from sqlalchemy.schema import Index
|
|
from sqlalchemy.types import TypeDecorator
|
|
|
|
class TSVector(TypeDecorator):
|
|
impl = TSVECTOR
|
|
cache_ok = cache_okay
|
|
|
|
tablename = "data_%s" % index_name # dynamic table name
|
|
class_name = "Data%s" % index_name # dynamic class name
|
|
indexname = "%s_idx" % index_name # dynamic class name
|
|
|
|
metadata_dtype = JSONB if use_jsonb else JSON
|
|
|
|
if hybrid_search:
|
|
|
|
class HybridAbstractData(base): # type: ignore
|
|
__abstract__ = True # this line is necessary
|
|
id = Column(BIGINT, primary_key=True, autoincrement=True)
|
|
text = Column(VARCHAR, nullable=False)
|
|
metadata_ = Column(metadata_dtype)
|
|
node_id = Column(VARCHAR)
|
|
embedding = Column(Vector(embed_dim)) # type: ignore
|
|
text_search_tsv = Column( # type: ignore
|
|
TSVector(),
|
|
Computed(
|
|
"to_tsvector('%s', text)" % text_search_config, persisted=True
|
|
),
|
|
)
|
|
|
|
model = type(
|
|
class_name,
|
|
(HybridAbstractData,),
|
|
{"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
|
|
)
|
|
|
|
Index(
|
|
indexname,
|
|
model.text_search_tsv, # type: ignore
|
|
postgresql_using="gin",
|
|
)
|
|
else:
|
|
|
|
class AbstractData(base): # type: ignore
|
|
__abstract__ = True # this line is necessary
|
|
id = Column(BIGINT, primary_key=True, autoincrement=True)
|
|
text = Column(VARCHAR, nullable=False)
|
|
metadata_ = Column(metadata_dtype)
|
|
node_id = Column(VARCHAR)
|
|
embedding = Column(Vector(embed_dim)) # type: ignore
|
|
|
|
model = type(
|
|
class_name,
|
|
(AbstractData,),
|
|
{"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
class PGVectorStore(BasePydanticVectorStore):
|
|
from sqlalchemy.sql.selectable import Select
|
|
|
|
stores_text = True
|
|
flat_metadata = False
|
|
|
|
connection_string: str
|
|
async_connection_string: str
|
|
table_name: str
|
|
schema_name: str
|
|
embed_dim: int
|
|
hybrid_search: bool
|
|
text_search_config: str
|
|
cache_ok: bool
|
|
perform_setup: bool
|
|
debug: bool
|
|
use_jsonb: bool
|
|
|
|
_base: Any = PrivateAttr()
|
|
_table_class: Any = PrivateAttr()
|
|
_engine: Any = PrivateAttr()
|
|
_session: Any = PrivateAttr()
|
|
_async_engine: Any = PrivateAttr()
|
|
_async_session: Any = PrivateAttr()
|
|
_is_initialized: bool = PrivateAttr(default=False)
|
|
|
|
def __init__(
|
|
self,
|
|
connection_string: str,
|
|
async_connection_string: str,
|
|
table_name: str,
|
|
schema_name: str,
|
|
hybrid_search: bool = False,
|
|
text_search_config: str = "english",
|
|
embed_dim: int = 1536,
|
|
cache_ok: bool = False,
|
|
perform_setup: bool = True,
|
|
debug: bool = False,
|
|
use_jsonb: bool = False,
|
|
) -> None:
|
|
try:
|
|
import asyncpg # noqa
|
|
import pgvector # noqa
|
|
import psycopg2 # noqa
|
|
import sqlalchemy
|
|
import sqlalchemy.ext.asyncio # noqa
|
|
except ImportError:
|
|
raise ImportError(
|
|
"`sqlalchemy[asyncio]`, `pgvector`, `psycopg2-binary` and `asyncpg` "
|
|
"packages should be pre installed"
|
|
)
|
|
|
|
table_name = table_name.lower()
|
|
schema_name = schema_name.lower()
|
|
|
|
if hybrid_search and text_search_config is None:
|
|
raise ValueError(
|
|
"Sparse vector index creation requires "
|
|
"a text search configuration specification."
|
|
)
|
|
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
# sqlalchemy model
|
|
self._base = declarative_base()
|
|
self._table_class = get_data_model(
|
|
self._base,
|
|
table_name,
|
|
schema_name,
|
|
hybrid_search,
|
|
text_search_config,
|
|
cache_ok,
|
|
embed_dim=embed_dim,
|
|
use_jsonb=use_jsonb,
|
|
)
|
|
|
|
super().__init__(
|
|
connection_string=connection_string,
|
|
async_connection_string=async_connection_string,
|
|
table_name=table_name,
|
|
schema_name=schema_name,
|
|
hybrid_search=hybrid_search,
|
|
text_search_config=text_search_config,
|
|
embed_dim=embed_dim,
|
|
cache_ok=cache_ok,
|
|
perform_setup=perform_setup,
|
|
debug=debug,
|
|
use_jsonb=use_jsonb,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
if not self._is_initialized:
|
|
return
|
|
|
|
self._session.close_all()
|
|
self._engine.dispose()
|
|
|
|
await self._async_engine.dispose()
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "PGVectorStore"
|
|
|
|
@classmethod
|
|
def from_params(
|
|
cls,
|
|
host: Optional[str] = None,
|
|
port: Optional[str] = None,
|
|
database: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
table_name: str = "llamaindex",
|
|
schema_name: str = "public",
|
|
connection_string: Optional[str] = None,
|
|
async_connection_string: Optional[str] = None,
|
|
hybrid_search: bool = False,
|
|
text_search_config: str = "english",
|
|
embed_dim: int = 1536,
|
|
cache_ok: bool = False,
|
|
perform_setup: bool = True,
|
|
debug: bool = False,
|
|
use_jsonb: bool = False,
|
|
) -> "PGVectorStore":
|
|
"""Return connection string from database parameters."""
|
|
conn_str = (
|
|
connection_string
|
|
or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
|
)
|
|
async_conn_str = async_connection_string or (
|
|
f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
|
|
)
|
|
return cls(
|
|
connection_string=conn_str,
|
|
async_connection_string=async_conn_str,
|
|
table_name=table_name,
|
|
schema_name=schema_name,
|
|
hybrid_search=hybrid_search,
|
|
text_search_config=text_search_config,
|
|
embed_dim=embed_dim,
|
|
cache_ok=cache_ok,
|
|
perform_setup=perform_setup,
|
|
debug=debug,
|
|
use_jsonb=use_jsonb,
|
|
)
|
|
|
|
@property
|
|
def client(self) -> Any:
|
|
if not self._is_initialized:
|
|
return None
|
|
return self._engine
|
|
|
|
def _connect(self) -> Any:
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
self._engine = create_engine(self.connection_string, echo=self.debug)
|
|
self._session = sessionmaker(self._engine)
|
|
|
|
self._async_engine = create_async_engine(self.async_connection_string)
|
|
self._async_session = sessionmaker(self._async_engine, class_=AsyncSession) # type: ignore
|
|
|
|
def _create_schema_if_not_exists(self) -> None:
|
|
with self._session() as session, session.begin():
|
|
from sqlalchemy import text
|
|
|
|
# Check if the specified schema exists with "CREATE" statement
|
|
check_schema_statement = text(
|
|
f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{self.schema_name}'"
|
|
)
|
|
result = session.execute(check_schema_statement).fetchone()
|
|
|
|
# If the schema does not exist, then create it
|
|
if not result:
|
|
create_schema_statement = text(
|
|
f"CREATE SCHEMA IF NOT EXISTS {self.schema_name}"
|
|
)
|
|
session.execute(create_schema_statement)
|
|
|
|
session.commit()
|
|
|
|
def _create_tables_if_not_exists(self) -> None:
|
|
with self._session() as session, session.begin():
|
|
self._base.metadata.create_all(session.connection())
|
|
|
|
def _create_extension(self) -> None:
|
|
import sqlalchemy
|
|
|
|
with self._session() as session, session.begin():
|
|
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")
|
|
session.execute(statement)
|
|
session.commit()
|
|
|
|
def _initialize(self) -> None:
|
|
if not self._is_initialized:
|
|
self._connect()
|
|
if self.perform_setup:
|
|
self._create_extension()
|
|
self._create_schema_if_not_exists()
|
|
self._create_tables_if_not_exists()
|
|
self._is_initialized = True
|
|
|
|
def _node_to_table_row(self, node: BaseNode) -> Any:
|
|
return self._table_class(
|
|
node_id=node.node_id,
|
|
embedding=node.get_embedding(),
|
|
text=node.get_content(metadata_mode=MetadataMode.NONE),
|
|
metadata_=node_to_metadata_dict(
|
|
node,
|
|
remove_text=True,
|
|
flat_metadata=self.flat_metadata,
|
|
),
|
|
)
|
|
|
|
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
|
|
self._initialize()
|
|
ids = []
|
|
with self._session() as session, session.begin():
|
|
for node in nodes:
|
|
ids.append(node.node_id)
|
|
item = self._node_to_table_row(node)
|
|
session.add(item)
|
|
session.commit()
|
|
return ids
|
|
|
|
async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
|
|
self._initialize()
|
|
ids = []
|
|
async with self._async_session() as session, session.begin():
|
|
for node in nodes:
|
|
ids.append(node.node_id)
|
|
item = self._node_to_table_row(node)
|
|
session.add(item)
|
|
await session.commit()
|
|
return ids
|
|
|
|
def _to_postgres_operator(self, operator: FilterOperator) -> str:
|
|
if operator == FilterOperator.EQ:
|
|
return "="
|
|
elif operator == FilterOperator.GT:
|
|
return ">"
|
|
elif operator == FilterOperator.LT:
|
|
return "<"
|
|
elif operator == FilterOperator.NE:
|
|
return "!="
|
|
elif operator == FilterOperator.GTE:
|
|
return ">="
|
|
elif operator == FilterOperator.LTE:
|
|
return "<="
|
|
elif operator == FilterOperator.IN:
|
|
return "@>"
|
|
else:
|
|
_logger.warning(f"Unknown operator: {operator}, fallback to '='")
|
|
return "="
|
|
|
|
def _apply_filters_and_limit(
|
|
self,
|
|
stmt: Select,
|
|
limit: int,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
) -> Any:
|
|
import sqlalchemy
|
|
|
|
sqlalchemy_conditions = {
|
|
"or": sqlalchemy.sql.or_,
|
|
"and": sqlalchemy.sql.and_,
|
|
}
|
|
|
|
if metadata_filters:
|
|
if metadata_filters.condition not in sqlalchemy_conditions:
|
|
raise ValueError(
|
|
f"Invalid condition: {metadata_filters.condition}. "
|
|
f"Must be one of {list(sqlalchemy_conditions.keys())}"
|
|
)
|
|
stmt = stmt.where( # type: ignore
|
|
sqlalchemy_conditions[metadata_filters.condition](
|
|
*(
|
|
(
|
|
sqlalchemy.text(
|
|
f"metadata_::jsonb->'{filter_.key}' "
|
|
f"{self._to_postgres_operator(filter_.operator)} "
|
|
f"'[\"{filter_.value}\"]'"
|
|
)
|
|
if filter_.operator == FilterOperator.IN
|
|
else sqlalchemy.text(
|
|
f"metadata_->>'{filter_.key}' "
|
|
f"{self._to_postgres_operator(filter_.operator)} "
|
|
f"'{filter_.value}'"
|
|
)
|
|
)
|
|
for filter_ in metadata_filters.filters
|
|
)
|
|
)
|
|
)
|
|
return stmt.limit(limit) # type: ignore
|
|
|
|
def _build_query(
|
|
self,
|
|
embedding: Optional[List[float]],
|
|
limit: int = 10,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
) -> Any:
|
|
from sqlalchemy import select, text
|
|
|
|
stmt = select( # type: ignore
|
|
self._table_class.id,
|
|
self._table_class.node_id,
|
|
self._table_class.text,
|
|
self._table_class.metadata_,
|
|
self._table_class.embedding.cosine_distance(embedding).label("distance"),
|
|
).order_by(text("distance asc"))
|
|
|
|
return self._apply_filters_and_limit(stmt, limit, metadata_filters)
|
|
|
|
def _query_with_score(
|
|
self,
|
|
embedding: Optional[List[float]],
|
|
limit: int = 10,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
**kwargs: Any,
|
|
) -> List[DBEmbeddingRow]:
|
|
stmt = self._build_query(embedding, limit, metadata_filters)
|
|
with self._session() as session, session.begin():
|
|
from sqlalchemy import text
|
|
|
|
if kwargs.get("ivfflat_probes"):
|
|
session.execute(
|
|
text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
|
|
)
|
|
if kwargs.get("hnsw_ef_search"):
|
|
session.execute(
|
|
text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
|
|
)
|
|
|
|
res = session.execute(
|
|
stmt,
|
|
)
|
|
return [
|
|
DBEmbeddingRow(
|
|
node_id=item.node_id,
|
|
text=item.text,
|
|
metadata=item.metadata_,
|
|
similarity=(1 - item.distance) if item.distance is not None else 0,
|
|
)
|
|
for item in res.all()
|
|
]
|
|
|
|
async def _aquery_with_score(
|
|
self,
|
|
embedding: Optional[List[float]],
|
|
limit: int = 10,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
**kwargs: Any,
|
|
) -> List[DBEmbeddingRow]:
|
|
stmt = self._build_query(embedding, limit, metadata_filters)
|
|
async with self._async_session() as async_session, async_session.begin():
|
|
from sqlalchemy import text
|
|
|
|
if kwargs.get("hnsw_ef_search"):
|
|
await async_session.execute(
|
|
text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
|
|
)
|
|
if kwargs.get("ivfflat_probes"):
|
|
await async_session.execute(
|
|
text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
|
|
)
|
|
|
|
res = await async_session.execute(stmt)
|
|
return [
|
|
DBEmbeddingRow(
|
|
node_id=item.node_id,
|
|
text=item.text,
|
|
metadata=item.metadata_,
|
|
similarity=(1 - item.distance) if item.distance is not None else 0,
|
|
)
|
|
for item in res.all()
|
|
]
|
|
|
|
def _build_sparse_query(
|
|
self,
|
|
query_str: Optional[str],
|
|
limit: int,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
) -> Any:
|
|
from sqlalchemy import select, type_coerce
|
|
from sqlalchemy.sql import func, text
|
|
from sqlalchemy.types import UserDefinedType
|
|
|
|
class REGCONFIG(UserDefinedType):
|
|
def get_col_spec(self, **kw: Any) -> str:
|
|
return "regconfig"
|
|
|
|
if query_str is None:
|
|
raise ValueError("query_str must be specified for a sparse vector query.")
|
|
|
|
ts_query = func.plainto_tsquery(
|
|
type_coerce(self.text_search_config, REGCONFIG), query_str
|
|
)
|
|
stmt = (
|
|
select( # type: ignore
|
|
self._table_class.id,
|
|
self._table_class.node_id,
|
|
self._table_class.text,
|
|
self._table_class.metadata_,
|
|
func.ts_rank(self._table_class.text_search_tsv, ts_query).label("rank"),
|
|
)
|
|
.where(self._table_class.text_search_tsv.op("@@")(ts_query))
|
|
.order_by(text("rank desc"))
|
|
)
|
|
|
|
# type: ignore
|
|
return self._apply_filters_and_limit(stmt, limit, metadata_filters)
|
|
|
|
async def _async_sparse_query_with_rank(
|
|
self,
|
|
query_str: Optional[str] = None,
|
|
limit: int = 10,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
) -> List[DBEmbeddingRow]:
|
|
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
|
|
async with self._async_session() as async_session, async_session.begin():
|
|
res = await async_session.execute(stmt)
|
|
return [
|
|
DBEmbeddingRow(
|
|
node_id=item.node_id,
|
|
text=item.text,
|
|
metadata=item.metadata_,
|
|
similarity=item.rank,
|
|
)
|
|
for item in res.all()
|
|
]
|
|
|
|
def _sparse_query_with_rank(
|
|
self,
|
|
query_str: Optional[str] = None,
|
|
limit: int = 10,
|
|
metadata_filters: Optional[MetadataFilters] = None,
|
|
) -> List[DBEmbeddingRow]:
|
|
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
|
|
with self._session() as session, session.begin():
|
|
res = session.execute(stmt)
|
|
return [
|
|
DBEmbeddingRow(
|
|
node_id=item.node_id,
|
|
text=item.text,
|
|
metadata=item.metadata_,
|
|
similarity=item.rank,
|
|
)
|
|
for item in res.all()
|
|
]
|
|
|
|
async def _async_hybrid_query(
|
|
self, query: VectorStoreQuery, **kwargs: Any
|
|
) -> List[DBEmbeddingRow]:
|
|
import asyncio
|
|
|
|
if query.alpha is not None:
|
|
_logger.warning("postgres hybrid search does not support alpha parameter.")
|
|
|
|
sparse_top_k = query.sparse_top_k or query.similarity_top_k
|
|
|
|
results = await asyncio.gather(
|
|
self._aquery_with_score(
|
|
query.query_embedding,
|
|
query.similarity_top_k,
|
|
query.filters,
|
|
**kwargs,
|
|
),
|
|
self._async_sparse_query_with_rank(
|
|
query.query_str, sparse_top_k, query.filters
|
|
),
|
|
)
|
|
|
|
dense_results, sparse_results = results
|
|
all_results = dense_results + sparse_results
|
|
return _dedup_results(all_results)
|
|
|
|
def _hybrid_query(
|
|
self, query: VectorStoreQuery, **kwargs: Any
|
|
) -> List[DBEmbeddingRow]:
|
|
if query.alpha is not None:
|
|
_logger.warning("postgres hybrid search does not support alpha parameter.")
|
|
|
|
sparse_top_k = query.sparse_top_k or query.similarity_top_k
|
|
|
|
dense_results = self._query_with_score(
|
|
query.query_embedding,
|
|
query.similarity_top_k,
|
|
query.filters,
|
|
**kwargs,
|
|
)
|
|
|
|
sparse_results = self._sparse_query_with_rank(
|
|
query.query_str, sparse_top_k, query.filters
|
|
)
|
|
|
|
all_results = dense_results + sparse_results
|
|
return _dedup_results(all_results)
|
|
|
|
def _db_rows_to_query_result(
|
|
self, rows: List[DBEmbeddingRow]
|
|
) -> VectorStoreQueryResult:
|
|
nodes = []
|
|
similarities = []
|
|
ids = []
|
|
for db_embedding_row in rows:
|
|
try:
|
|
node = metadata_dict_to_node(db_embedding_row.metadata)
|
|
node.set_content(str(db_embedding_row.text))
|
|
except Exception:
|
|
# NOTE: deprecated legacy logic for backward compatibility
|
|
node = TextNode(
|
|
id_=db_embedding_row.node_id,
|
|
text=db_embedding_row.text,
|
|
metadata=db_embedding_row.metadata,
|
|
)
|
|
similarities.append(db_embedding_row.similarity)
|
|
ids.append(db_embedding_row.node_id)
|
|
nodes.append(node)
|
|
|
|
return VectorStoreQueryResult(
|
|
nodes=nodes,
|
|
similarities=similarities,
|
|
ids=ids,
|
|
)
|
|
|
|
async def aquery(
|
|
self, query: VectorStoreQuery, **kwargs: Any
|
|
) -> VectorStoreQueryResult:
|
|
self._initialize()
|
|
if query.mode == VectorStoreQueryMode.HYBRID:
|
|
results = await self._async_hybrid_query(query, **kwargs)
|
|
elif query.mode in [
|
|
VectorStoreQueryMode.SPARSE,
|
|
VectorStoreQueryMode.TEXT_SEARCH,
|
|
]:
|
|
sparse_top_k = query.sparse_top_k or query.similarity_top_k
|
|
results = await self._async_sparse_query_with_rank(
|
|
query.query_str, sparse_top_k, query.filters
|
|
)
|
|
elif query.mode == VectorStoreQueryMode.DEFAULT:
|
|
results = await self._aquery_with_score(
|
|
query.query_embedding,
|
|
query.similarity_top_k,
|
|
query.filters,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid query mode: {query.mode}")
|
|
|
|
return self._db_rows_to_query_result(results)
|
|
|
|
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
|
self._initialize()
|
|
if query.mode == VectorStoreQueryMode.HYBRID:
|
|
results = self._hybrid_query(query, **kwargs)
|
|
elif query.mode in [
|
|
VectorStoreQueryMode.SPARSE,
|
|
VectorStoreQueryMode.TEXT_SEARCH,
|
|
]:
|
|
sparse_top_k = query.sparse_top_k or query.similarity_top_k
|
|
results = self._sparse_query_with_rank(
|
|
query.query_str, sparse_top_k, query.filters
|
|
)
|
|
elif query.mode == VectorStoreQueryMode.DEFAULT:
|
|
results = self._query_with_score(
|
|
query.query_embedding,
|
|
query.similarity_top_k,
|
|
query.filters,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid query mode: {query.mode}")
|
|
|
|
return self._db_rows_to_query_result(results)
|
|
|
|
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
|
import sqlalchemy
|
|
|
|
self._initialize()
|
|
with self._session() as session, session.begin():
|
|
stmt = sqlalchemy.text(
|
|
f"DELETE FROM {self.schema_name}.data_{self.table_name} where "
|
|
f"(metadata_->>'doc_id')::text = '{ref_doc_id}' "
|
|
)
|
|
|
|
session.execute(stmt)
|
|
session.commit()
|
|
|
|
|
|
def _dedup_results(results: List[DBEmbeddingRow]) -> List[DBEmbeddingRow]:
|
|
seen_ids = set()
|
|
deduped_results = []
|
|
for result in results:
|
|
if result.node_id not in seen_ids:
|
|
deduped_results.append(result)
|
|
seen_ids.add(result.node_id)
|
|
return deduped_results
|