193 lines
6.5 KiB
Python
193 lines
6.5 KiB
Python
"""Google's hosted Gemini API."""
|
|
import os
|
|
import typing
|
|
from typing import Any, Dict, Optional, Sequence
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseGen,
|
|
LLMMetadata,
|
|
)
|
|
from llama_index.llms.base import (
|
|
llm_chat_callback,
|
|
llm_completion_callback,
|
|
)
|
|
from llama_index.llms.custom import CustomLLM
|
|
from llama_index.llms.gemini_utils import (
|
|
ROLES_FROM_GEMINI,
|
|
chat_from_gemini_response,
|
|
chat_message_to_gemini,
|
|
completion_from_gemini_response,
|
|
merge_neighboring_same_role_messages,
|
|
)
|
|
|
|
if typing.TYPE_CHECKING:
|
|
import google.generativeai as genai
|
|
|
|
|
|
GEMINI_MODELS = (
|
|
"models/gemini-pro",
|
|
"models/gemini-ultra",
|
|
)
|
|
|
|
|
|
class Gemini(CustomLLM):
|
|
"""Gemini."""
|
|
|
|
model_name: str = Field(
|
|
default=GEMINI_MODELS[0], description="The Gemini model to use."
|
|
)
|
|
temperature: float = Field(
|
|
default=DEFAULT_TEMPERATURE,
|
|
description="The temperature to use during generation.",
|
|
gte=0.0,
|
|
lte=1.0,
|
|
)
|
|
max_tokens: int = Field(
|
|
default=DEFAULT_NUM_OUTPUTS,
|
|
description="The number of tokens to generate.",
|
|
gt=0,
|
|
)
|
|
generate_kwargs: dict = Field(
|
|
default_factory=dict, description="Kwargs for generation."
|
|
)
|
|
|
|
_model: "genai.GenerativeModel" = PrivateAttr()
|
|
_model_meta: "genai.types.Model" = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
model_name: Optional[str] = GEMINI_MODELS[0],
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
max_tokens: Optional[int] = None,
|
|
generation_config: Optional["genai.types.GenerationConfigDict"] = None,
|
|
safety_settings: "genai.types.SafetySettingOptions" = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
api_base: Optional[str] = None,
|
|
transport: Optional[str] = None,
|
|
**generate_kwargs: Any,
|
|
):
|
|
"""Creates a new Gemini model interface."""
|
|
try:
|
|
import google.generativeai as genai
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Gemini is not installed. Please install it with "
|
|
"`pip install 'google-generativeai>=0.3.0'`."
|
|
)
|
|
|
|
# API keys are optional. The API can be authorised via OAuth (detected
|
|
# environmentally) or by the GOOGLE_API_KEY environment variable.
|
|
config_params: Dict[str, Any] = {
|
|
"api_key": api_key or os.getenv("GOOGLE_API_KEY"),
|
|
}
|
|
if api_base:
|
|
config_params["client_options"] = {"api_endpoint": api_base}
|
|
if transport:
|
|
config_params["transport"] = transport
|
|
# transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
|
|
genai.configure(**config_params)
|
|
|
|
base_gen_config = generation_config if generation_config else {}
|
|
# Explicitly passed args take precedence over the generation_config.
|
|
final_gen_config = {"temperature": temperature, **base_gen_config}
|
|
|
|
self._model = genai.GenerativeModel(
|
|
model_name=model_name,
|
|
generation_config=final_gen_config,
|
|
safety_settings=safety_settings,
|
|
)
|
|
|
|
self._model_meta = genai.get_model(model_name)
|
|
|
|
supported_methods = self._model_meta.supported_generation_methods
|
|
if "generateContent" not in supported_methods:
|
|
raise ValueError(
|
|
f"Model {model_name} does not support content generation, only "
|
|
f"{supported_methods}."
|
|
)
|
|
|
|
if not max_tokens:
|
|
max_tokens = self._model_meta.output_token_limit
|
|
else:
|
|
max_tokens = min(max_tokens, self._model_meta.output_token_limit)
|
|
|
|
super().__init__(
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
generate_kwargs=generate_kwargs,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "Gemini_LLM"
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
total_tokens = self._model_meta.input_token_limit + self.max_tokens
|
|
return LLMMetadata(
|
|
context_window=total_tokens,
|
|
num_output=self.max_tokens,
|
|
model_name=self.model_name,
|
|
is_chat_model=True,
|
|
)
|
|
|
|
@llm_completion_callback()
|
|
def complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
result = self._model.generate_content(prompt, **kwargs)
|
|
return completion_from_gemini_response(result)
|
|
|
|
def stream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
it = self._model.generate_content(prompt, stream=True, **kwargs)
|
|
yield from map(completion_from_gemini_response, it)
|
|
|
|
@llm_chat_callback()
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
merged_messages = merge_neighboring_same_role_messages(messages)
|
|
*history, next_msg = map(chat_message_to_gemini, merged_messages)
|
|
chat = self._model.start_chat(history=history)
|
|
response = chat.send_message(next_msg)
|
|
return chat_from_gemini_response(response)
|
|
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
merged_messages = merge_neighboring_same_role_messages(messages)
|
|
*history, next_msg = map(chat_message_to_gemini, merged_messages)
|
|
chat = self._model.start_chat(history=history)
|
|
response = chat.send_message(next_msg, stream=True)
|
|
|
|
def gen() -> ChatResponseGen:
|
|
content = ""
|
|
for r in response:
|
|
top_candidate = r.candidates[0]
|
|
content_delta = top_candidate.content.parts[0].text
|
|
role = ROLES_FROM_GEMINI[top_candidate.content.role]
|
|
raw = {
|
|
**(type(top_candidate).to_dict(top_candidate)),
|
|
**(
|
|
type(response.prompt_feedback).to_dict(response.prompt_feedback)
|
|
),
|
|
}
|
|
content += content_delta
|
|
yield ChatResponse(
|
|
message=ChatMessage(role=role, content=content),
|
|
delta=content_delta,
|
|
raw=raw,
|
|
)
|
|
|
|
return gen()
|