faiss_rag_enterprise/llama_index/indices/multi_modal/retriever.py

368 lines
15 KiB
Python

"""Base vector store index query."""
import asyncio
from typing import Any, Dict, List, Optional
from llama_index.callbacks.base import CallbackManager
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.base_multi_modal_retriever import (
MultiModalRetriever,
)
from llama_index.data_structs.data_structs import IndexDict
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.multi_modal_base import MultiModalEmbedding
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
from llama_index.indices.utils import log_vector_store_query_result
from llama_index.schema import NodeWithScore, ObjectType, QueryBundle, QueryType
from llama_index.vector_stores.types import (
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
class MultiModalVectorIndexRetriever(MultiModalRetriever):
"""Multi Modal Vector index retriever.
Args:
index (MultiModalVectorIndexRetriever): Multi Modal vector store index for images and texts.
similarity_top_k (int): number of top k results to return.
vector_store_query_mode (str): vector store query mode
See reference for VectorStoreQueryMode for full list of supported modes.
filters (Optional[MetadataFilters]): metadata filters, defaults to None
alpha (float): weight for sparse/dense retrieval, only used for
hybrid query mode.
doc_ids (Optional[List[str]]): list of documents to constrain search.
vector_store_kwargs (dict): Additional vector store specific kwargs to pass
through to the vector store at query time.
"""
def __init__(
self,
index: MultiModalVectorStoreIndex,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
image_similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT,
filters: Optional[MetadataFilters] = None,
alpha: Optional[float] = None,
node_ids: Optional[List[str]] = None,
doc_ids: Optional[List[str]] = None,
sparse_top_k: Optional[int] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._index = index
self._vector_store = self._index.vector_store
# separate image vector store for image retrieval
self._image_vector_store = self._index.image_vector_store
assert isinstance(self._index.image_embed_model, BaseEmbedding)
self._image_embed_model = self._index.image_embed_model
self._service_context = self._index.service_context
self._docstore = self._index.docstore
self._similarity_top_k = similarity_top_k
self._image_similarity_top_k = image_similarity_top_k
self._vector_store_query_mode = VectorStoreQueryMode(vector_store_query_mode)
self._alpha = alpha
self._node_ids = node_ids
self._doc_ids = doc_ids
self._filters = filters
self._sparse_top_k = sparse_top_k
self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {})
self.callback_manager = callback_manager or CallbackManager([])
@property
def similarity_top_k(self) -> int:
"""Return similarity top k."""
return self._similarity_top_k
@similarity_top_k.setter
def similarity_top_k(self, similarity_top_k: int) -> None:
"""Set similarity top k."""
self._similarity_top_k = similarity_top_k
@property
def image_similarity_top_k(self) -> int:
"""Return image similarity top k."""
return self._image_similarity_top_k
@image_similarity_top_k.setter
def image_similarity_top_k(self, image_similarity_top_k: int) -> None:
"""Set image similarity top k."""
self._image_similarity_top_k = image_similarity_top_k
def _build_vector_store_query(
self, query_bundle_with_embeddings: QueryBundle, similarity_top_k: int
) -> VectorStoreQuery:
return VectorStoreQuery(
query_embedding=query_bundle_with_embeddings.embedding,
similarity_top_k=similarity_top_k,
node_ids=self._node_ids,
doc_ids=self._doc_ids,
query_str=query_bundle_with_embeddings.query_str,
mode=self._vector_store_query_mode,
alpha=self._alpha,
filters=self._filters,
sparse_top_k=self._sparse_top_k,
)
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
res = []
# If text vector store is not empty, retrieve text nodes
# If text vector store is empty, please create index without text vector store
if self._vector_store is not None:
res.extend(self._text_retrieve(query_bundle))
# If image vector store is not empty, retrieve text nodes
# If image vector store is empty, please create index without image vector store
if self._image_vector_store is not None:
res.extend(self._text_to_image_retrieve(query_bundle))
return res
def _text_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_text_vector_store_empty:
if self._vector_store.is_embedding_query:
if (
query_bundle.embedding is None
and len(query_bundle.embedding_strs) > 0
):
query_bundle.embedding = self._service_context.embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
return self._get_nodes_with_embeddings(
query_bundle, self._similarity_top_k, self._vector_store
)
else:
return []
def text_retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return self._text_retrieve(str_or_query_bundle)
def _text_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
self._image_embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []
def text_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return self._text_to_image_retrieve(str_or_query_bundle)
def _image_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
query_bundle.embedding = self._image_embed_model.get_image_embedding(
query_bundle.embedding_image[0]
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []
def image_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(
query_str="", image_path=str_or_query_bundle
)
return self._image_to_image_retrieve(str_or_query_bundle)
def _get_nodes_with_embeddings(
self,
query_bundle_with_embeddings: QueryBundle,
similarity_top_k: int,
vector_store: VectorStore,
) -> List[NodeWithScore]:
query = self._build_vector_store_query(
query_bundle_with_embeddings, similarity_top_k
)
query_result = vector_store.query(query, **self._kwargs)
return self._build_node_list_from_query_result(query_result)
def _build_node_list_from_query_result(
self, query_result: VectorStoreQueryResult
) -> List[NodeWithScore]:
if query_result.nodes is None:
# NOTE: vector store does not keep text and returns node indices.
# Need to recover all nodes from docstore
if query_result.ids is None:
raise ValueError(
"Vector store query result should return at "
"least one of nodes or ids."
)
assert isinstance(self._index.index_struct, IndexDict)
node_ids = [
self._index.index_struct.nodes_dict[idx] for idx in query_result.ids
]
nodes = self._docstore.get_nodes(node_ids)
query_result.nodes = nodes
else:
# NOTE: vector store keeps text, returns nodes.
# Only need to recover image or index nodes from docstore
for i in range(len(query_result.nodes)):
source_node = query_result.nodes[i].source_node
if (not self._vector_store.stores_text) or (
source_node is not None and source_node.node_type != ObjectType.TEXT
):
node_id = query_result.nodes[i].node_id
if self._docstore.document_exists(node_id):
query_result.nodes[
i
] = self._docstore.get_node( # type: ignore[index]
node_id
)
log_vector_store_query_result(query_result)
node_with_scores: List[NodeWithScore] = []
for ind, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
node_with_scores.append(NodeWithScore(node=node, score=score))
return node_with_scores
# Async Retrieval Methods
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
# Run the two retrievals in async, and return their results as a concatenated list
results: List[NodeWithScore] = []
tasks = [
self._atext_retrieve(query_bundle),
self._atext_to_image_retrieve(query_bundle),
]
task_results = await asyncio.gather(*tasks)
for task_result in task_results:
results.extend(task_result)
return results
async def _atext_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_text_vector_store_empty:
if self._vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = await self._service_context.embed_model.aget_agg_embedding_from_queries(
query_bundle.embedding_strs
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._similarity_top_k, self._vector_store
)
else:
return []
async def atext_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return await self._atext_retrieve(str_or_query_bundle)
async def _atext_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
await self._image_embed_model.aget_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []
async def atext_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return await self._atext_to_image_retrieve(str_or_query_bundle)
async def _aget_nodes_with_embeddings(
self,
query_bundle_with_embeddings: QueryBundle,
similarity_top_k: int,
vector_store: VectorStore,
) -> List[NodeWithScore]:
query = self._build_vector_store_query(
query_bundle_with_embeddings, similarity_top_k
)
query_result = await vector_store.aquery(query, **self._kwargs)
return self._build_node_list_from_query_result(query_result)
async def _aimage_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
# Using the first imaage in the list for image retrieval
query_bundle.embedding = (
await self._image_embed_model.aget_image_embedding(
query_bundle.embedding_image[0]
)
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []
async def aimage_to_image_retrieve(
self, str_or_query_bundle: QueryType
) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
# leave query_str as empty since we are using image_path for image retrieval
str_or_query_bundle = QueryBundle(
query_str="", image_path=str_or_query_bundle
)
return await self._aimage_to_image_retrieve(str_or_query_bundle)