faiss_rag_enterprise/llama_index/selectors/embedding_selectors.py

92 lines
3.0 KiB
Python

from typing import Any, Dict, Optional, Sequence
from llama_index.core.base_selector import (
BaseSelector,
SelectorResult,
SingleSelection,
)
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.utils import resolve_embed_model
from llama_index.indices.query.embedding_utils import get_top_k_embeddings
from llama_index.prompts.mixin import PromptDictType
from llama_index.schema import QueryBundle
from llama_index.tools.types import ToolMetadata
class EmbeddingSingleSelector(BaseSelector):
"""Embedding selector.
Embedding selector that chooses one out of many options.
Args:
embed_model (BaseEmbedding): An embedding model.
"""
def __init__(
self,
embed_model: BaseEmbedding,
) -> None:
self._embed_model = embed_model
@classmethod
def from_defaults(
cls,
embed_model: Optional[BaseEmbedding] = None,
) -> "EmbeddingSingleSelector":
# optionally initialize defaults
embed_model = embed_model or resolve_embed_model("default")
# construct prompt
return cls(embed_model)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
query_embedding = self._embed_model.get_query_embedding(query.query_str)
text_embeddings = [
self._embed_model.get_text_embedding(choice.description)
for choice in choices
]
top_similarities, top_ids = get_top_k_embeddings(
query_embedding,
text_embeddings,
similarity_top_k=1,
embedding_ids=list(range(len(choices))),
)
# get top choice
top_selection_reason = f"Top similarity match: {top_similarities[0]:.2f}, {choices[top_ids[0]].name}"
top_selection = SingleSelection(index=top_ids[0], reason=top_selection_reason)
# parse output
return SelectorResult(selections=[top_selection])
async def _aselect(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
query_embedding = await self._embed_model.aget_query_embedding(query.query_str)
text_embeddings = [
await self._embed_model.aget_text_embedding(choice.description)
for choice in choices
]
top_similarities, top_ids = get_top_k_embeddings(
query_embedding,
text_embeddings,
similarity_top_k=1,
embedding_ids=list(range(len(choices))),
)
# get top choice
top_selection_reason = f"Top similarity match: {top_similarities[0]:.2f}, {choices[top_ids[0]].name}"
top_selection = SingleSelection(index=top_ids[0], reason=top_selection_reason)
# parse output
return SelectorResult(selections=[top_selection])