97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
from typing import List, Optional, Sequence, cast
|
|
|
|
from llama_index.llm_predictor.base import LLMPredictorType
|
|
from llama_index.output_parsers.base import StructuredOutput
|
|
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.prompts.prompt_type import PromptType
|
|
from llama_index.question_gen.output_parser import SubQuestionOutputParser
|
|
from llama_index.question_gen.prompts import (
|
|
DEFAULT_SUB_QUESTION_PROMPT_TMPL,
|
|
build_tools_text,
|
|
)
|
|
from llama_index.question_gen.types import BaseQuestionGenerator, SubQuestion
|
|
from llama_index.schema import QueryBundle
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.tools.types import ToolMetadata
|
|
from llama_index.types import BaseOutputParser
|
|
|
|
|
|
class LLMQuestionGenerator(BaseQuestionGenerator):
|
|
def __init__(
|
|
self,
|
|
llm: LLMPredictorType,
|
|
prompt: BasePromptTemplate,
|
|
) -> None:
|
|
self._llm = llm
|
|
self._prompt = prompt
|
|
|
|
if self._prompt.output_parser is None:
|
|
raise ValueError("Prompt should have output parser.")
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
service_context: Optional[ServiceContext] = None,
|
|
prompt_template_str: Optional[str] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
) -> "LLMQuestionGenerator":
|
|
# optionally initialize defaults
|
|
service_context = service_context or ServiceContext.from_defaults()
|
|
prompt_template_str = prompt_template_str or DEFAULT_SUB_QUESTION_PROMPT_TMPL
|
|
output_parser = output_parser or SubQuestionOutputParser()
|
|
|
|
# construct prompt
|
|
prompt = PromptTemplate(
|
|
template=prompt_template_str,
|
|
output_parser=output_parser,
|
|
prompt_type=PromptType.SUB_QUESTION,
|
|
)
|
|
return cls(service_context.llm, prompt)
|
|
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {"question_gen_prompt": self._prompt}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "question_gen_prompt" in prompts:
|
|
output_parser = prompts["question_gen_prompt"].output_parser
|
|
if output_parser is None:
|
|
output_parser = SubQuestionOutputParser()
|
|
self._prompt = PromptTemplate(
|
|
prompts["question_gen_prompt"].template, output_parser=output_parser
|
|
)
|
|
|
|
def generate(
|
|
self, tools: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> List[SubQuestion]:
|
|
tools_str = build_tools_text(tools)
|
|
query_str = query.query_str
|
|
prediction = self._llm.predict(
|
|
prompt=self._prompt,
|
|
tools_str=tools_str,
|
|
query_str=query_str,
|
|
)
|
|
|
|
assert self._prompt.output_parser is not None
|
|
parse = self._prompt.output_parser.parse(prediction)
|
|
parse = cast(StructuredOutput, parse)
|
|
return parse.parsed_output
|
|
|
|
async def agenerate(
|
|
self, tools: Sequence[ToolMetadata], query: QueryBundle
|
|
) -> List[SubQuestion]:
|
|
tools_str = build_tools_text(tools)
|
|
query_str = query.query_str
|
|
prediction = await self._llm.apredict(
|
|
prompt=self._prompt,
|
|
tools_str=tools_str,
|
|
query_str=query_str,
|
|
)
|
|
|
|
assert self._prompt.output_parser is not None
|
|
parse = self._prompt.output_parser.parse(prediction)
|
|
parse = cast(StructuredOutput, parse)
|
|
return parse.parsed_output
|