268 lines
9.7 KiB
Python
268 lines
9.7 KiB
Python
"""Google's Gemini multi-modal models."""
|
|
import os
|
|
import typing
|
|
from typing import Any, Dict, Optional, Sequence
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
)
|
|
from llama_index.llms.gemini_utils import (
|
|
ROLES_FROM_GEMINI,
|
|
chat_from_gemini_response,
|
|
chat_message_to_gemini,
|
|
completion_from_gemini_response,
|
|
)
|
|
from llama_index.multi_modal_llms import (
|
|
MultiModalLLM,
|
|
MultiModalLLMMetadata,
|
|
)
|
|
from llama_index.schema import ImageDocument
|
|
|
|
if typing.TYPE_CHECKING:
|
|
import google.generativeai as genai
|
|
|
|
# PIL is imported lazily in the ctor but referenced throughout the module.
|
|
try:
|
|
import PIL
|
|
except ImportError:
|
|
# Swallow the error here, it's raised in the constructor where intent is clear.
|
|
pass
|
|
|
|
# This lists the multi-modal models - see also llms.gemini for text models.
|
|
GEMINI_MM_MODELS = (
|
|
"models/gemini-pro-vision",
|
|
"models/gemini-ultra-vision",
|
|
)
|
|
|
|
|
|
class GeminiMultiModal(MultiModalLLM):
|
|
"""Gemini multimodal."""
|
|
|
|
model_name: str = Field(
|
|
default=GEMINI_MM_MODELS[0], description="The Gemini model to use."
|
|
)
|
|
temperature: float = Field(
|
|
default=DEFAULT_TEMPERATURE,
|
|
description="The temperature to use during generation.",
|
|
gte=0.0,
|
|
lte=1.0,
|
|
)
|
|
max_tokens: int = Field(
|
|
default=DEFAULT_NUM_OUTPUTS,
|
|
description="The number of tokens to generate.",
|
|
gt=0,
|
|
)
|
|
generate_kwargs: dict = Field(
|
|
default_factory=dict, description="Kwargs for generation."
|
|
)
|
|
|
|
_model: "genai.GenerativeModel" = PrivateAttr()
|
|
_model_meta: "genai.types.Model" = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
model_name: Optional[str] = GEMINI_MM_MODELS[0],
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
max_tokens: Optional[int] = None,
|
|
generation_config: Optional["genai.types.GenerationConfigDict"] = None,
|
|
safety_settings: "genai.types.SafetySettingOptions" = None,
|
|
api_base: Optional[str] = None,
|
|
transport: Optional[str] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
**generate_kwargs: Any,
|
|
):
|
|
"""Creates a new Gemini model interface."""
|
|
try:
|
|
import google.generativeai as genai
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Gemini is not installed. Please install it with "
|
|
"`pip install 'google-generativeai>=0.3.0'`."
|
|
)
|
|
try:
|
|
import PIL # noqa: F401
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Multi-modal support requires PIL. Please install it with "
|
|
"`pip install pillow`."
|
|
)
|
|
|
|
# API keys are optional. The API can be authorised via OAuth (detected
|
|
# environmentally) or by the GOOGLE_API_KEY environment variable.
|
|
config_params: Dict[str, Any] = {
|
|
"api_key": api_key or os.getenv("GOOGLE_API_KEY"),
|
|
}
|
|
if api_base:
|
|
config_params["client_options"] = {"api_endpoint": api_base}
|
|
if transport:
|
|
config_params["transport"] = transport
|
|
# transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
|
|
genai.configure(**config_params)
|
|
|
|
base_gen_config = generation_config if generation_config else {}
|
|
# Explicitly passed args take precedence over the generation_config.
|
|
final_gen_config = {"temperature": temperature} | base_gen_config
|
|
|
|
# Check whether the Gemini Model is supported or not
|
|
if model_name not in GEMINI_MM_MODELS:
|
|
raise ValueError(
|
|
f"Invalid model {model_name}. "
|
|
f"Available models are: {GEMINI_MM_MODELS}"
|
|
)
|
|
|
|
self._model = genai.GenerativeModel(
|
|
model_name=model_name,
|
|
generation_config=final_gen_config,
|
|
safety_settings=safety_settings,
|
|
)
|
|
|
|
self._model_meta = genai.get_model(model_name)
|
|
|
|
supported_methods = self._model_meta.supported_generation_methods
|
|
if "generateContent" not in supported_methods:
|
|
raise ValueError(
|
|
f"Model {model_name} does not support content generation, only "
|
|
f"{supported_methods}."
|
|
)
|
|
|
|
if not max_tokens:
|
|
max_tokens = self._model_meta.output_token_limit
|
|
else:
|
|
max_tokens = min(max_tokens, self._model_meta.output_token_limit)
|
|
|
|
super().__init__(
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
generate_kwargs=generate_kwargs,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "Gemini_MultiModal_LLM"
|
|
|
|
@property
|
|
def metadata(self) -> MultiModalLLMMetadata:
|
|
total_tokens = self._model_meta.input_token_limit + self.max_tokens
|
|
return MultiModalLLMMetadata(
|
|
context_window=total_tokens,
|
|
num_output=self.max_tokens,
|
|
model_name=self.model_name,
|
|
)
|
|
|
|
def complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
|
|
result = self._model.generate_content([prompt, *images], **kwargs)
|
|
return completion_from_gemini_response(result)
|
|
|
|
def stream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
|
|
result = self._model.generate_content([prompt, *images], stream=True, **kwargs)
|
|
yield from map(completion_from_gemini_response, result)
|
|
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
*history, next_msg = map(chat_message_to_gemini, messages)
|
|
chat = self._model.start_chat(history=history)
|
|
response = chat.send_message(next_msg)
|
|
return chat_from_gemini_response(response)
|
|
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
*history, next_msg = map(chat_message_to_gemini, messages)
|
|
chat = self._model.start_chat(history=history)
|
|
response = chat.send_message(next_msg, stream=True)
|
|
|
|
def gen() -> ChatResponseGen:
|
|
content = ""
|
|
for r in response:
|
|
top_candidate = r.candidates[0]
|
|
content_delta = top_candidate.content.parts[0].text
|
|
role = ROLES_FROM_GEMINI[top_candidate.content.role]
|
|
raw = {
|
|
**(type(top_candidate).to_dict(top_candidate)),
|
|
**(
|
|
type(response.prompt_feedback).to_dict(response.prompt_feedback)
|
|
),
|
|
}
|
|
content += content_delta
|
|
yield ChatResponse(
|
|
message=ChatMessage(role=role, content=content),
|
|
delta=content_delta,
|
|
raw=raw,
|
|
)
|
|
|
|
return gen()
|
|
|
|
async def acomplete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
|
|
result = await self._model.generate_content_async([prompt, *images], **kwargs)
|
|
return completion_from_gemini_response(result)
|
|
|
|
async def astream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
|
|
ait = await self._model.generate_content_async(
|
|
[prompt, *images], stream=True, **kwargs
|
|
)
|
|
|
|
async def gen() -> CompletionResponseAsyncGen:
|
|
async for comp in ait:
|
|
yield completion_from_gemini_response(comp)
|
|
|
|
return gen()
|
|
|
|
async def achat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponse:
|
|
*history, next_msg = map(chat_message_to_gemini, messages)
|
|
chat = self._model.start_chat(history=history)
|
|
response = await chat.send_message_async(next_msg)
|
|
return chat_from_gemini_response(response)
|
|
|
|
async def astream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseAsyncGen:
|
|
*history, next_msg = map(chat_message_to_gemini, messages)
|
|
chat = self._model.start_chat(history=history)
|
|
response = await chat.send_message_async(next_msg, stream=True)
|
|
|
|
async def gen() -> ChatResponseAsyncGen:
|
|
content = ""
|
|
for r in response:
|
|
top_candidate = r.candidates[0]
|
|
content_delta = top_candidate.content.parts[0].text
|
|
role = ROLES_FROM_GEMINI[top_candidate.content.role]
|
|
raw = {
|
|
**(type(top_candidate).to_dict(top_candidate)),
|
|
**(
|
|
type(response.prompt_feedback).to_dict(response.prompt_feedback)
|
|
),
|
|
}
|
|
content += content_delta
|
|
yield ChatResponse(
|
|
message=ChatMessage(role=role, content=content),
|
|
delta=content_delta,
|
|
raw=raw,
|
|
)
|
|
|
|
return gen()
|