502 lines
18 KiB
Python
502 lines
18 KiB
Python
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast
|
|
|
|
import httpx
|
|
from openai import AsyncOpenAI
|
|
from openai import OpenAI as SyncOpenAI
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
ChatCompletionChunk,
|
|
ChoiceDelta,
|
|
ChoiceDeltaToolCall,
|
|
)
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.constants import (
|
|
DEFAULT_CONTEXT_WINDOW,
|
|
DEFAULT_NUM_OUTPUTS,
|
|
DEFAULT_TEMPERATURE,
|
|
)
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
MessageRole,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
messages_to_prompt as generic_messages_to_prompt,
|
|
)
|
|
from llama_index.llms.openai_utils import (
|
|
from_openai_message,
|
|
resolve_openai_credentials,
|
|
to_openai_message_dicts,
|
|
)
|
|
from llama_index.multi_modal_llms import (
|
|
MultiModalLLM,
|
|
MultiModalLLMMetadata,
|
|
)
|
|
from llama_index.multi_modal_llms.openai_utils import (
|
|
GPT4V_MODELS,
|
|
generate_openai_multi_modal_chat_message,
|
|
)
|
|
from llama_index.schema import ImageDocument
|
|
|
|
|
|
class OpenAIMultiModal(MultiModalLLM):
|
|
model: str = Field(description="The Multi-Modal model to use from OpenAI.")
|
|
temperature: float = Field(description="The temperature to use for sampling.")
|
|
max_new_tokens: Optional[int] = Field(
|
|
description=" The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
|
|
gt=0,
|
|
)
|
|
context_window: Optional[int] = Field(
|
|
description="The maximum number of context tokens for the model.",
|
|
gt=0,
|
|
)
|
|
image_detail: str = Field(
|
|
description="The level of details for image in API calls. Can be low, high, or auto"
|
|
)
|
|
max_retries: int = Field(
|
|
default=3,
|
|
description="Maximum number of retries.",
|
|
gte=0,
|
|
)
|
|
timeout: float = Field(
|
|
default=60.0,
|
|
description="The timeout, in seconds, for API requests.",
|
|
gte=0,
|
|
)
|
|
api_key: str = Field(default=None, description="The OpenAI API key.", exclude=True)
|
|
api_base: str = Field(default=None, description="The base URL for OpenAI API.")
|
|
api_version: str = Field(description="The API version for OpenAI API.")
|
|
additional_kwargs: Dict[str, Any] = Field(
|
|
default_factory=dict, description="Additional kwargs for the OpenAI API."
|
|
)
|
|
default_headers: Dict[str, str] = Field(
|
|
default=None, description="The default headers for API requests."
|
|
)
|
|
|
|
_messages_to_prompt: Callable = PrivateAttr()
|
|
_completion_to_prompt: Callable = PrivateAttr()
|
|
_client: SyncOpenAI = PrivateAttr()
|
|
_aclient: AsyncOpenAI = PrivateAttr()
|
|
_http_client: Optional[httpx.Client] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gpt-4-vision-preview",
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
max_new_tokens: Optional[int] = 300,
|
|
additional_kwargs: Optional[Dict[str, Any]] = None,
|
|
context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW,
|
|
max_retries: int = 3,
|
|
timeout: float = 60.0,
|
|
image_detail: str = "low",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
messages_to_prompt: Optional[Callable] = None,
|
|
completion_to_prompt: Optional[Callable] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
default_headers: Optional[Dict[str, str]] = None,
|
|
http_client: Optional[httpx.Client] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
|
|
self._completion_to_prompt = completion_to_prompt or (lambda x: x)
|
|
api_key, api_base, api_version = resolve_openai_credentials(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
)
|
|
|
|
super().__init__(
|
|
model=model,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
additional_kwargs=additional_kwargs or {},
|
|
context_window=context_window,
|
|
image_detail=image_detail,
|
|
max_retries=max_retries,
|
|
timeout=timeout,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
callback_manager=callback_manager,
|
|
default_headers=default_headers,
|
|
**kwargs,
|
|
)
|
|
self._http_client = http_client
|
|
self._client, self._aclient = self._get_clients(**kwargs)
|
|
|
|
def _get_clients(self, **kwargs: Any) -> Tuple[SyncOpenAI, AsyncOpenAI]:
|
|
client = SyncOpenAI(**self._get_credential_kwargs())
|
|
aclient = AsyncOpenAI(**self._get_credential_kwargs())
|
|
return client, aclient
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "openai_multi_modal_llm"
|
|
|
|
@property
|
|
def metadata(self) -> MultiModalLLMMetadata:
|
|
"""Multi Modal LLM metadata."""
|
|
return MultiModalLLMMetadata(
|
|
num_output=self.max_new_tokens or DEFAULT_NUM_OUTPUTS,
|
|
model_name=self.model,
|
|
)
|
|
|
|
def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
|
return {
|
|
"api_key": self.api_key,
|
|
"base_url": self.api_base,
|
|
"max_retries": self.max_retries,
|
|
"default_headers": self.default_headers,
|
|
"http_client": self._http_client,
|
|
"timeout": self.timeout,
|
|
**kwargs,
|
|
}
|
|
|
|
def _get_multi_modal_chat_messages(
|
|
self,
|
|
prompt: str,
|
|
role: str,
|
|
image_documents: Sequence[ImageDocument],
|
|
**kwargs: Any,
|
|
) -> List[ChatCompletionMessageParam]:
|
|
return to_openai_message_dicts(
|
|
[
|
|
generate_openai_multi_modal_chat_message(
|
|
prompt=prompt,
|
|
role=role,
|
|
image_documents=image_documents,
|
|
image_detail=self.image_detail,
|
|
)
|
|
]
|
|
)
|
|
|
|
# Model Params for OpenAI GPT4V model.
|
|
def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
|
if self.model not in GPT4V_MODELS:
|
|
raise ValueError(
|
|
f"Invalid model {self.model}. "
|
|
f"Available models are: {list(GPT4V_MODELS.keys())}"
|
|
)
|
|
base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs}
|
|
if self.max_new_tokens is not None:
|
|
# If max_tokens is None, don't include in the payload:
|
|
# https://platform.openai.com/docs/api-reference/chat
|
|
# https://platform.openai.com/docs/api-reference/completions
|
|
base_kwargs["max_tokens"] = self.max_new_tokens
|
|
return {**base_kwargs, **self.additional_kwargs}
|
|
|
|
def _get_response_token_counts(self, raw_response: Any) -> dict:
|
|
"""Get the token usage reported by the response."""
|
|
if not isinstance(raw_response, dict):
|
|
return {}
|
|
|
|
usage = raw_response.get("usage", {})
|
|
# NOTE: other model providers that use the OpenAI client may not report usage
|
|
if usage is None:
|
|
return {}
|
|
|
|
return {
|
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
"completion_tokens": usage.get("completion_tokens", 0),
|
|
"total_tokens": usage.get("total_tokens", 0),
|
|
}
|
|
|
|
def _complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
all_kwargs = self._get_model_kwargs(**kwargs)
|
|
message_dict = self._get_multi_modal_chat_messages(
|
|
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
|
|
)
|
|
response = self._client.chat.completions.create(
|
|
messages=message_dict,
|
|
stream=False,
|
|
**all_kwargs,
|
|
)
|
|
|
|
return CompletionResponse(
|
|
text=response.choices[0].message.content,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
all_kwargs = self._get_model_kwargs(**kwargs)
|
|
message_dicts = to_openai_message_dicts(messages)
|
|
response = self._client.chat.completions.create(
|
|
messages=message_dicts,
|
|
stream=False,
|
|
**all_kwargs,
|
|
)
|
|
openai_message = response.choices[0].message
|
|
message = from_openai_message(openai_message)
|
|
|
|
return ChatResponse(
|
|
message=message,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
def _stream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
all_kwargs = self._get_model_kwargs(**kwargs)
|
|
message_dict = self._get_multi_modal_chat_messages(
|
|
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
|
|
)
|
|
|
|
def gen() -> CompletionResponseGen:
|
|
text = ""
|
|
|
|
for response in self._client.chat.completions.create(
|
|
messages=message_dict,
|
|
stream=True,
|
|
**all_kwargs,
|
|
):
|
|
response = cast(ChatCompletionChunk, response)
|
|
if len(response.choices) > 0:
|
|
delta = response.choices[0].delta
|
|
else:
|
|
delta = ChoiceDelta()
|
|
|
|
# update using deltas
|
|
content_delta = delta.content or ""
|
|
text += content_delta
|
|
|
|
yield CompletionResponse(
|
|
delta=content_delta,
|
|
text=text,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
return gen()
|
|
|
|
def _stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
message_dicts = to_openai_message_dicts(messages)
|
|
|
|
def gen() -> ChatResponseGen:
|
|
content = ""
|
|
tool_calls: List[ChoiceDeltaToolCall] = []
|
|
|
|
is_function = False
|
|
for response in self._client.chat.completions.create(
|
|
messages=message_dicts,
|
|
stream=True,
|
|
**self._get_model_kwargs(**kwargs),
|
|
):
|
|
response = cast(ChatCompletionChunk, response)
|
|
if len(response.choices) > 0:
|
|
delta = response.choices[0].delta
|
|
else:
|
|
delta = ChoiceDelta()
|
|
|
|
# check if this chunk is the start of a function call
|
|
if delta.tool_calls:
|
|
is_function = True
|
|
|
|
# update using deltas
|
|
role = delta.role or MessageRole.ASSISTANT
|
|
content_delta = delta.content or ""
|
|
content += content_delta
|
|
|
|
additional_kwargs = {}
|
|
if is_function:
|
|
tool_calls = self._update_tool_calls(tool_calls, delta.tool_calls)
|
|
additional_kwargs["tool_calls"] = tool_calls
|
|
|
|
yield ChatResponse(
|
|
message=ChatMessage(
|
|
role=role,
|
|
content=content,
|
|
additional_kwargs=additional_kwargs,
|
|
),
|
|
delta=content_delta,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
return gen()
|
|
|
|
def complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
return self._complete(prompt, image_documents, **kwargs)
|
|
|
|
def stream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
return self._stream_complete(prompt, image_documents, **kwargs)
|
|
|
|
def chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
return self._chat(messages, **kwargs)
|
|
|
|
def stream_chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponseGen:
|
|
return self._stream_chat(messages, **kwargs)
|
|
|
|
# ===== Async Endpoints =====
|
|
|
|
async def _acomplete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
all_kwargs = self._get_model_kwargs(**kwargs)
|
|
message_dict = self._get_multi_modal_chat_messages(
|
|
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
|
|
)
|
|
response = await self._aclient.chat.completions.create(
|
|
messages=message_dict,
|
|
stream=False,
|
|
**all_kwargs,
|
|
)
|
|
|
|
return CompletionResponse(
|
|
text=response.choices[0].message.content,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
async def acomplete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponse:
|
|
return await self._acomplete(prompt, image_documents, **kwargs)
|
|
|
|
async def _astream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
all_kwargs = self._get_model_kwargs(**kwargs)
|
|
message_dict = self._get_multi_modal_chat_messages(
|
|
prompt=prompt, role=MessageRole.USER, image_documents=image_documents
|
|
)
|
|
|
|
async def gen() -> CompletionResponseAsyncGen:
|
|
text = ""
|
|
|
|
async for response in await self._aclient.chat.completions.create(
|
|
messages=message_dict,
|
|
stream=True,
|
|
**all_kwargs,
|
|
):
|
|
response = cast(ChatCompletionChunk, response)
|
|
if len(response.choices) > 0:
|
|
delta = response.choices[0].delta
|
|
else:
|
|
delta = ChoiceDelta()
|
|
|
|
# update using deltas
|
|
content_delta = delta.content or ""
|
|
text += content_delta
|
|
|
|
yield CompletionResponse(
|
|
delta=content_delta,
|
|
text=text,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
return gen()
|
|
|
|
async def _achat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponse:
|
|
all_kwargs = self._get_model_kwargs(**kwargs)
|
|
message_dicts = to_openai_message_dicts(messages)
|
|
response = await self._aclient.chat.completions.create(
|
|
messages=message_dicts,
|
|
stream=False,
|
|
**all_kwargs,
|
|
)
|
|
openai_message = response.choices[0].message
|
|
message = from_openai_message(openai_message)
|
|
|
|
return ChatResponse(
|
|
message=message,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
async def _astream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseAsyncGen:
|
|
message_dicts = to_openai_message_dicts(messages)
|
|
|
|
async def gen() -> ChatResponseAsyncGen:
|
|
content = ""
|
|
tool_calls: List[ChoiceDeltaToolCall] = []
|
|
|
|
is_function = False
|
|
async for response in await self._aclient.chat.completions.create(
|
|
messages=message_dicts,
|
|
stream=True,
|
|
**self._get_model_kwargs(**kwargs),
|
|
):
|
|
response = cast(ChatCompletionChunk, response)
|
|
if len(response.choices) > 0:
|
|
delta = response.choices[0].delta
|
|
else:
|
|
delta = ChoiceDelta()
|
|
|
|
# check if this chunk is the start of a function call
|
|
if delta.tool_calls:
|
|
is_function = True
|
|
|
|
# update using deltas
|
|
role = delta.role or MessageRole.ASSISTANT
|
|
content_delta = delta.content or ""
|
|
content += content_delta
|
|
|
|
additional_kwargs = {}
|
|
if is_function:
|
|
tool_calls = self._update_tool_calls(tool_calls, delta.tool_calls)
|
|
additional_kwargs["tool_calls"] = tool_calls
|
|
|
|
yield ChatResponse(
|
|
message=ChatMessage(
|
|
role=role,
|
|
content=content,
|
|
additional_kwargs=additional_kwargs,
|
|
),
|
|
delta=content_delta,
|
|
raw=response,
|
|
additional_kwargs=self._get_response_token_counts(response),
|
|
)
|
|
|
|
return gen()
|
|
|
|
async def astream_complete(
|
|
self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
return await self._astream_complete(prompt, image_documents, **kwargs)
|
|
|
|
async def achat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
return await self._achat(messages, **kwargs)
|
|
|
|
async def astream_chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponseAsyncGen:
|
|
return await self._astream_chat(messages, **kwargs)
|