faiss_rag_enterprise/llama_index/playground/base.py

188 lines
6.9 KiB
Python

"""Experiment with different indices, models, and more."""
from __future__ import annotations
import time
from typing import Any, Dict, List, Type
import pandas as pd
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from llama_index.indices.base import BaseIndex
from llama_index.indices.list.base import ListRetrieverMode, SummaryIndex
from llama_index.indices.tree.base import TreeIndex, TreeRetrieverMode
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llm_predictor.base import LLMPredictor
from llama_index.schema import Document
from llama_index.utils import get_color_mapping, print_text
DEFAULT_INDEX_CLASSES: List[Type[BaseIndex]] = [
VectorStoreIndex,
TreeIndex,
SummaryIndex,
]
INDEX_SPECIFIC_QUERY_MODES_TYPE = Dict[Type[BaseIndex], List[str]]
DEFAULT_MODES: INDEX_SPECIFIC_QUERY_MODES_TYPE = {
TreeIndex: [e.value for e in TreeRetrieverMode],
SummaryIndex: [e.value for e in ListRetrieverMode],
VectorStoreIndex: ["default"],
}
class Playground:
"""Experiment with indices, models, embeddings, retriever_modes, and more."""
def __init__(
self,
indices: List[BaseIndex],
retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE = DEFAULT_MODES,
):
"""Initialize with indices to experiment with.
Args:
indices: A list of BaseIndex's to experiment with
retriever_modes: A list of retriever_modes that specify which nodes are
chosen from the index when a query is made. A full list of
retriever_modes available to each index can be found here:
https://docs.llamaindex.ai/en/stable/module_guides/querying/retriever/retriever_modes.html
"""
self._validate_indices(indices)
self._indices = indices
self._validate_modes(retriever_modes)
self._retriever_modes = retriever_modes
index_range = [str(i) for i in range(len(indices))]
self.index_colors = get_color_mapping(index_range)
@classmethod
def from_docs(
cls,
documents: List[Document],
index_classes: List[Type[BaseIndex]] = DEFAULT_INDEX_CLASSES,
retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE = DEFAULT_MODES,
**kwargs: Any,
) -> Playground:
"""Initialize with Documents using the default list of indices.
Args:
documents: A List of Documents to experiment with.
"""
if len(documents) == 0:
raise ValueError(
"Playground must be initialized with a nonempty list of Documents."
)
indices = [
index_class.from_documents(documents, **kwargs)
for index_class in index_classes
]
return cls(indices, retriever_modes)
def _validate_indices(self, indices: List[BaseIndex]) -> None:
"""Validate a list of indices."""
if len(indices) == 0:
raise ValueError("Playground must have a non-empty list of indices.")
for index in indices:
if not isinstance(index, BaseIndex):
raise ValueError(
"Every index in Playground should be an instance of BaseIndex."
)
@property
def indices(self) -> List[BaseIndex]:
"""Get Playground's indices."""
return self._indices
@indices.setter
def indices(self, indices: List[BaseIndex]) -> None:
"""Set Playground's indices."""
self._validate_indices(indices)
self._indices = indices
def _validate_modes(self, retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE) -> None:
"""Validate a list of retriever_modes."""
if len(retriever_modes) == 0:
raise ValueError(
"Playground must have a nonzero number of retriever_modes."
"Initialize without the `retriever_modes` "
"argument to use the default list."
)
@property
def retriever_modes(self) -> dict:
"""Get Playground's indices."""
return self._retriever_modes
@retriever_modes.setter
def retriever_modes(self, retriever_modes: INDEX_SPECIFIC_QUERY_MODES_TYPE) -> None:
"""Set Playground's indices."""
self._validate_modes(retriever_modes)
self._retriever_modes = retriever_modes
def compare(
self, query_text: str, to_pandas: bool | None = True
) -> pd.DataFrame | List[Dict[str, Any]]:
"""Compare index outputs on an input query.
Args:
query_text (str): Query to run all indices on.
to_pandas (Optional[bool]): Return results in a pandas dataframe.
True by default.
Returns:
The output of each index along with other data, such as the time it took to
compute. Results are stored in a Pandas Dataframe or a list of Dicts.
"""
print(f"\033[1mQuery:\033[0m\n{query_text}\n")
result = []
for i, index in enumerate(self._indices):
for retriever_mode in self._retriever_modes[type(index)]:
start_time = time.time()
index_name = type(index).__name__
print_text(
f"\033[1m{index_name}\033[0m, retriever mode = {retriever_mode}",
end="\n",
)
# insert token counter into service context
service_context = index.service_context
token_counter = TokenCountingHandler()
callback_manager = CallbackManager([token_counter])
if isinstance(service_context.llm_predictor, LLMPredictor):
service_context.llm_predictor.llm.callback_manager = (
callback_manager
)
service_context.embed_model.callback_manager = callback_manager
try:
query_engine = index.as_query_engine(
retriever_mode=retriever_mode, service_context=service_context
)
except ValueError:
continue
output = query_engine.query(query_text)
print_text(str(output), color=self.index_colors[str(i)], end="\n\n")
duration = time.time() - start_time
result.append(
{
"Index": index_name,
"Retriever Mode": retriever_mode,
"Output": str(output),
"Duration": duration,
"Prompt Tokens": token_counter.prompt_llm_token_count,
"Completion Tokens": token_counter.completion_llm_token_count,
"Embed Tokens": token_counter.total_embedding_token_count,
}
)
print(f"\nRan {len(result)} combinations in total.")
if to_pandas:
return pd.DataFrame(result)
else:
return result