faiss_rag_enterprise/llama_index/selectors/utils.py

34 lines
1.1 KiB
Python

from typing import Optional
from llama_index.core.base_selector import BaseSelector
from llama_index.selectors.llm_selectors import LLMMultiSelector, LLMSingleSelector
from llama_index.selectors.pydantic_selectors import (
PydanticMultiSelector,
PydanticSingleSelector,
)
from llama_index.service_context import ServiceContext
def get_selector_from_context(
service_context: ServiceContext, is_multi: bool = False
) -> BaseSelector:
"""Get a selector from a service context. Prefers Pydantic selectors if possible."""
selector: Optional[BaseSelector] = None
if is_multi:
try:
llm = service_context.llm
selector = PydanticMultiSelector.from_defaults(llm=llm) # type: ignore
except ValueError:
selector = LLMMultiSelector.from_defaults(service_context=service_context)
else:
try:
llm = service_context.llm
selector = PydanticSingleSelector.from_defaults(llm=llm) # type: ignore
except ValueError:
selector = LLMSingleSelector.from_defaults(service_context=service_context)
assert selector is not None
return selector