faiss_rag_enterprise/llama_index/indices/keyword_table/retrievers.py

168 lines
5.9 KiB
Python

"""Query for KeywordTableIndex."""
import logging
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Optional
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_retriever import BaseRetriever
from llama_index.indices.keyword_table.base import BaseKeywordTableIndex
from llama_index.indices.keyword_table.utils import (
extract_keywords_given_response,
rake_extract_keywords,
simple_extract_keywords,
)
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import (
DEFAULT_KEYWORD_EXTRACT_TEMPLATE,
DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE,
)
from llama_index.schema import NodeWithScore, QueryBundle
from llama_index.utils import truncate_text
DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE
logger = logging.getLogger(__name__)
class BaseKeywordTableRetriever(BaseRetriever):
"""Base Keyword Table Retriever.
Arguments are shared among subclasses.
Args:
keyword_extract_template (Optional[BasePromptTemplate]): A Keyword
Extraction Prompt
(see :ref:`Prompt-Templates`).
query_keyword_extract_template (Optional[BasePromptTemplate]): A Query
Keyword Extraction
Prompt (see :ref:`Prompt-Templates`).
refine_template (Optional[BasePromptTemplate]): A Refinement Prompt
(see :ref:`Prompt-Templates`).
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
(see :ref:`Prompt-Templates`).
max_keywords_per_query (int): Maximum number of keywords to extract from query.
num_chunks_per_query (int): Maximum number of text chunks to query.
"""
def __init__(
self,
index: BaseKeywordTableIndex,
keyword_extract_template: Optional[BasePromptTemplate] = None,
query_keyword_extract_template: Optional[BasePromptTemplate] = None,
max_keywords_per_query: int = 10,
num_chunks_per_query: int = 10,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._index = index
self._index_struct = index.index_struct
self._docstore = index.docstore
self._service_context = index.service_context
self.max_keywords_per_query = max_keywords_per_query
self.num_chunks_per_query = num_chunks_per_query
self.keyword_extract_template = (
keyword_extract_template or DEFAULT_KEYWORD_EXTRACT_TEMPLATE
)
self.query_keyword_extract_template = query_keyword_extract_template or DQKET
super().__init__(
callback_manager=callback_manager,
object_map=object_map,
verbose=verbose,
)
@abstractmethod
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Get nodes for response."""
logger.info(f"> Starting query: {query_bundle.query_str}")
keywords = self._get_keywords(query_bundle.query_str)
logger.info(f"query keywords: {keywords}")
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [k for k in keywords if k in self._index_struct.keywords]
logger.info(f"> Extracted keywords: {keywords}")
for k in keywords:
for node_id in self._index_struct.table[k]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
chunk_indices_count.keys(),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
sorted_chunk_indices = sorted_chunk_indices[: self.num_chunks_per_query]
sorted_nodes = self._docstore.get_nodes(sorted_chunk_indices)
if logging.getLogger(__name__).getEffectiveLevel() == logging.DEBUG:
for chunk_idx, node in zip(sorted_chunk_indices, sorted_nodes):
logger.debug(
f"> Querying with idx: {chunk_idx}: "
f"{truncate_text(node.get_content(), 50)}"
)
return [NodeWithScore(node=node) for node in sorted_nodes]
class KeywordTableGPTRetriever(BaseKeywordTableRetriever):
"""Keyword Table Index GPT Retriever.
Extracts keywords using GPT. Set when using `retriever_mode="default"`.
See BaseGPTKeywordTableQuery for arguments.
"""
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
response = self._service_context.llm.predict(
self.query_keyword_extract_template,
max_keywords=self.max_keywords_per_query,
question=query_str,
)
keywords = extract_keywords_given_response(response, start_token="KEYWORDS:")
return list(keywords)
class KeywordTableSimpleRetriever(BaseKeywordTableRetriever):
"""Keyword Table Index Simple Retriever.
Extracts keywords using simple regex-based keyword extractor.
Set when `retriever_mode="simple"`.
See BaseGPTKeywordTableQuery for arguments.
"""
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
simple_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)
class KeywordTableRAKERetriever(BaseKeywordTableRetriever):
"""Keyword Table Index RAKE Retriever.
Extracts keywords using RAKE keyword extractor.
Set when `retriever_mode="rake"`.
See BaseGPTKeywordTableQuery for arguments.
"""
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
rake_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)