231 lines
7.7 KiB
Python
231 lines
7.7 KiB
Python
from abc import abstractmethod
|
|
from typing import Any, Dict, List, Optional, Sequence, get_args
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field
|
|
from llama_index.constants import (
|
|
DEFAULT_CONTEXT_WINDOW,
|
|
DEFAULT_NUM_INPUT_FILES,
|
|
DEFAULT_NUM_OUTPUTS,
|
|
)
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
)
|
|
from llama_index.core.query_pipeline.query_component import (
|
|
ChainableMixin,
|
|
InputKeys,
|
|
OutputKeys,
|
|
QueryComponent,
|
|
validate_and_convert_stringable,
|
|
)
|
|
from llama_index.schema import BaseComponent, ImageDocument
|
|
|
|
|
|
class MultiModalLLMMetadata(BaseModel):
|
|
context_window: Optional[int] = Field(
|
|
default=DEFAULT_CONTEXT_WINDOW,
|
|
description=(
|
|
"Total number of tokens the model can be input when generating a response."
|
|
),
|
|
)
|
|
num_output: Optional[int] = Field(
|
|
default=DEFAULT_NUM_OUTPUTS,
|
|
description="Number of tokens the model can output when generating a response.",
|
|
)
|
|
num_input_files: Optional[int] = Field(
|
|
default=DEFAULT_NUM_INPUT_FILES,
|
|
description="Number of input files the model can take when generating a response.",
|
|
)
|
|
is_function_calling_model: Optional[bool] = Field(
|
|
default=False,
|
|
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
|
|
description=(
|
|
"Set True if the model supports function calling messages, similar to"
|
|
" OpenAI's function calling API. For example, converting 'Email Anya to"
|
|
" see if she wants to get coffee next Friday' to a function call like"
|
|
" `send_email(to: string, body: string)`."
|
|
),
|
|
)
|
|
model_name: str = Field(
|
|
default="unknown",
|
|
description=(
|
|
"The model's name used for logging, testing, and sanity checking. For some"
|
|
" models this can be automatically discerned. For other models, like"
|
|
" locally loaded models, this must be manually specified."
|
|
),
|
|
)
|
|
|
|
is_chat_model: bool = Field(
|
|
default=False,
|
|
description=(
|
|
"Set True if the model exposes a chat interface (i.e. can be passed a"
|
|
" sequence of messages, rather than text), like OpenAI's"
|
|
" /v1/chat/completions endpoint."
|
|
),
|
|
)
|
|
|
|
|
|
# TODO add callback functionality
|
|
|
|
|
|
class MultiModalLLM(ChainableMixin, BaseComponent):
|
|
"""Multi-Modal LLM interface."""
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
@abstractmethod
|
|
def metadata(self) -> MultiModalLLMMetadata:
|
|
"""Multi-Modal LLM metadata."""
|
|
|
|
@abstractmethod
|
|
def complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
"""Completion endpoint for Multi-Modal LLM."""
|
|
|
|
@abstractmethod
|
|
def stream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
"""Streaming completion endpoint for Multi-Modal LLM."""
|
|
|
|
@abstractmethod
|
|
def chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Chat endpoint for Multi-Modal LLM."""
|
|
|
|
@abstractmethod
|
|
def stream_chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponseGen:
|
|
"""Stream chat endpoint for Multi-Modal LLM."""
|
|
|
|
# ===== Async Endpoints =====
|
|
|
|
@abstractmethod
|
|
async def acomplete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
"""Async completion endpoint for Multi-Modal LLM."""
|
|
|
|
@abstractmethod
|
|
async def astream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
"""Async streaming completion endpoint for Multi-Modal LLM."""
|
|
|
|
@abstractmethod
|
|
async def achat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Async chat endpoint for Multi-Modal LLM."""
|
|
|
|
@abstractmethod
|
|
async def astream_chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponseAsyncGen:
|
|
"""Async streaming chat endpoint for Multi-Modal LLM."""
|
|
|
|
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
|
|
"""Return query component."""
|
|
if self.metadata.is_chat_model:
|
|
# TODO: we don't have a separate chat component
|
|
return MultiModalCompleteComponent(multi_modal_llm=self, **kwargs)
|
|
else:
|
|
return MultiModalCompleteComponent(multi_modal_llm=self, **kwargs)
|
|
|
|
|
|
class BaseMultiModalComponent(QueryComponent):
|
|
"""Base LLM component."""
|
|
|
|
multi_modal_llm: MultiModalLLM = Field(..., description="LLM")
|
|
streaming: bool = Field(default=False, description="Streaming mode")
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def set_callback_manager(self, callback_manager: Any) -> None:
|
|
"""Set callback manager."""
|
|
# TODO: make callbacks work with multi-modal
|
|
|
|
|
|
class MultiModalCompleteComponent(BaseMultiModalComponent):
|
|
"""Multi-modal completion component."""
|
|
|
|
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate component inputs during run_component."""
|
|
if "prompt" not in input:
|
|
raise ValueError("Prompt must be in input dict.")
|
|
|
|
# do special check to see if prompt is a list of chat messages
|
|
if isinstance(input["prompt"], get_args(List[ChatMessage])):
|
|
raise NotImplementedError(
|
|
"Chat messages not yet supported as input to multi-modal model."
|
|
)
|
|
else:
|
|
input["prompt"] = validate_and_convert_stringable(input["prompt"])
|
|
|
|
# make sure image documents are valid
|
|
if "image_documents" in input:
|
|
if not isinstance(input["image_documents"], list):
|
|
raise ValueError("image_documents must be a list.")
|
|
for doc in input["image_documents"]:
|
|
if not isinstance(doc, ImageDocument):
|
|
raise ValueError(
|
|
"image_documents must be a list of ImageDocument objects."
|
|
)
|
|
|
|
return input
|
|
|
|
def _run_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# TODO: support only complete for now
|
|
prompt = kwargs["prompt"]
|
|
image_documents = kwargs.get("image_documents", [])
|
|
if self.streaming:
|
|
response = self.multi_modal_llm.stream_complete(prompt, image_documents)
|
|
else:
|
|
response = self.multi_modal_llm.complete(prompt, image_documents)
|
|
return {"output": response}
|
|
|
|
async def _arun_component(self, **kwargs: Any) -> Any:
|
|
"""Run component."""
|
|
# TODO: support only complete for now
|
|
# non-trivial to figure how to support chat/complete/etc.
|
|
prompt = kwargs["prompt"]
|
|
image_documents = kwargs.get("image_documents", [])
|
|
if self.streaming:
|
|
response = await self.multi_modal_llm.astream_complete(
|
|
prompt, image_documents
|
|
)
|
|
else:
|
|
response = await self.multi_modal_llm.acomplete(prompt, image_documents)
|
|
return {"output": response}
|
|
|
|
@property
|
|
def input_keys(self) -> InputKeys:
|
|
"""Input keys."""
|
|
# TODO: support only complete for now
|
|
return InputKeys.from_keys({"prompt", "image_documents"})
|
|
|
|
@property
|
|
def output_keys(self) -> OutputKeys:
|
|
"""Output keys."""
|
|
return OutputKeys.from_keys({"output"})
|