faiss_rag_enterprise/llama_index/llms/llm.py

462 lines
15 KiB
Python

from collections import ChainMap
from typing import (
Any,
Dict,
List,
Optional,
Protocol,
Sequence,
get_args,
runtime_checkable,
)
from llama_index.bridge.pydantic import BaseModel, Field, validator
from llama_index.callbacks import CBEventType, EventPayload
from llama_index.core.llms.types import (
ChatMessage,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponseAsyncGen,
CompletionResponseGen,
MessageRole,
)
from llama_index.core.query_pipeline.query_component import (
InputKeys,
OutputKeys,
QueryComponent,
StringableInput,
validate_and_convert_stringable,
)
from llama_index.llms.base import BaseLLM
from llama_index.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
from llama_index.llms.generic_utils import (
prompt_to_messages,
)
from llama_index.prompts import BasePromptTemplate, PromptTemplate
from llama_index.types import (
BaseOutputParser,
PydanticProgramMode,
TokenAsyncGen,
TokenGen,
)
# NOTE: These two protocols are needed to appease mypy
@runtime_checkable
class MessagesToPromptType(Protocol):
def __call__(self, messages: Sequence[ChatMessage]) -> str:
pass
@runtime_checkable
class CompletionToPromptType(Protocol):
def __call__(self, prompt: str) -> str:
pass
def stream_completion_response_to_tokens(
completion_response_gen: CompletionResponseGen,
) -> TokenGen:
"""Convert a stream completion response to a stream of tokens."""
def gen() -> TokenGen:
for response in completion_response_gen:
yield response.delta or ""
return gen()
def stream_chat_response_to_tokens(
chat_response_gen: ChatResponseGen,
) -> TokenGen:
"""Convert a stream completion response to a stream of tokens."""
def gen() -> TokenGen:
for response in chat_response_gen:
yield response.delta or ""
return gen()
async def astream_completion_response_to_tokens(
completion_response_gen: CompletionResponseAsyncGen,
) -> TokenAsyncGen:
"""Convert a stream completion response to a stream of tokens."""
async def gen() -> TokenAsyncGen:
async for response in completion_response_gen:
yield response.delta or ""
return gen()
async def astream_chat_response_to_tokens(
chat_response_gen: ChatResponseAsyncGen,
) -> TokenAsyncGen:
"""Convert a stream completion response to a stream of tokens."""
async def gen() -> TokenAsyncGen:
async for response in chat_response_gen:
yield response.delta or ""
return gen()
def default_completion_to_prompt(prompt: str) -> str:
return prompt
class LLM(BaseLLM):
system_prompt: Optional[str] = Field(
default=None, description="System prompt for LLM calls."
)
messages_to_prompt: MessagesToPromptType = Field(
description="Function to convert a list of messages to an LLM prompt.",
default=generic_messages_to_prompt,
exclude=True,
)
completion_to_prompt: CompletionToPromptType = Field(
description="Function to convert a completion to an LLM prompt.",
default=default_completion_to_prompt,
exclude=True,
)
output_parser: Optional[BaseOutputParser] = Field(
description="Output parser to parse, validate, and correct errors programmatically.",
default=None,
exclude=True,
)
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT
# deprecated
query_wrapper_prompt: Optional[BasePromptTemplate] = Field(
description="Query wrapper prompt for LLM calls.",
default=None,
exclude=True,
)
@validator("messages_to_prompt", pre=True)
def set_messages_to_prompt(
cls, messages_to_prompt: Optional[MessagesToPromptType]
) -> MessagesToPromptType:
return messages_to_prompt or generic_messages_to_prompt
@validator("completion_to_prompt", pre=True)
def set_completion_to_prompt(
cls, completion_to_prompt: Optional[CompletionToPromptType]
) -> CompletionToPromptType:
return completion_to_prompt or default_completion_to_prompt
def _log_template_data(
self, prompt: BasePromptTemplate, **prompt_args: Any
) -> None:
template_vars = {
k: v
for k, v in ChainMap(prompt.kwargs, prompt_args).items()
if k in prompt.template_vars
}
with self.callback_manager.event(
CBEventType.TEMPLATING,
payload={
EventPayload.TEMPLATE: prompt.get_template(llm=self),
EventPayload.TEMPLATE_VARS: template_vars,
EventPayload.SYSTEM_PROMPT: self.system_prompt,
EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt,
},
):
pass
def _get_prompt(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
formatted_prompt = prompt.format(
llm=self,
messages_to_prompt=self.messages_to_prompt,
completion_to_prompt=self.completion_to_prompt,
**prompt_args,
)
if self.output_parser is not None:
formatted_prompt = self.output_parser.format(formatted_prompt)
return self._extend_prompt(formatted_prompt)
def _get_messages(
self, prompt: BasePromptTemplate, **prompt_args: Any
) -> List[ChatMessage]:
messages = prompt.format_messages(llm=self, **prompt_args)
if self.output_parser is not None:
messages = self.output_parser.format_messages(messages)
return self._extend_messages(messages)
def structured_predict(
self,
output_cls: BaseModel,
prompt: PromptTemplate,
**prompt_args: Any,
) -> BaseModel:
from llama_index.program.utils import get_program_for_llm
program = get_program_for_llm(
output_cls,
prompt,
self,
pydantic_program_mode=self.pydantic_program_mode,
)
return program(**prompt_args)
async def astructured_predict(
self,
output_cls: BaseModel,
prompt: PromptTemplate,
**prompt_args: Any,
) -> BaseModel:
from llama_index.program.utils import get_program_for_llm
program = get_program_for_llm(
output_cls,
prompt,
self,
pydantic_program_mode=self.pydantic_program_mode,
)
return await program.acall(**prompt_args)
def _parse_output(self, output: str) -> str:
if self.output_parser is not None:
return str(self.output_parser.parse(output))
return output
def predict(
self,
prompt: BasePromptTemplate,
**prompt_args: Any,
) -> str:
"""Predict."""
self._log_template_data(prompt, **prompt_args)
if self.metadata.is_chat_model:
messages = self._get_messages(prompt, **prompt_args)
chat_response = self.chat(messages)
output = chat_response.message.content or ""
else:
formatted_prompt = self._get_prompt(prompt, **prompt_args)
response = self.complete(formatted_prompt, formatted=True)
output = response.text
return self._parse_output(output)
def stream(
self,
prompt: BasePromptTemplate,
**prompt_args: Any,
) -> TokenGen:
"""Stream."""
self._log_template_data(prompt, **prompt_args)
if self.metadata.is_chat_model:
messages = self._get_messages(prompt, **prompt_args)
chat_response = self.stream_chat(messages)
stream_tokens = stream_chat_response_to_tokens(chat_response)
else:
formatted_prompt = self._get_prompt(prompt, **prompt_args)
stream_response = self.stream_complete(formatted_prompt, formatted=True)
stream_tokens = stream_completion_response_to_tokens(stream_response)
if prompt.output_parser is not None or self.output_parser is not None:
raise NotImplementedError("Output parser is not supported for streaming.")
return stream_tokens
async def apredict(
self,
prompt: BasePromptTemplate,
**prompt_args: Any,
) -> str:
"""Async predict."""
self._log_template_data(prompt, **prompt_args)
if self.metadata.is_chat_model:
messages = self._get_messages(prompt, **prompt_args)
chat_response = await self.achat(messages)
output = chat_response.message.content or ""
else:
formatted_prompt = self._get_prompt(prompt, **prompt_args)
response = await self.acomplete(formatted_prompt, formatted=True)
output = response.text
return self._parse_output(output)
async def astream(
self,
prompt: BasePromptTemplate,
**prompt_args: Any,
) -> TokenAsyncGen:
"""Async stream."""
self._log_template_data(prompt, **prompt_args)
if self.metadata.is_chat_model:
messages = self._get_messages(prompt, **prompt_args)
chat_response = await self.astream_chat(messages)
stream_tokens = await astream_chat_response_to_tokens(chat_response)
else:
formatted_prompt = self._get_prompt(prompt, **prompt_args)
stream_response = await self.astream_complete(
formatted_prompt, formatted=True
)
stream_tokens = await astream_completion_response_to_tokens(stream_response)
if prompt.output_parser is not None or self.output_parser is not None:
raise NotImplementedError("Output parser is not supported for streaming.")
return stream_tokens
def _extend_prompt(
self,
formatted_prompt: str,
) -> str:
"""Add system and query wrapper prompts to base prompt."""
extended_prompt = formatted_prompt
if self.system_prompt:
extended_prompt = self.system_prompt + "\n\n" + extended_prompt
if self.query_wrapper_prompt:
extended_prompt = self.query_wrapper_prompt.format(
query_str=extended_prompt
)
return extended_prompt
def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
"""Add system prompt to chat message list."""
if self.system_prompt:
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
*messages,
]
return messages
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""Return query component."""
if self.metadata.is_chat_model:
return LLMChatComponent(llm=self, **kwargs)
else:
return LLMCompleteComponent(llm=self, **kwargs)
class BaseLLMComponent(QueryComponent):
"""Base LLM component."""
llm: LLM = Field(..., description="LLM")
streaming: bool = Field(default=False, description="Streaming mode")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: Any) -> None:
"""Set callback manager."""
self.llm.callback_manager = callback_manager
class LLMCompleteComponent(BaseLLMComponent):
"""LLM completion component."""
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
if "prompt" not in input:
raise ValueError("Prompt must be in input dict.")
# do special check to see if prompt is a list of chat messages
if isinstance(input["prompt"], get_args(List[ChatMessage])):
input["prompt"] = self.llm.messages_to_prompt(input["prompt"])
input["prompt"] = validate_and_convert_stringable(input["prompt"])
else:
input["prompt"] = validate_and_convert_stringable(input["prompt"])
input["prompt"] = self.llm.completion_to_prompt(input["prompt"])
return input
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
# TODO: support only complete for now
# non-trivial to figure how to support chat/complete/etc.
prompt = kwargs["prompt"]
# ignore all other kwargs for now
if self.streaming:
response = self.llm.stream_complete(prompt, formatted=True)
else:
response = self.llm.complete(prompt, formatted=True)
return {"output": response}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component."""
# TODO: support only complete for now
# non-trivial to figure how to support chat/complete/etc.
prompt = kwargs["prompt"]
# ignore all other kwargs for now
response = await self.llm.acomplete(prompt, formatted=True)
return {"output": response}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# TODO: support only complete for now
return InputKeys.from_keys({"prompt"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})
class LLMChatComponent(BaseLLMComponent):
"""LLM chat component."""
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
if "messages" not in input:
raise ValueError("Messages must be in input dict.")
# if `messages` is a string, convert to a list of chat message
if isinstance(input["messages"], get_args(StringableInput)):
input["messages"] = validate_and_convert_stringable(input["messages"])
input["messages"] = prompt_to_messages(str(input["messages"]))
for message in input["messages"]:
if not isinstance(message, ChatMessage):
raise ValueError("Messages must be a list of ChatMessage")
return input
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
# TODO: support only complete for now
# non-trivial to figure how to support chat/complete/etc.
messages = kwargs["messages"]
if self.streaming:
response = self.llm.stream_chat(messages)
else:
response = self.llm.chat(messages)
return {"output": response}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component."""
# TODO: support only complete for now
# non-trivial to figure how to support chat/complete/etc.
messages = kwargs["messages"]
if self.streaming:
response = await self.llm.astream_chat(messages)
else:
response = await self.llm.achat(messages)
return {"output": response}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# TODO: support only complete for now
return InputKeys.from_keys({"messages"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})