faiss_rag_enterprise/llama_index/retrievers/pathway_retriever.py

172 lines
6.1 KiB
Python

"""Pathway Retriever."""
import logging
from typing import Any, Callable, List, Optional, Tuple, Union
from llama_index.callbacks.base import CallbackManager
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.base_retriever import BaseRetriever
from llama_index.embeddings import BaseEmbedding
from llama_index.indices.query.schema import QueryBundle
from llama_index.ingestion.pipeline import run_transformations
from llama_index.schema import (
BaseNode,
NodeWithScore,
QueryBundle,
TextNode,
TransformComponent,
)
logger = logging.getLogger(__name__)
def node_transformer(x: str) -> List[BaseNode]:
return [TextNode(text=x)]
def node_to_pathway(x: BaseNode) -> List[Tuple[str, dict]]:
return [(node.text, node.extra_info) for node in x]
class PathwayVectorServer:
"""
Build an autoupdating document indexing pipeline
for approximate nearest neighbor search.
Args:
docs (list): Pathway tables, may be pw.io connectors or custom tables.
transformations (List[TransformComponent]): list of transformation steps, has to
include embedding as last step, optionally splitter and other
TransformComponent in the middle
parser (Callable[[bytes], list[tuple[str, dict]]]): optional, callable that
parses file contents into a list of documents. If None, defaults to `uft-8` decoding of the file contents. Defaults to None.
"""
def __init__(
self,
*docs: Any,
transformations: List[Union[TransformComponent, Callable[[Any], Any]]],
parser: Optional[Callable[[bytes], List[Tuple[str, dict]]]] = None,
**kwargs: Any,
) -> None:
try:
from pathway.xpacks.llm import vector_store
except ImportError:
raise ImportError(
"Could not import pathway python package. "
"Please install it with `pip install pathway`."
)
if transformations is None or not transformations:
raise ValueError("Transformations list cannot be None or empty.")
if not isinstance(transformations[-1], BaseEmbedding):
raise ValueError(
f"Last step of transformations should be an instance of {BaseEmbedding.__name__}, "
f"found {type(transformations[-1])}."
)
embedder: BaseEmbedding = transformations.pop() # type: ignore
def embedding_callable(x: str) -> List[float]:
return embedder.get_text_embedding(x)
transformations.insert(0, node_transformer)
transformations.append(node_to_pathway) # TextNode -> (str, dict)
def generic_transformer(x: List[str]) -> List[Tuple[str, dict]]:
return run_transformations(x, transformations) # type: ignore
self.vector_store_server = vector_store.VectorStoreServer(
*docs,
embedder=embedding_callable,
parser=parser,
splitter=generic_transformer,
**kwargs,
)
def run_server(
self,
host: str,
port: str,
threaded: bool = False,
with_cache: bool = True,
cache_backend: Any = None,
) -> Any:
"""
Run the server and start answering queries.
Args:
host (str): host to bind the HTTP listener
port (str | int): port to bind the HTTP listener
threaded (bool): if True, run in a thread. Else block computation
with_cache (bool): if True, embedding requests for the same contents are cached
cache_backend: the backend to use for caching if it is enabled. The
default is the disk cache, hosted locally in the folder ``./Cache``. You
can use ``Backend`` class of the [`persistence API`]
(/developers/api-docs/persistence-api/#pathway.persistence.Backend)
to override it.
Returns:
If threaded, return the Thread object. Else, does not return.
"""
try:
import pathway as pw
except ImportError:
raise ImportError(
"Could not import pathway python package. "
"Please install it with `pip install pathway`."
)
if with_cache and cache_backend is None:
cache_backend = pw.persistence.Backend.filesystem("./Cache")
return self.vector_store_server.run_server(
host,
port,
threaded=threaded,
with_cache=with_cache,
cache_backend=cache_backend,
)
class PathwayRetriever(BaseRetriever):
"""Pathway retriever.
Pathway is an open data processing framework.
It allows you to easily develop data transformation pipelines
that work with live data sources and changing data.
This is the client that implements Retriever API for PathwayVectorServer.
"""
def __init__(
self,
host: str,
port: Union[str, int],
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""Initializing the Pathway retriever client."""
import_err_msg = "`pathway` package not found, please run `pip install pathway`"
try:
from pathway.xpacks.llm.vector_store import VectorStoreClient
except ImportError:
raise ImportError(import_err_msg)
self.client = VectorStoreClient(host, port)
self.similarity_top_k = similarity_top_k
super().__init__(callback_manager)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve."""
rets = self.client(query=query_bundle.query_str, k=self.similarity_top_k)
items = [
NodeWithScore(
node=TextNode(text=ret["text"], extra_info=ret["metadata"]),
# Transform cosine distance into a similairty score
# (higher is more similar)
score=1 - ret["dist"],
)
for ret in rets
]
return sorted(items, key=lambda x: x.score or 0.0, reverse=True)