faiss_rag_enterprise/llama_index/postprocessor/optimizer.py

156 lines
5.6 KiB
Python

"""Optimization related classes and functions."""
import logging
from typing import Callable, List, Optional
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.indices.query.embedding_utils import get_top_k_embeddings
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle
logger = logging.getLogger(__name__)
class SentenceEmbeddingOptimizer(BaseNodePostprocessor):
"""Optimization of a text chunk given the query by shortening the input text."""
percentile_cutoff: Optional[float] = Field(
description="Percentile cutoff for the top k sentences to use."
)
threshold_cutoff: Optional[float] = Field(
description="Threshold cutoff for similarity for each sentence to use."
)
_embed_model: BaseEmbedding = PrivateAttr()
_tokenizer_fn: Callable[[str], List[str]] = PrivateAttr()
context_before: Optional[int] = Field(
description="Number of sentences before retrieved sentence for further context"
)
context_after: Optional[int] = Field(
description="Number of sentences after retrieved sentence for further context"
)
def __init__(
self,
embed_model: Optional[BaseEmbedding] = None,
percentile_cutoff: Optional[float] = None,
threshold_cutoff: Optional[float] = None,
tokenizer_fn: Optional[Callable[[str], List[str]]] = None,
context_before: Optional[int] = None,
context_after: Optional[int] = None,
):
"""Optimizer class that is passed into BaseGPTIndexQuery.
Should be set like this:
.. code-block:: python
from llama_index.optimization.optimizer import Optimizer
optimizer = SentenceEmbeddingOptimizer(
percentile_cutoff=0.5
this means that the top 50% of sentences will be used.
Alternatively, you can set the cutoff using a threshold
on the similarity score. In this case only sentences with a
similarity score higher than the threshold will be used.
threshold_cutoff=0.7
these cutoffs can also be used together.
)
query_engine = index.as_query_engine(
optimizer=optimizer
)
response = query_engine.query("<query_str>")
"""
self._embed_model = embed_model or OpenAIEmbedding()
if tokenizer_fn is None:
import nltk.data
tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
tokenizer_fn = tokenizer.tokenize
self._tokenizer_fn = tokenizer_fn
super().__init__(
percentile_cutoff=percentile_cutoff,
threshold_cutoff=threshold_cutoff,
context_after=context_after,
context_before=context_before,
)
@classmethod
def class_name(cls) -> str:
return "SentenceEmbeddingOptimizer"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Optimize a node text given the query by shortening the node text."""
if query_bundle is None:
return nodes
for node_idx in range(len(nodes)):
text = nodes[node_idx].node.get_content(metadata_mode=MetadataMode.LLM)
split_text = self._tokenizer_fn(text)
if query_bundle.embedding is None:
query_bundle.embedding = (
self._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
text_embeddings = self._embed_model._get_text_embeddings(split_text)
num_top_k = None
threshold = None
if self.percentile_cutoff is not None:
num_top_k = int(len(split_text) * self.percentile_cutoff)
if self.threshold_cutoff is not None:
threshold = self.threshold_cutoff
top_similarities, top_idxs = get_top_k_embeddings(
query_embedding=query_bundle.embedding,
embeddings=text_embeddings,
similarity_fn=self._embed_model.similarity,
similarity_top_k=num_top_k,
embedding_ids=list(range(len(text_embeddings))),
similarity_cutoff=threshold,
)
if len(top_idxs) == 0:
raise ValueError("Optimizer returned zero sentences.")
rangeMin, rangeMax = 0, len(split_text)
if self.context_before is None:
self.context_before = 1
if self.context_after is None:
self.context_after = 1
top_sentences = [
" ".join(
split_text[
max(idx - self.context_before, rangeMin) : min(
idx + self.context_after + 1, rangeMax
)
]
)
for idx in top_idxs
]
logger.debug(f"> Top {len(top_idxs)} sentences with scores:\n")
if logger.isEnabledFor(logging.DEBUG):
for idx in range(len(top_idxs)):
logger.debug(
f"{idx}. {top_sentences[idx]} ({top_similarities[idx]})"
)
nodes[node_idx].node.set_content(" ".join(top_sentences))
return nodes