462 lines
15 KiB
Python
462 lines
15 KiB
Python
from collections import ChainMap
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Protocol,
|
|
Sequence,
|
|
get_args,
|
|
runtime_checkable,
|
|
)
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field, validator
|
|
from llama_index.callbacks import CBEventType, EventPayload
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
MessageRole,
|
|
)
|
|
from llama_index.core.query_pipeline.query_component import (
|
|
InputKeys,
|
|
OutputKeys,
|
|
QueryComponent,
|
|
StringableInput,
|
|
validate_and_convert_stringable,
|
|
)
|
|
from llama_index.llms.base import BaseLLM
|
|
from llama_index.llms.generic_utils import (
|
|
messages_to_prompt as generic_messages_to_prompt,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
prompt_to_messages,
|
|
)
|
|
from llama_index.prompts import BasePromptTemplate, PromptTemplate
|
|
from llama_index.types import (
|
|
BaseOutputParser,
|
|
PydanticProgramMode,
|
|
TokenAsyncGen,
|
|
TokenGen,
|
|
)
|
|
|
|
|
|
# NOTE: These two protocols are needed to appease mypy
|
|
@runtime_checkable
|
|
class MessagesToPromptType(Protocol):
|
|
def __call__(self, messages: Sequence[ChatMessage]) -> str:
|
|
pass
|
|
|
|
|
|
@runtime_checkable
|
|
class CompletionToPromptType(Protocol):
|
|
def __call__(self, prompt: str) -> str:
|
|
pass
|
|
|
|
|
|
def stream_completion_response_to_tokens(
|
|
completion_response_gen: CompletionResponseGen,
|
|
) -> TokenGen:
|
|
"""Convert a stream completion response to a stream of tokens."""
|
|
|
|
def gen() -> TokenGen:
|
|
for response in completion_response_gen:
|
|
yield response.delta or ""
|
|
|
|
return gen()
|
|
|
|
|
|
def stream_chat_response_to_tokens(
|
|
chat_response_gen: ChatResponseGen,
|
|
) -> TokenGen:
|
|
"""Convert a stream completion response to a stream of tokens."""
|
|
|
|
def gen() -> TokenGen:
|
|
for response in chat_response_gen:
|
|
yield response.delta or ""
|
|
|
|
return gen()
|
|
|
|
|
|
async def astream_completion_response_to_tokens(
|
|
completion_response_gen: CompletionResponseAsyncGen,
|
|
) -> TokenAsyncGen:
|
|
"""Convert a stream completion response to a stream of tokens."""
|
|
|
|
async def gen() -> TokenAsyncGen:
|
|
async for response in completion_response_gen:
|
|
yield response.delta or ""
|
|
|
|
return gen()
|
|
|
|
|
|
async def astream_chat_response_to_tokens(
|
|
chat_response_gen: ChatResponseAsyncGen,
|
|
) -> TokenAsyncGen:
|
|
"""Convert a stream completion response to a stream of tokens."""
|
|
|
|
async def gen() -> TokenAsyncGen:
|
|
async for response in chat_response_gen:
|
|
yield response.delta or ""
|
|
|
|
return gen()
|
|
|
|
|
|
def default_completion_to_prompt(prompt: str) -> str:
|
|
return prompt
|
|
|
|
|
|
class LLM(BaseLLM):
|
|
system_prompt: Optional[str] = Field(
|
|
default=None, description="System prompt for LLM calls."
|
|
)
|
|
messages_to_prompt: MessagesToPromptType = Field(
|
|
description="Function to convert a list of messages to an LLM prompt.",
|
|
default=generic_messages_to_prompt,
|
|
exclude=True,
|
|
)
|
|
completion_to_prompt: CompletionToPromptType = Field(
|
|
description="Function to convert a completion to an LLM prompt.",
|
|
default=default_completion_to_prompt,
|
|
exclude=True,
|
|
)
|
|
output_parser: Optional[BaseOutputParser] = Field(
|
|
description="Output parser to parse, validate, and correct errors programmatically.",
|
|
default=None,
|
|
exclude=True,
|
|
)
|
|
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT
|
|
|
|
# deprecated
|
|
query_wrapper_prompt: Optional[BasePromptTemplate] = Field(
|
|
description="Query wrapper prompt for LLM calls.",
|
|
default=None,
|
|
exclude=True,
|
|
)
|
|
|
|
@validator("messages_to_prompt", pre=True)
|
|
def set_messages_to_prompt(
|
|
cls, messages_to_prompt: Optional[MessagesToPromptType]
|
|
) -> MessagesToPromptType:
|
|
return messages_to_prompt or generic_messages_to_prompt
|
|
|
|
@validator("completion_to_prompt", pre=True)
|
|
def set_completion_to_prompt(
|
|
cls, completion_to_prompt: Optional[CompletionToPromptType]
|
|
) -> CompletionToPromptType:
|
|
return completion_to_prompt or default_completion_to_prompt
|
|
|
|
def _log_template_data(
|
|
self, prompt: BasePromptTemplate, **prompt_args: Any
|
|
) -> None:
|
|
template_vars = {
|
|
k: v
|
|
for k, v in ChainMap(prompt.kwargs, prompt_args).items()
|
|
if k in prompt.template_vars
|
|
}
|
|
with self.callback_manager.event(
|
|
CBEventType.TEMPLATING,
|
|
payload={
|
|
EventPayload.TEMPLATE: prompt.get_template(llm=self),
|
|
EventPayload.TEMPLATE_VARS: template_vars,
|
|
EventPayload.SYSTEM_PROMPT: self.system_prompt,
|
|
EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt,
|
|
},
|
|
):
|
|
pass
|
|
|
|
def _get_prompt(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
|
|
formatted_prompt = prompt.format(
|
|
llm=self,
|
|
messages_to_prompt=self.messages_to_prompt,
|
|
completion_to_prompt=self.completion_to_prompt,
|
|
**prompt_args,
|
|
)
|
|
if self.output_parser is not None:
|
|
formatted_prompt = self.output_parser.format(formatted_prompt)
|
|
return self._extend_prompt(formatted_prompt)
|
|
|
|
def _get_messages(
|
|
self, prompt: BasePromptTemplate, **prompt_args: Any
|
|
) -> List[ChatMessage]:
|
|
messages = prompt.format_messages(llm=self, **prompt_args)
|
|
if self.output_parser is not None:
|
|
messages = self.output_parser.format_messages(messages)
|
|
return self._extend_messages(messages)
|
|
|
|
def structured_predict(
|
|
self,
|
|
output_cls: BaseModel,
|
|
prompt: PromptTemplate,
|
|
**prompt_args: Any,
|
|
) -> BaseModel:
|
|
from llama_index.program.utils import get_program_for_llm
|
|
|
|
program = get_program_for_llm(
|
|
output_cls,
|
|
prompt,
|
|
self,
|
|
pydantic_program_mode=self.pydantic_program_mode,
|
|
)
|
|
|
|
return program(**prompt_args)
|
|
|
|
async def astructured_predict(
|
|
self,
|
|
output_cls: BaseModel,
|
|
prompt: PromptTemplate,
|
|
**prompt_args: Any,
|
|
) -> BaseModel:
|
|
from llama_index.program.utils import get_program_for_llm
|
|
|
|
program = get_program_for_llm(
|
|
output_cls,
|
|
prompt,
|
|
self,
|
|
pydantic_program_mode=self.pydantic_program_mode,
|
|
)
|
|
|
|
return await program.acall(**prompt_args)
|
|
|
|
def _parse_output(self, output: str) -> str:
|
|
if self.output_parser is not None:
|
|
return str(self.output_parser.parse(output))
|
|
|
|
return output
|
|
|
|
def predict(
|
|
self,
|
|
prompt: BasePromptTemplate,
|
|
**prompt_args: Any,
|
|
) -> str:
|
|
"""Predict."""
|
|
self._log_template_data(prompt, **prompt_args)
|
|
|
|
if self.metadata.is_chat_model:
|
|
messages = self._get_messages(prompt, **prompt_args)
|
|
chat_response = self.chat(messages)
|
|
output = chat_response.message.content or ""
|
|
else:
|
|
formatted_prompt = self._get_prompt(prompt, **prompt_args)
|
|
response = self.complete(formatted_prompt, formatted=True)
|
|
output = response.text
|
|
|
|
return self._parse_output(output)
|
|
|
|
def stream(
|
|
self,
|
|
prompt: BasePromptTemplate,
|
|
**prompt_args: Any,
|
|
) -> TokenGen:
|
|
"""Stream."""
|
|
self._log_template_data(prompt, **prompt_args)
|
|
|
|
if self.metadata.is_chat_model:
|
|
messages = self._get_messages(prompt, **prompt_args)
|
|
chat_response = self.stream_chat(messages)
|
|
stream_tokens = stream_chat_response_to_tokens(chat_response)
|
|
else:
|
|
formatted_prompt = self._get_prompt(prompt, **prompt_args)
|
|
stream_response = self.stream_complete(formatted_prompt, formatted=True)
|
|
stream_tokens = stream_completion_response_to_tokens(stream_response)
|
|
|
|
if prompt.output_parser is not None or self.output_parser is not None:
|
|
raise NotImplementedError("Output parser is not supported for streaming.")
|
|
|
|
return stream_tokens
|
|
|
|
async def apredict(
|
|
self,
|
|
prompt: BasePromptTemplate,
|
|
**prompt_args: Any,
|
|
) -> str:
|
|
"""Async predict."""
|
|
self._log_template_data(prompt, **prompt_args)
|
|
|
|
if self.metadata.is_chat_model:
|
|
messages = self._get_messages(prompt, **prompt_args)
|
|
chat_response = await self.achat(messages)
|
|
output = chat_response.message.content or ""
|
|
else:
|
|
formatted_prompt = self._get_prompt(prompt, **prompt_args)
|
|
response = await self.acomplete(formatted_prompt, formatted=True)
|
|
output = response.text
|
|
|
|
return self._parse_output(output)
|
|
|
|
async def astream(
|
|
self,
|
|
prompt: BasePromptTemplate,
|
|
**prompt_args: Any,
|
|
) -> TokenAsyncGen:
|
|
"""Async stream."""
|
|
self._log_template_data(prompt, **prompt_args)
|
|
|
|
if self.metadata.is_chat_model:
|
|
messages = self._get_messages(prompt, **prompt_args)
|
|
chat_response = await self.astream_chat(messages)
|
|
stream_tokens = await astream_chat_response_to_tokens(chat_response)
|
|
else:
|
|
formatted_prompt = self._get_prompt(prompt, **prompt_args)
|
|
stream_response = await self.astream_complete(
|
|
formatted_prompt, formatted=True
|
|
)
|
|
stream_tokens = await astream_completion_response_to_tokens(stream_response)
|
|
|
|
if prompt.output_parser is not None or self.output_parser is not None:
|
|
raise NotImplementedError("Output parser is not supported for streaming.")
|
|
|
|
return stream_tokens
|
|
|
|
def _extend_prompt(
|
|
self,
|
|
formatted_prompt: str,
|
|
) -> str:
|
|
"""Add system and query wrapper prompts to base prompt."""
|
|
extended_prompt = formatted_prompt
|
|
|
|
if self.system_prompt:
|
|
extended_prompt = self.system_prompt + "\n\n" + extended_prompt
|
|
|
|
if self.query_wrapper_prompt:
|
|
extended_prompt = self.query_wrapper_prompt.format(
|
|
query_str=extended_prompt
|
|
)
|
|
|
|
return extended_prompt
|
|
|
|
def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
|
|
"""Add system prompt to chat message list."""
|
|
if self.system_prompt:
|
|
messages = [
|
|
ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
|
|
*messages,
|
|
]
|
|
return messages
|
|
|
|
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
|
|
"""Return query component."""
|
|
if self.metadata.is_chat_model:
|
|
return LLMChatComponent(llm=self, **kwargs)
|
|
else:
|
|
return LLMCompleteComponent(llm=self, **kwargs)
|
|
|
|
|
|
class BaseLLMComponent(QueryComponent):
|
|
"""Base LLM component."""
|
|
|
|
llm: LLM = Field(..., description="LLM")
|
|
streaming: bool = Field(default=False, description="Streaming mode")
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def set_callback_manager(self, callback_manager: Any) -> None:
|
|
"""Set callback manager."""
|
|
self.llm.callback_manager = callback_manager
|
|
|
|
|
|
class LLMCompleteComponent(BaseLLMComponent):
|
|
"""LLM completion component."""
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
if "prompt" not in input:
|
|
raise ValueError("Prompt must be in input dict.")
|
|
|
|
# do special check to see if prompt is a list of chat messages
|
|
if isinstance(input["prompt"], get_args(List[ChatMessage])):
|
|
input["prompt"] = self.llm.messages_to_prompt(input["prompt"])
|
|
input["prompt"] = validate_and_convert_stringable(input["prompt"])
|
|
else:
|
|
input["prompt"] = validate_and_convert_stringable(input["prompt"])
|
|
input["prompt"] = self.llm.completion_to_prompt(input["prompt"])
|
|
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# TODO: support only complete for now
|
|
# non-trivial to figure how to support chat/complete/etc.
|
|
prompt = kwargs["prompt"]
|
|
# ignore all other kwargs for now
|
|
if self.streaming:
|
|
response = self.llm.stream_complete(prompt, formatted=True)
|
|
else:
|
|
response = self.llm.complete(prompt, formatted=True)
|
|
return {"output": response}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# TODO: support only complete for now
|
|
# non-trivial to figure how to support chat/complete/etc.
|
|
prompt = kwargs["prompt"]
|
|
# ignore all other kwargs for now
|
|
response = await self.llm.acomplete(prompt, formatted=True)
|
|
return {"output": response}
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
# TODO: support only complete for now
|
|
return InputKeys.from_keys({"prompt"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"output"})
|
|
|
|
|
|
class LLMChatComponent(BaseLLMComponent):
|
|
"""LLM chat component."""
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
if "messages" not in input:
|
|
raise ValueError("Messages must be in input dict.")
|
|
|
|
# if `messages` is a string, convert to a list of chat message
|
|
if isinstance(input["messages"], get_args(StringableInput)):
|
|
input["messages"] = validate_and_convert_stringable(input["messages"])
|
|
input["messages"] = prompt_to_messages(str(input["messages"]))
|
|
|
|
for message in input["messages"]:
|
|
if not isinstance(message, ChatMessage):
|
|
raise ValueError("Messages must be a list of ChatMessage")
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# TODO: support only complete for now
|
|
# non-trivial to figure how to support chat/complete/etc.
|
|
messages = kwargs["messages"]
|
|
if self.streaming:
|
|
response = self.llm.stream_chat(messages)
|
|
else:
|
|
response = self.llm.chat(messages)
|
|
return {"output": response}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# TODO: support only complete for now
|
|
# non-trivial to figure how to support chat/complete/etc.
|
|
messages = kwargs["messages"]
|
|
if self.streaming:
|
|
response = await self.llm.astream_chat(messages)
|
|
else:
|
|
response = await self.llm.achat(messages)
|
|
return {"output": response}
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
# TODO: support only complete for now
|
|
return InputKeys.from_keys({"messages"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"output"})
|