104 lines
3.8 KiB
Python
104 lines
3.8 KiB
Python
import logging
|
|
from typing import Callable, List, Optional, cast
|
|
|
|
from nltk.stem import PorterStemmer
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.indices.keyword_table.utils import simple_extract_keywords
|
|
from llama_index.indices.vector_store.base import VectorStoreIndex
|
|
from llama_index.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
|
from llama_index.storage.docstore.types import BaseDocumentStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def tokenize_remove_stopwords(text: str) -> List[str]:
|
|
# lowercase and stem words
|
|
text = text.lower()
|
|
stemmer = PorterStemmer()
|
|
words = list(simple_extract_keywords(text))
|
|
return [stemmer.stem(word) for word in words]
|
|
|
|
|
|
class BM25Retriever(BaseRetriever):
|
|
def __init__(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
tokenizer: Optional[Callable[[str], List[str]]],
|
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
objects: Optional[List[IndexNode]] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
try:
|
|
from rank_bm25 import BM25Okapi
|
|
except ImportError:
|
|
raise ImportError("Please install rank_bm25: pip install rank-bm25")
|
|
|
|
self._nodes = nodes
|
|
self._tokenizer = tokenizer or tokenize_remove_stopwords
|
|
self._similarity_top_k = similarity_top_k
|
|
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
|
self.bm25 = BM25Okapi(self._corpus)
|
|
super().__init__(
|
|
callback_manager=callback_manager,
|
|
object_map=object_map,
|
|
objects=objects,
|
|
verbose=verbose,
|
|
)
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
index: Optional[VectorStoreIndex] = None,
|
|
nodes: Optional[List[BaseNode]] = None,
|
|
docstore: Optional[BaseDocumentStore] = None,
|
|
tokenizer: Optional[Callable[[str], List[str]]] = None,
|
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
|
verbose: bool = False,
|
|
) -> "BM25Retriever":
|
|
# ensure only one of index, nodes, or docstore is passed
|
|
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
|
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
|
|
|
if index is not None:
|
|
docstore = index.docstore
|
|
|
|
if docstore is not None:
|
|
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
|
|
|
assert (
|
|
nodes is not None
|
|
), "Please pass exactly one of index, nodes, or docstore."
|
|
|
|
tokenizer = tokenizer or tokenize_remove_stopwords
|
|
return cls(
|
|
nodes=nodes,
|
|
tokenizer=tokenizer,
|
|
similarity_top_k=similarity_top_k,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def _get_scored_nodes(self, query: str) -> List[NodeWithScore]:
|
|
tokenized_query = self._tokenizer(query)
|
|
doc_scores = self.bm25.get_scores(tokenized_query)
|
|
|
|
nodes: List[NodeWithScore] = []
|
|
for i, node in enumerate(self._nodes):
|
|
nodes.append(NodeWithScore(node=node, score=doc_scores[i]))
|
|
|
|
return nodes
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
if query_bundle.custom_embedding_strs or query_bundle.embedding:
|
|
logger.warning("BM25Retriever does not support embeddings, skipping...")
|
|
|
|
scored_nodes = self._get_scored_nodes(query_bundle.query_str)
|
|
|
|
# Sort and get top_k nodes, score range => 0..1, closer to 1 means more relevant
|
|
nodes = sorted(scored_nodes, key=lambda x: x.score or 0.0, reverse=True)
|
|
return nodes[: self._similarity_top_k]
|