92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
from typing import Any, Dict, Optional, Sequence
|
|
|
|
from llama_index.core.base_selector import (
|
|
BaseSelector,
|
|
SelectorResult,
|
|
SingleSelection,
|
|
)
|
|
from llama_index.embeddings.base import BaseEmbedding
|
|
from llama_index.embeddings.utils import resolve_embed_model
|
|
from llama_index.indices.query.embedding_utils import get_top_k_embeddings
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.schema import QueryBundle
|
|
from llama_index.tools.types import ToolMetadata
|
|
|
|
|
|
class EmbeddingSingleSelector(BaseSelector):
|
|
"""Embedding selector.
|
|
|
|
Embedding selector that chooses one out of many options.
|
|
|
|
Args:
|
|
embed_model (BaseEmbedding): An embedding model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_model: BaseEmbedding,
|
|
) -> None:
|
|
self._embed_model = embed_model
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
embed_model: Optional[BaseEmbedding] = None,
|
|
) -> "EmbeddingSingleSelector":
|
|
# optionally initialize defaults
|
|
embed_model = embed_model or resolve_embed_model("default")
|
|
|
|
# construct prompt
|
|
return cls(embed_model)
|
|
|
|
def _get_prompts(self) -> Dict[str, Any]:
|
|
"""Get prompts."""
|
|
return {}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
|
|
def _select(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
query_embedding = self._embed_model.get_query_embedding(query.query_str)
|
|
text_embeddings = [
|
|
self._embed_model.get_text_embedding(choice.description)
|
|
for choice in choices
|
|
]
|
|
|
|
top_similarities, top_ids = get_top_k_embeddings(
|
|
query_embedding,
|
|
text_embeddings,
|
|
similarity_top_k=1,
|
|
embedding_ids=list(range(len(choices))),
|
|
)
|
|
# get top choice
|
|
top_selection_reason = f"Top similarity match: {top_similarities[0]:.2f}, {choices[top_ids[0]].name}"
|
|
top_selection = SingleSelection(index=top_ids[0], reason=top_selection_reason)
|
|
|
|
# parse output
|
|
return SelectorResult(selections=[top_selection])
|
|
|
|
async def _aselect(
|
|
self, choices: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> SelectorResult:
|
|
query_embedding = await self._embed_model.aget_query_embedding(query.query_str)
|
|
text_embeddings = [
|
|
await self._embed_model.aget_text_embedding(choice.description)
|
|
for choice in choices
|
|
]
|
|
|
|
top_similarities, top_ids = get_top_k_embeddings(
|
|
query_embedding,
|
|
text_embeddings,
|
|
similarity_top_k=1,
|
|
embedding_ids=list(range(len(choices))),
|
|
)
|
|
# get top choice
|
|
top_selection_reason = f"Top similarity match: {top_similarities[0]:.2f}, {choices[top_ids[0]].name}"
|
|
top_selection = SingleSelection(index=top_ids[0], reason=top_selection_reason)
|
|
|
|
# parse output
|
|
return SelectorResult(selections=[top_selection])
|