faiss_rag_enterprise/llama_index/llm_predictor/mock.py

156 lines
5.7 KiB
Python

"""Mock LLM Predictor."""
from typing import Any, Dict
from deprecated import deprecated
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.constants import DEFAULT_NUM_OUTPUTS
from llama_index.core.llms.types import LLMMetadata
from llama_index.llm_predictor.base import BaseLLMPredictor
from llama_index.llms.llm import LLM
from llama_index.prompts.base import BasePromptTemplate
from llama_index.prompts.prompt_type import PromptType
from llama_index.token_counter.utils import (
mock_extract_keywords_response,
mock_extract_kg_triplets_response,
)
from llama_index.types import TokenAsyncGen, TokenGen
from llama_index.utils import get_tokenizer
# TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py
def _mock_summary_predict(max_tokens: int, prompt_args: Dict) -> str:
"""Mock summary predict."""
# tokens in response shouldn't be larger than tokens in `context_str`
num_text_tokens = len(get_tokenizer()(prompt_args["context_str"]))
token_limit = min(num_text_tokens, max_tokens)
return " ".join(["summary"] * token_limit)
def _mock_insert_predict() -> str:
"""Mock insert predict."""
return "ANSWER: 1"
def _mock_query_select() -> str:
"""Mock query select."""
return "ANSWER: 1"
def _mock_query_select_multiple(num_chunks: int) -> str:
"""Mock query select."""
nums_str = ", ".join([str(i) for i in range(num_chunks)])
return f"ANSWER: {nums_str}"
def _mock_answer(max_tokens: int, prompt_args: Dict) -> str:
"""Mock answer."""
# tokens in response shouldn't be larger than tokens in `text`
num_ctx_tokens = len(get_tokenizer()(prompt_args["context_str"]))
token_limit = min(num_ctx_tokens, max_tokens)
return " ".join(["answer"] * token_limit)
def _mock_refine(max_tokens: int, prompt: BasePromptTemplate, prompt_args: Dict) -> str:
"""Mock refine."""
# tokens in response shouldn't be larger than tokens in
# `existing_answer` + `context_msg`
# NOTE: if existing_answer is not in prompt_args, we need to get it from the prompt
if "existing_answer" not in prompt_args:
existing_answer = prompt.kwargs["existing_answer"]
else:
existing_answer = prompt_args["existing_answer"]
num_ctx_tokens = len(get_tokenizer()(prompt_args["context_msg"]))
num_exist_tokens = len(get_tokenizer()(existing_answer))
token_limit = min(num_ctx_tokens + num_exist_tokens, max_tokens)
return " ".join(["answer"] * token_limit)
def _mock_keyword_extract(prompt_args: Dict) -> str:
"""Mock keyword extract."""
return mock_extract_keywords_response(prompt_args["text"])
def _mock_query_keyword_extract(prompt_args: Dict) -> str:
"""Mock query keyword extract."""
return mock_extract_keywords_response(prompt_args["question"])
def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) -> str:
"""Mock knowledge graph triplet extract."""
return mock_extract_kg_triplets_response(
prompt_args["text"], max_triplets=max_triplets
)
@deprecated("MockLLMPredictor is deprecated. Use MockLLM instead.")
class MockLLMPredictor(BaseLLMPredictor):
"""Mock LLM Predictor."""
max_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS, description="Number of tokens to mock generate."
)
_callback_manager: CallbackManager = PrivateAttr(default_factory=CallbackManager)
@classmethod
def class_name(cls) -> str:
return "MockLLMPredictor"
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata()
@property
def callback_manager(self) -> CallbackManager:
return self.callback_manager
@property
def llm(self) -> LLM:
raise NotImplementedError("MockLLMPredictor does not have an LLM model.")
def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
"""Mock predict."""
prompt_str = prompt.metadata["prompt_type"]
if prompt_str == PromptType.SUMMARY:
output = _mock_summary_predict(self.max_tokens, prompt_args)
elif prompt_str == PromptType.TREE_INSERT:
output = _mock_insert_predict()
elif prompt_str == PromptType.TREE_SELECT:
output = _mock_query_select()
elif prompt_str == PromptType.TREE_SELECT_MULTIPLE:
output = _mock_query_select_multiple(prompt_args["num_chunks"])
elif prompt_str == PromptType.REFINE:
output = _mock_refine(self.max_tokens, prompt, prompt_args)
elif prompt_str == PromptType.QUESTION_ANSWER:
output = _mock_answer(self.max_tokens, prompt_args)
elif prompt_str == PromptType.KEYWORD_EXTRACT:
output = _mock_keyword_extract(prompt_args)
elif prompt_str == PromptType.QUERY_KEYWORD_EXTRACT:
output = _mock_query_keyword_extract(prompt_args)
elif prompt_str == PromptType.KNOWLEDGE_TRIPLET_EXTRACT:
output = _mock_knowledge_graph_triplet_extract(
prompt_args,
int(prompt.kwargs.get("max_knowledge_triplets", 2)),
)
elif prompt_str == PromptType.CUSTOM:
# we don't know specific prompt type, return generic response
output = ""
else:
raise ValueError("Invalid prompt type.")
return output
def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen:
raise NotImplementedError
async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
return self.predict(prompt, **prompt_args)
async def astream(
self, prompt: BasePromptTemplate, **prompt_args: Any
) -> TokenAsyncGen:
raise NotImplementedError