73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
from typing import Any, Optional, Sequence
|
|
|
|
from llama_index.prompts import BasePromptTemplate
|
|
from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.response_synthesizers.base import BaseSynthesizer
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.types import RESPONSE_TEXT_TYPE
|
|
|
|
|
|
class Generation(BaseSynthesizer):
|
|
def __init__(
|
|
self,
|
|
simple_template: Optional[BasePromptTemplate] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
streaming: bool = False,
|
|
) -> None:
|
|
super().__init__(service_context=service_context, streaming=streaming)
|
|
self._input_prompt = simple_template or DEFAULT_SIMPLE_INPUT_PROMPT
|
|
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {"simple_template": self._input_prompt}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "simple_template" in prompts:
|
|
self._input_prompt = prompts["simple_template"]
|
|
|
|
async def aget_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
# NOTE: ignore text chunks and previous response
|
|
del text_chunks
|
|
|
|
if not self._streaming:
|
|
return await self._service_context.llm.apredict(
|
|
self._input_prompt,
|
|
query_str=query_str,
|
|
**response_kwargs,
|
|
)
|
|
else:
|
|
return self._service_context.llm.stream(
|
|
self._input_prompt,
|
|
query_str=query_str,
|
|
**response_kwargs,
|
|
)
|
|
|
|
def get_response(
|
|
self,
|
|
query_str: str,
|
|
text_chunks: Sequence[str],
|
|
**response_kwargs: Any,
|
|
) -> RESPONSE_TEXT_TYPE:
|
|
# NOTE: ignore text chunks and previous response
|
|
del text_chunks
|
|
|
|
if not self._streaming:
|
|
return self._service_context.llm.predict(
|
|
self._input_prompt,
|
|
query_str=query_str,
|
|
**response_kwargs,
|
|
)
|
|
else:
|
|
return self._service_context.llm.stream(
|
|
self._input_prompt,
|
|
query_str=query_str,
|
|
**response_kwargs,
|
|
)
|