faiss_rag_enterprise/llama_index/response_synthesizers/generation.py

73 lines
2.3 KiB
Python

from typing import Any, Optional, Sequence
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT
from llama_index.prompts.mixin import PromptDictType
from llama_index.response_synthesizers.base import BaseSynthesizer
from llama_index.service_context import ServiceContext
from llama_index.types import RESPONSE_TEXT_TYPE
class Generation(BaseSynthesizer):
def __init__(
self,
simple_template: Optional[BasePromptTemplate] = None,
service_context: Optional[ServiceContext] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context=service_context, streaming=streaming)
self._input_prompt = simple_template or DEFAULT_SIMPLE_INPUT_PROMPT
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"simple_template": self._input_prompt}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "simple_template" in prompts:
self._input_prompt = prompts["simple_template"]
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
# NOTE: ignore text chunks and previous response
del text_chunks
if not self._streaming:
return await self._service_context.llm.apredict(
self._input_prompt,
query_str=query_str,
**response_kwargs,
)
else:
return self._service_context.llm.stream(
self._input_prompt,
query_str=query_str,
**response_kwargs,
)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
# NOTE: ignore text chunks and previous response
del text_chunks
if not self._streaming:
return self._service_context.llm.predict(
self._input_prompt,
query_str=query_str,
**response_kwargs,
)
else:
return self._service_context.llm.stream(
self._input_prompt,
query_str=query_str,
**response_kwargs,
)