168 lines
5.9 KiB
Python
168 lines
5.9 KiB
Python
"""Query for KeywordTableIndex."""
|
|
import logging
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.indices.keyword_table.base import BaseKeywordTableIndex
|
|
from llama_index.indices.keyword_table.utils import (
|
|
extract_keywords_given_response,
|
|
rake_extract_keywords,
|
|
simple_extract_keywords,
|
|
)
|
|
from llama_index.prompts import BasePromptTemplate
|
|
from llama_index.prompts.default_prompts import (
|
|
DEFAULT_KEYWORD_EXTRACT_TEMPLATE,
|
|
DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE,
|
|
)
|
|
from llama_index.schema import NodeWithScore, QueryBundle
|
|
from llama_index.utils import truncate_text
|
|
|
|
DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseKeywordTableRetriever(BaseRetriever):
|
|
"""Base Keyword Table Retriever.
|
|
|
|
Arguments are shared among subclasses.
|
|
|
|
Args:
|
|
keyword_extract_template (Optional[BasePromptTemplate]): A Keyword
|
|
Extraction Prompt
|
|
(see :ref:`Prompt-Templates`).
|
|
query_keyword_extract_template (Optional[BasePromptTemplate]): A Query
|
|
Keyword Extraction
|
|
Prompt (see :ref:`Prompt-Templates`).
|
|
refine_template (Optional[BasePromptTemplate]): A Refinement Prompt
|
|
(see :ref:`Prompt-Templates`).
|
|
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
|
|
(see :ref:`Prompt-Templates`).
|
|
max_keywords_per_query (int): Maximum number of keywords to extract from query.
|
|
num_chunks_per_query (int): Maximum number of text chunks to query.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index: BaseKeywordTableIndex,
|
|
keyword_extract_template: Optional[BasePromptTemplate] = None,
|
|
query_keyword_extract_template: Optional[BasePromptTemplate] = None,
|
|
max_keywords_per_query: int = 10,
|
|
num_chunks_per_query: int = 10,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
self._index = index
|
|
self._index_struct = index.index_struct
|
|
self._docstore = index.docstore
|
|
self._service_context = index.service_context
|
|
|
|
self.max_keywords_per_query = max_keywords_per_query
|
|
self.num_chunks_per_query = num_chunks_per_query
|
|
self.keyword_extract_template = (
|
|
keyword_extract_template or DEFAULT_KEYWORD_EXTRACT_TEMPLATE
|
|
)
|
|
self.query_keyword_extract_template = query_keyword_extract_template or DQKET
|
|
super().__init__(
|
|
callback_manager=callback_manager,
|
|
object_map=object_map,
|
|
verbose=verbose,
|
|
)
|
|
|
|
@abstractmethod
|
|
def _get_keywords(self, query_str: str) -> List[str]:
|
|
"""Extract keywords."""
|
|
|
|
def _retrieve(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
) -> List[NodeWithScore]:
|
|
"""Get nodes for response."""
|
|
logger.info(f"> Starting query: {query_bundle.query_str}")
|
|
keywords = self._get_keywords(query_bundle.query_str)
|
|
logger.info(f"query keywords: {keywords}")
|
|
|
|
# go through text chunks in order of most matching keywords
|
|
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
|
keywords = [k for k in keywords if k in self._index_struct.keywords]
|
|
logger.info(f"> Extracted keywords: {keywords}")
|
|
for k in keywords:
|
|
for node_id in self._index_struct.table[k]:
|
|
chunk_indices_count[node_id] += 1
|
|
sorted_chunk_indices = sorted(
|
|
chunk_indices_count.keys(),
|
|
key=lambda x: chunk_indices_count[x],
|
|
reverse=True,
|
|
)
|
|
sorted_chunk_indices = sorted_chunk_indices[: self.num_chunks_per_query]
|
|
sorted_nodes = self._docstore.get_nodes(sorted_chunk_indices)
|
|
|
|
if logging.getLogger(__name__).getEffectiveLevel() == logging.DEBUG:
|
|
for chunk_idx, node in zip(sorted_chunk_indices, sorted_nodes):
|
|
logger.debug(
|
|
f"> Querying with idx: {chunk_idx}: "
|
|
f"{truncate_text(node.get_content(), 50)}"
|
|
)
|
|
return [NodeWithScore(node=node) for node in sorted_nodes]
|
|
|
|
|
|
class KeywordTableGPTRetriever(BaseKeywordTableRetriever):
|
|
"""Keyword Table Index GPT Retriever.
|
|
|
|
Extracts keywords using GPT. Set when using `retriever_mode="default"`.
|
|
|
|
See BaseGPTKeywordTableQuery for arguments.
|
|
|
|
"""
|
|
|
|
def _get_keywords(self, query_str: str) -> List[str]:
|
|
"""Extract keywords."""
|
|
response = self._service_context.llm.predict(
|
|
self.query_keyword_extract_template,
|
|
max_keywords=self.max_keywords_per_query,
|
|
question=query_str,
|
|
)
|
|
keywords = extract_keywords_given_response(response, start_token="KEYWORDS:")
|
|
return list(keywords)
|
|
|
|
|
|
class KeywordTableSimpleRetriever(BaseKeywordTableRetriever):
|
|
"""Keyword Table Index Simple Retriever.
|
|
|
|
Extracts keywords using simple regex-based keyword extractor.
|
|
Set when `retriever_mode="simple"`.
|
|
|
|
See BaseGPTKeywordTableQuery for arguments.
|
|
|
|
"""
|
|
|
|
def _get_keywords(self, query_str: str) -> List[str]:
|
|
"""Extract keywords."""
|
|
return list(
|
|
simple_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
|
|
)
|
|
|
|
|
|
class KeywordTableRAKERetriever(BaseKeywordTableRetriever):
|
|
"""Keyword Table Index RAKE Retriever.
|
|
|
|
Extracts keywords using RAKE keyword extractor.
|
|
Set when `retriever_mode="rake"`.
|
|
|
|
See BaseGPTKeywordTableQuery for arguments.
|
|
|
|
"""
|
|
|
|
def _get_keywords(self, query_str: str) -> List[str]:
|
|
"""Extract keywords."""
|
|
return list(
|
|
rake_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
|
|
)
|