faiss_rag_enterprise/llama_index/multi_modal_llms/base.py

231 lines
7.7 KiB
Python

from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, get_args
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.constants import (
DEFAULT_CONTEXT_WINDOW,
DEFAULT_NUM_INPUT_FILES,
DEFAULT_NUM_OUTPUTS,
)
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
)
from llama_index.core.query_pipeline.query_component import (
ChainableMixin,
InputKeys,
OutputKeys,
QueryComponent,
validate_and_convert_stringable,
)
from llama_index.schema import BaseComponent, ImageDocument
class MultiModalLLMMetadata(BaseModel):
context_window: Optional[int] = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
"Total number of tokens the model can be input when generating a response."
),
)
num_output: Optional[int] = Field(
default=DEFAULT_NUM_OUTPUTS,
description="Number of tokens the model can output when generating a response.",
)
num_input_files: Optional[int] = Field(
default=DEFAULT_NUM_INPUT_FILES,
description="Number of input files the model can take when generating a response.",
)
is_function_calling_model: Optional[bool] = Field(
default=False,
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
description=(
"Set True if the model supports function calling messages, similar to"
" OpenAI's function calling API. For example, converting 'Email Anya to"
" see if she wants to get coffee next Friday' to a function call like"
" `send_email(to: string, body: string)`."
),
)
model_name: str = Field(
default="unknown",
description=(
"The model's name used for logging, testing, and sanity checking. For some"
" models this can be automatically discerned. For other models, like"
" locally loaded models, this must be manually specified."
),
)
is_chat_model: bool = Field(
default=False,
description=(
"Set True if the model exposes a chat interface (i.e. can be passed a"
" sequence of messages, rather than text), like OpenAI's"
" /v1/chat/completions endpoint."
),
)
# TODO add callback functionality
class MultiModalLLM(ChainableMixin, BaseComponent):
"""Multi-Modal LLM interface."""
class Config:
arbitrary_types_allowed = True
@property
@abstractmethod
def metadata(self) -> MultiModalLLMMetadata:
"""Multi-Modal LLM metadata."""
@abstractmethod
def complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
"""Completion endpoint for Multi-Modal LLM."""
@abstractmethod
def stream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseGen:
"""Streaming completion endpoint for Multi-Modal LLM."""
@abstractmethod
def chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
"""Chat endpoint for Multi-Modal LLM."""
@abstractmethod
def stream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseGen:
"""Stream chat endpoint for Multi-Modal LLM."""
# ===== Async Endpoints =====
@abstractmethod
async def acomplete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
"""Async completion endpoint for Multi-Modal LLM."""
@abstractmethod
async def astream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseAsyncGen:
"""Async streaming completion endpoint for Multi-Modal LLM."""
@abstractmethod
async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
"""Async chat endpoint for Multi-Modal LLM."""
@abstractmethod
async def astream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseAsyncGen:
"""Async streaming chat endpoint for Multi-Modal LLM."""
def _as_query_component(self, **kwargs: Any) -> QueryComponent:
"""Return query component."""
if self.metadata.is_chat_model:
# TODO: we don't have a separate chat component
return MultiModalCompleteComponent(multi_modal_llm=self, **kwargs)
else:
return MultiModalCompleteComponent(multi_modal_llm=self, **kwargs)
class BaseMultiModalComponent(QueryComponent):
"""Base LLM component."""
multi_modal_llm: MultiModalLLM = Field(..., description="LLM")
streaming: bool = Field(default=False, description="Streaming mode")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: Any) -> None:
"""Set callback manager."""
# TODO: make callbacks work with multi-modal
class MultiModalCompleteComponent(BaseMultiModalComponent):
"""Multi-modal completion component."""
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
if "prompt" not in input:
raise ValueError("Prompt must be in input dict.")
# do special check to see if prompt is a list of chat messages
if isinstance(input["prompt"], get_args(List[ChatMessage])):
raise NotImplementedError(
"Chat messages not yet supported as input to multi-modal model."
)
else:
input["prompt"] = validate_and_convert_stringable(input["prompt"])
# make sure image documents are valid
if "image_documents" in input:
if not isinstance(input["image_documents"], list):
raise ValueError("image_documents must be a list.")
for doc in input["image_documents"]:
if not isinstance(doc, ImageDocument):
raise ValueError(
"image_documents must be a list of ImageDocument objects."
)
return input
def _run_component(self, **kwargs: Any) -> Any:
"""Run component."""
# TODO: support only complete for now
prompt = kwargs["prompt"]
image_documents = kwargs.get("image_documents", [])
if self.streaming:
response = self.multi_modal_llm.stream_complete(prompt, image_documents)
else:
response = self.multi_modal_llm.complete(prompt, image_documents)
return {"output": response}
async def _arun_component(self, **kwargs: Any) -> Any:
"""Run component."""
# TODO: support only complete for now
# non-trivial to figure how to support chat/complete/etc.
prompt = kwargs["prompt"]
image_documents = kwargs.get("image_documents", [])
if self.streaming:
response = await self.multi_modal_llm.astream_complete(
prompt, image_documents
)
else:
response = await self.multi_modal_llm.acomplete(prompt, image_documents)
return {"output": response}
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# TODO: support only complete for now
return InputKeys.from_keys({"prompt", "image_documents"})
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
return OutputKeys.from_keys({"output"})