156 lines
5.7 KiB
Python
156 lines
5.7 KiB
Python
"""Mock LLM Predictor."""
|
|
from typing import Any, Dict
|
|
|
|
from deprecated import deprecated
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.constants import DEFAULT_NUM_OUTPUTS
|
|
from llama_index.core.llms.types import LLMMetadata
|
|
from llama_index.llm_predictor.base import BaseLLMPredictor
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.prompts.base import BasePromptTemplate
|
|
from llama_index.prompts.prompt_type import PromptType
|
|
from llama_index.token_counter.utils import (
|
|
mock_extract_keywords_response,
|
|
mock_extract_kg_triplets_response,
|
|
)
|
|
from llama_index.types import TokenAsyncGen, TokenGen
|
|
from llama_index.utils import get_tokenizer
|
|
|
|
# TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py
|
|
|
|
|
|
def _mock_summary_predict(max_tokens: int, prompt_args: Dict) -> str:
|
|
"""Mock summary predict."""
|
|
# tokens in response shouldn't be larger than tokens in `context_str`
|
|
num_text_tokens = len(get_tokenizer()(prompt_args["context_str"]))
|
|
token_limit = min(num_text_tokens, max_tokens)
|
|
return " ".join(["summary"] * token_limit)
|
|
|
|
|
|
def _mock_insert_predict() -> str:
|
|
"""Mock insert predict."""
|
|
return "ANSWER: 1"
|
|
|
|
|
|
def _mock_query_select() -> str:
|
|
"""Mock query select."""
|
|
return "ANSWER: 1"
|
|
|
|
|
|
def _mock_query_select_multiple(num_chunks: int) -> str:
|
|
"""Mock query select."""
|
|
nums_str = ", ".join([str(i) for i in range(num_chunks)])
|
|
return f"ANSWER: {nums_str}"
|
|
|
|
|
|
def _mock_answer(max_tokens: int, prompt_args: Dict) -> str:
|
|
"""Mock answer."""
|
|
# tokens in response shouldn't be larger than tokens in `text`
|
|
num_ctx_tokens = len(get_tokenizer()(prompt_args["context_str"]))
|
|
token_limit = min(num_ctx_tokens, max_tokens)
|
|
return " ".join(["answer"] * token_limit)
|
|
|
|
|
|
def _mock_refine(max_tokens: int, prompt: BasePromptTemplate, prompt_args: Dict) -> str:
|
|
"""Mock refine."""
|
|
# tokens in response shouldn't be larger than tokens in
|
|
# `existing_answer` + `context_msg`
|
|
# NOTE: if existing_answer is not in prompt_args, we need to get it from the prompt
|
|
if "existing_answer" not in prompt_args:
|
|
existing_answer = prompt.kwargs["existing_answer"]
|
|
else:
|
|
existing_answer = prompt_args["existing_answer"]
|
|
num_ctx_tokens = len(get_tokenizer()(prompt_args["context_msg"]))
|
|
num_exist_tokens = len(get_tokenizer()(existing_answer))
|
|
token_limit = min(num_ctx_tokens + num_exist_tokens, max_tokens)
|
|
return " ".join(["answer"] * token_limit)
|
|
|
|
|
|
def _mock_keyword_extract(prompt_args: Dict) -> str:
|
|
"""Mock keyword extract."""
|
|
return mock_extract_keywords_response(prompt_args["text"])
|
|
|
|
|
|
def _mock_query_keyword_extract(prompt_args: Dict) -> str:
|
|
"""Mock query keyword extract."""
|
|
return mock_extract_keywords_response(prompt_args["question"])
|
|
|
|
|
|
def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) -> str:
|
|
"""Mock knowledge graph triplet extract."""
|
|
return mock_extract_kg_triplets_response(
|
|
prompt_args["text"], max_triplets=max_triplets
|
|
)
|
|
|
|
|
|
@deprecated("MockLLMPredictor is deprecated. Use MockLLM instead.")
|
|
class MockLLMPredictor(BaseLLMPredictor):
|
|
"""Mock LLM Predictor."""
|
|
|
|
max_tokens: int = Field(
|
|
default=DEFAULT_NUM_OUTPUTS, description="Number of tokens to mock generate."
|
|
)
|
|
|
|
_callback_manager: CallbackManager = PrivateAttr(default_factory=CallbackManager)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "MockLLMPredictor"
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
return LLMMetadata()
|
|
|
|
@property
|
|
def callback_manager(self) -> CallbackManager:
|
|
return self.callback_manager
|
|
|
|
@property
|
|
def llm(self) -> LLM:
|
|
raise NotImplementedError("MockLLMPredictor does not have an LLM model.")
|
|
|
|
def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
|
|
"""Mock predict."""
|
|
prompt_str = prompt.metadata["prompt_type"]
|
|
if prompt_str == PromptType.SUMMARY:
|
|
output = _mock_summary_predict(self.max_tokens, prompt_args)
|
|
elif prompt_str == PromptType.TREE_INSERT:
|
|
output = _mock_insert_predict()
|
|
elif prompt_str == PromptType.TREE_SELECT:
|
|
output = _mock_query_select()
|
|
elif prompt_str == PromptType.TREE_SELECT_MULTIPLE:
|
|
output = _mock_query_select_multiple(prompt_args["num_chunks"])
|
|
elif prompt_str == PromptType.REFINE:
|
|
output = _mock_refine(self.max_tokens, prompt, prompt_args)
|
|
elif prompt_str == PromptType.QUESTION_ANSWER:
|
|
output = _mock_answer(self.max_tokens, prompt_args)
|
|
elif prompt_str == PromptType.KEYWORD_EXTRACT:
|
|
output = _mock_keyword_extract(prompt_args)
|
|
elif prompt_str == PromptType.QUERY_KEYWORD_EXTRACT:
|
|
output = _mock_query_keyword_extract(prompt_args)
|
|
elif prompt_str == PromptType.KNOWLEDGE_TRIPLET_EXTRACT:
|
|
output = _mock_knowledge_graph_triplet_extract(
|
|
prompt_args,
|
|
int(prompt.kwargs.get("max_knowledge_triplets", 2)),
|
|
)
|
|
elif prompt_str == PromptType.CUSTOM:
|
|
# we don't know specific prompt type, return generic response
|
|
output = ""
|
|
else:
|
|
raise ValueError("Invalid prompt type.")
|
|
|
|
return output
|
|
|
|
def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen:
|
|
raise NotImplementedError
|
|
|
|
async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
|
|
return self.predict(prompt, **prompt_args)
|
|
|
|
async def astream(
|
|
self, prompt: BasePromptTemplate, **prompt_args: Any
|
|
) -> TokenAsyncGen:
|
|
raise NotImplementedError
|