faiss_rag_enterprise/llama_index/multi_modal_llms/gemini.py

268 lines
9.7 KiB
Python

"""Google's Gemini multi-modal models."""
import os
import typing
from typing import Any, Dict, Optional, Sequence
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
)
from llama_index.llms.gemini_utils import (
ROLES_FROM_GEMINI,
chat_from_gemini_response,
chat_message_to_gemini,
completion_from_gemini_response,
)
from llama_index.multi_modal_llms import (
MultiModalLLM,
MultiModalLLMMetadata,
)
from llama_index.schema import ImageDocument
if typing.TYPE_CHECKING:
import google.generativeai as genai
# PIL is imported lazily in the ctor but referenced throughout the module.
try:
import PIL
except ImportError:
# Swallow the error here, it's raised in the constructor where intent is clear.
pass
# This lists the multi-modal models - see also llms.gemini for text models.
GEMINI_MM_MODELS = (
"models/gemini-pro-vision",
"models/gemini-ultra-vision",
)
class GeminiMultiModal(MultiModalLLM):
"""Gemini multimodal."""
model_name: str = Field(
default=GEMINI_MM_MODELS[0], description="The Gemini model to use."
)
temperature: float = Field(
default=DEFAULT_TEMPERATURE,
description="The temperature to use during generation.",
gte=0.0,
lte=1.0,
)
max_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The number of tokens to generate.",
gt=0,
)
generate_kwargs: dict = Field(
default_factory=dict, description="Kwargs for generation."
)
_model: "genai.GenerativeModel" = PrivateAttr()
_model_meta: "genai.types.Model" = PrivateAttr()
def __init__(
self,
api_key: Optional[str] = None,
model_name: Optional[str] = GEMINI_MM_MODELS[0],
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: Optional[int] = None,
generation_config: Optional["genai.types.GenerationConfigDict"] = None,
safety_settings: "genai.types.SafetySettingOptions" = None,
api_base: Optional[str] = None,
transport: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
**generate_kwargs: Any,
):
"""Creates a new Gemini model interface."""
try:
import google.generativeai as genai
except ImportError:
raise ValueError(
"Gemini is not installed. Please install it with "
"`pip install 'google-generativeai>=0.3.0'`."
)
try:
import PIL # noqa: F401
except ImportError:
raise ValueError(
"Multi-modal support requires PIL. Please install it with "
"`pip install pillow`."
)
# API keys are optional. The API can be authorised via OAuth (detected
# environmentally) or by the GOOGLE_API_KEY environment variable.
config_params: Dict[str, Any] = {
"api_key": api_key or os.getenv("GOOGLE_API_KEY"),
}
if api_base:
config_params["client_options"] = {"api_endpoint": api_base}
if transport:
config_params["transport"] = transport
# transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
genai.configure(**config_params)
base_gen_config = generation_config if generation_config else {}
# Explicitly passed args take precedence over the generation_config.
final_gen_config = {"temperature": temperature} | base_gen_config
# Check whether the Gemini Model is supported or not
if model_name not in GEMINI_MM_MODELS:
raise ValueError(
f"Invalid model {model_name}. "
f"Available models are: {GEMINI_MM_MODELS}"
)
self._model = genai.GenerativeModel(
model_name=model_name,
generation_config=final_gen_config,
safety_settings=safety_settings,
)
self._model_meta = genai.get_model(model_name)
supported_methods = self._model_meta.supported_generation_methods
if "generateContent" not in supported_methods:
raise ValueError(
f"Model {model_name} does not support content generation, only "
f"{supported_methods}."
)
if not max_tokens:
max_tokens = self._model_meta.output_token_limit
else:
max_tokens = min(max_tokens, self._model_meta.output_token_limit)
super().__init__(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
generate_kwargs=generate_kwargs,
callback_manager=callback_manager,
)
@classmethod
def class_name(cls) -> str:
return "Gemini_MultiModal_LLM"
@property
def metadata(self) -> MultiModalLLMMetadata:
total_tokens = self._model_meta.input_token_limit + self.max_tokens
return MultiModalLLMMetadata(
context_window=total_tokens,
num_output=self.max_tokens,
model_name=self.model_name,
)
def complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
result = self._model.generate_content([prompt, *images], **kwargs)
return completion_from_gemini_response(result)
def stream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseGen:
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
result = self._model.generate_content([prompt, *images], stream=True, **kwargs)
yield from map(completion_from_gemini_response, result)
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
*history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history)
response = chat.send_message(next_msg)
return chat_from_gemini_response(response)
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
*history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history)
response = chat.send_message(next_msg, stream=True)
def gen() -> ChatResponseGen:
content = ""
for r in response:
top_candidate = r.candidates[0]
content_delta = top_candidate.content.parts[0].text
role = ROLES_FROM_GEMINI[top_candidate.content.role]
raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(
type(response.prompt_feedback).to_dict(response.prompt_feedback)
),
}
content += content_delta
yield ChatResponse(
message=ChatMessage(role=role, content=content),
delta=content_delta,
raw=raw,
)
return gen()
async def acomplete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
result = await self._model.generate_content_async([prompt, *images], **kwargs)
return completion_from_gemini_response(result)
async def astream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseAsyncGen:
images = [PIL.Image.open(doc.resolve_image()) for doc in image_documents]
ait = await self._model.generate_content_async(
[prompt, *images], stream=True, **kwargs
)
async def gen() -> CompletionResponseAsyncGen:
async for comp in ait:
yield completion_from_gemini_response(comp)
return gen()
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
*history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history)
response = await chat.send_message_async(next_msg)
return chat_from_gemini_response(response)
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
*history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history)
response = await chat.send_message_async(next_msg, stream=True)
async def gen() -> ChatResponseAsyncGen:
content = ""
for r in response:
top_candidate = r.candidates[0]
content_delta = top_candidate.content.parts[0].text
role = ROLES_FROM_GEMINI[top_candidate.content.role]
raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(
type(response.prompt_feedback).to_dict(response.prompt_feedback)
),
}
content += content_delta
yield ChatResponse(
message=ChatMessage(role=role, content=content),
delta=content_delta,
raw=raw,
)
return gen()