faiss_rag_enterprise/llama_index/multi_modal_llms/openai.py

502 lines
18 KiB
Python

from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast
import httpx
from openai import AsyncOpenAI
from openai import OpenAI as SyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
ChoiceDelta,
ChoiceDeltaToolCall,
)
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import (
DEFAULT_CONTEXT_WINDOW,
DEFAULT_NUM_OUTPUTS,
DEFAULT_TEMPERATURE,
)
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
MessageRole,
)
from llama_index.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
from llama_index.llms.openai_utils import (
from_openai_message,
resolve_openai_credentials,
to_openai_message_dicts,
)
from llama_index.multi_modal_llms import (
MultiModalLLM,
MultiModalLLMMetadata,
)
from llama_index.multi_modal_llms.openai_utils import (
GPT4V_MODELS,
generate_openai_multi_modal_chat_message,
)
from llama_index.schema import ImageDocument
class OpenAIMultiModal(MultiModalLLM):
model: str = Field(description="The Multi-Modal model to use from OpenAI.")
temperature: float = Field(description="The temperature to use for sampling.")
max_new_tokens: Optional[int] = Field(
description=" The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
gt=0,
)
context_window: Optional[int] = Field(
description="The maximum number of context tokens for the model.",
gt=0,
)
image_detail: str = Field(
description="The level of details for image in API calls. Can be low, high, or auto"
)
max_retries: int = Field(
default=3,
description="Maximum number of retries.",
gte=0,
)
timeout: float = Field(
default=60.0,
description="The timeout, in seconds, for API requests.",
gte=0,
)
api_key: str = Field(default=None, description="The OpenAI API key.", exclude=True)
api_base: str = Field(default=None, description="The base URL for OpenAI API.")
api_version: str = Field(description="The API version for OpenAI API.")
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the OpenAI API."
)
default_headers: Dict[str, str] = Field(
default=None, description="The default headers for API requests."
)
_messages_to_prompt: Callable = PrivateAttr()
_completion_to_prompt: Callable = PrivateAttr()
_client: SyncOpenAI = PrivateAttr()
_aclient: AsyncOpenAI = PrivateAttr()
_http_client: Optional[httpx.Client] = PrivateAttr()
def __init__(
self,
model: str = "gpt-4-vision-preview",
temperature: float = DEFAULT_TEMPERATURE,
max_new_tokens: Optional[int] = 300,
additional_kwargs: Optional[Dict[str, Any]] = None,
context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW,
max_retries: int = 3,
timeout: float = 60.0,
image_detail: str = "low",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
messages_to_prompt: Optional[Callable] = None,
completion_to_prompt: Optional[Callable] = None,
callback_manager: Optional[CallbackManager] = None,
default_headers: Optional[Dict[str, str]] = None,
http_client: Optional[httpx.Client] = None,
**kwargs: Any,
) -> None:
self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
self._completion_to_prompt = completion_to_prompt or (lambda x: x)
api_key, api_base, api_version = resolve_openai_credentials(
api_key=api_key,
api_base=api_base,
api_version=api_version,
)
super().__init__(
model=model,
temperature=temperature,
max_new_tokens=max_new_tokens,
additional_kwargs=additional_kwargs or {},
context_window=context_window,
image_detail=image_detail,
max_retries=max_retries,
timeout=timeout,
api_key=api_key,
api_base=api_base,
api_version=api_version,
callback_manager=callback_manager,
default_headers=default_headers,
**kwargs,
)
self._http_client = http_client
self._client, self._aclient = self._get_clients(**kwargs)
def _get_clients(self, **kwargs: Any) -> Tuple[SyncOpenAI, AsyncOpenAI]:
client = SyncOpenAI(**self._get_credential_kwargs())
aclient = AsyncOpenAI(**self._get_credential_kwargs())
return client, aclient
@classmethod
def class_name(cls) -> str:
return "openai_multi_modal_llm"
@property
def metadata(self) -> MultiModalLLMMetadata:
"""Multi Modal LLM metadata."""
return MultiModalLLMMetadata(
num_output=self.max_new_tokens or DEFAULT_NUM_OUTPUTS,
model_name=self.model,
)
def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
return {
"api_key": self.api_key,
"base_url": self.api_base,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"http_client": self._http_client,
"timeout": self.timeout,
**kwargs,
}
def _get_multi_modal_chat_messages(
self,
prompt: str,
role: str,
image_documents: Sequence[ImageDocument],
**kwargs: Any,
) -> List[ChatCompletionMessageParam]:
return to_openai_message_dicts(
[
generate_openai_multi_modal_chat_message(
prompt=prompt,
role=role,
image_documents=image_documents,
image_detail=self.image_detail,
)
]
)
# Model Params for OpenAI GPT4V model.
def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
if self.model not in GPT4V_MODELS:
raise ValueError(
f"Invalid model {self.model}. "
f"Available models are: {list(GPT4V_MODELS.keys())}"
)
base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs}
if self.max_new_tokens is not None:
# If max_tokens is None, don't include in the payload:
# https://platform.openai.com/docs/api-reference/chat
# https://platform.openai.com/docs/api-reference/completions
base_kwargs["max_tokens"] = self.max_new_tokens
return {**base_kwargs, **self.additional_kwargs}
def _get_response_token_counts(self, raw_response: Any) -> dict:
"""Get the token usage reported by the response."""
if not isinstance(raw_response, dict):
return {}
usage = raw_response.get("usage", {})
# NOTE: other model providers that use the OpenAI client may not report usage
if usage is None:
return {}
return {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
def _complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
all_kwargs = self._get_model_kwargs(**kwargs)
message_dict = self._get_multi_modal_chat_messages(
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
)
response = self._client.chat.completions.create(
messages=message_dict,
stream=False,
**all_kwargs,
)
return CompletionResponse(
text=response.choices[0].message.content,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
all_kwargs = self._get_model_kwargs(**kwargs)
message_dicts = to_openai_message_dicts(messages)
response = self._client.chat.completions.create(
messages=message_dicts,
stream=False,
**all_kwargs,
)
openai_message = response.choices[0].message
message = from_openai_message(openai_message)
return ChatResponse(
message=message,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
def _stream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseGen:
all_kwargs = self._get_model_kwargs(**kwargs)
message_dict = self._get_multi_modal_chat_messages(
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
)
def gen() -> CompletionResponseGen:
text = ""
for response in self._client.chat.completions.create(
messages=message_dict,
stream=True,
**all_kwargs,
):
response = cast(ChatCompletionChunk, response)
if len(response.choices) > 0:
delta = response.choices[0].delta
else:
delta = ChoiceDelta()
# update using deltas
content_delta = delta.content or ""
text += content_delta
yield CompletionResponse(
delta=content_delta,
text=text,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
return gen()
def _stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
message_dicts = to_openai_message_dicts(messages)
def gen() -> ChatResponseGen:
content = ""
tool_calls: List[ChoiceDeltaToolCall] = []
is_function = False
for response in self._client.chat.completions.create(
messages=message_dicts,
stream=True,
**self._get_model_kwargs(**kwargs),
):
response = cast(ChatCompletionChunk, response)
if len(response.choices) > 0:
delta = response.choices[0].delta
else:
delta = ChoiceDelta()
# check if this chunk is the start of a function call
if delta.tool_calls:
is_function = True
# update using deltas
role = delta.role or MessageRole.ASSISTANT
content_delta = delta.content or ""
content += content_delta
additional_kwargs = {}
if is_function:
tool_calls = self._update_tool_calls(tool_calls, delta.tool_calls)
additional_kwargs["tool_calls"] = tool_calls
yield ChatResponse(
message=ChatMessage(
role=role,
content=content,
additional_kwargs=additional_kwargs,
),
delta=content_delta,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
return gen()
def complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
return self._complete(prompt, image_documents, **kwargs)
def stream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseGen:
return self._stream_complete(prompt, image_documents, **kwargs)
def chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
return self._chat(messages, **kwargs)
def stream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseGen:
return self._stream_chat(messages, **kwargs)
# ===== Async Endpoints =====
async def _acomplete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
all_kwargs = self._get_model_kwargs(**kwargs)
message_dict = self._get_multi_modal_chat_messages(
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
)
response = await self._aclient.chat.completions.create(
messages=message_dict,
stream=False,
**all_kwargs,
)
return CompletionResponse(
text=response.choices[0].message.content,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
async def acomplete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponse:
return await self._acomplete(prompt, image_documents, **kwargs)
async def _astream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseAsyncGen:
all_kwargs = self._get_model_kwargs(**kwargs)
message_dict = self._get_multi_modal_chat_messages(
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
)
async def gen() -> CompletionResponseAsyncGen:
text = ""
async for response in await self._aclient.chat.completions.create(
messages=message_dict,
stream=True,
**all_kwargs,
):
response = cast(ChatCompletionChunk, response)
if len(response.choices) > 0:
delta = response.choices[0].delta
else:
delta = ChoiceDelta()
# update using deltas
content_delta = delta.content or ""
text += content_delta
yield CompletionResponse(
delta=content_delta,
text=text,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
return gen()
async def _achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
all_kwargs = self._get_model_kwargs(**kwargs)
message_dicts = to_openai_message_dicts(messages)
response = await self._aclient.chat.completions.create(
messages=message_dicts,
stream=False,
**all_kwargs,
)
openai_message = response.choices[0].message
message = from_openai_message(openai_message)
return ChatResponse(
message=message,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
async def _astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
message_dicts = to_openai_message_dicts(messages)
async def gen() -> ChatResponseAsyncGen:
content = ""
tool_calls: List[ChoiceDeltaToolCall] = []
is_function = False
async for response in await self._aclient.chat.completions.create(
messages=message_dicts,
stream=True,
**self._get_model_kwargs(**kwargs),
):
response = cast(ChatCompletionChunk, response)
if len(response.choices) > 0:
delta = response.choices[0].delta
else:
delta = ChoiceDelta()
# check if this chunk is the start of a function call
if delta.tool_calls:
is_function = True
# update using deltas
role = delta.role or MessageRole.ASSISTANT
content_delta = delta.content or ""
content += content_delta
additional_kwargs = {}
if is_function:
tool_calls = self._update_tool_calls(tool_calls, delta.tool_calls)
additional_kwargs["tool_calls"] = tool_calls
yield ChatResponse(
message=ChatMessage(
role=role,
content=content,
additional_kwargs=additional_kwargs,
),
delta=content_delta,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)
return gen()
async def astream_complete(
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
) -> CompletionResponseAsyncGen:
return await self._astream_complete(prompt, image_documents, **kwargs)
async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
return await self._achat(messages, **kwargs)
async def astream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseAsyncGen:
return await self._astream_chat(messages, **kwargs)