481 lines
17 KiB
Python
481 lines
17 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Sequence,
|
|
)
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
LLMMetadata,
|
|
)
|
|
from llama_index.llms.base import (
|
|
llm_chat_callback,
|
|
llm_completion_callback,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
completion_response_to_chat_response,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
messages_to_prompt as generic_messages_to_prompt,
|
|
)
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.types import PydanticProgramMode
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import TypeVar
|
|
|
|
M = TypeVar("M")
|
|
T = TypeVar("T")
|
|
Metadata = Any
|
|
|
|
|
|
class OpenLLM(LLM):
|
|
"""OpenLLM LLM."""
|
|
|
|
model_id: str = Field(
|
|
description="Given Model ID from HuggingFace Hub. This can be either a pretrained ID or local path. This is synonymous to HuggingFace's '.from_pretrained' first argument"
|
|
)
|
|
model_version: Optional[str] = Field(
|
|
description="Optional model version to save the model as."
|
|
)
|
|
model_tag: Optional[str] = Field(
|
|
description="Optional tag to save to BentoML store."
|
|
)
|
|
prompt_template: Optional[str] = Field(
|
|
description="Optional prompt template to pass for this LLM."
|
|
)
|
|
backend: Optional[Literal["vllm", "pt"]] = Field(
|
|
description="Optional backend to pass for this LLM. By default, it will use vLLM if vLLM is available in local system. Otherwise, it will fallback to PyTorch."
|
|
)
|
|
quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = Field(
|
|
description="Optional quantization methods to use with this LLM. See OpenLLM's --quantize options from `openllm start` for more information."
|
|
)
|
|
serialization: Literal["safetensors", "legacy"] = Field(
|
|
description="Optional serialization methods for this LLM to be save as. Default to 'safetensors', but will fallback to PyTorch pickle `.bin` on some models."
|
|
)
|
|
trust_remote_code: bool = Field(
|
|
description="Optional flag to trust remote code. This is synonymous to Transformers' `trust_remote_code`. Default to False."
|
|
)
|
|
if TYPE_CHECKING:
|
|
from typing import Generic
|
|
|
|
try:
|
|
import openllm
|
|
|
|
_llm: openllm.LLM[Any, Any]
|
|
except ImportError:
|
|
_llm: Any # type: ignore[no-redef]
|
|
else:
|
|
_llm: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
model_version: Optional[str] = None,
|
|
model_tag: Optional[str] = None,
|
|
prompt_template: Optional[str] = None,
|
|
backend: Optional[Literal["vllm", "pt"]] = None,
|
|
*args: Any,
|
|
quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = None,
|
|
serialization: Literal["safetensors", "legacy"] = "safetensors",
|
|
trust_remote_code: bool = False,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
|
|
completion_to_prompt: Optional[Callable[[str], str]] = None,
|
|
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
|
|
**attrs: Any,
|
|
):
|
|
try:
|
|
import openllm
|
|
except ImportError:
|
|
raise ImportError(
|
|
"OpenLLM is not installed. Please install OpenLLM via `pip install openllm`"
|
|
)
|
|
self._llm = openllm.LLM[Any, Any](
|
|
model_id,
|
|
model_version=model_version,
|
|
model_tag=model_tag,
|
|
prompt_template=prompt_template,
|
|
system_message=system_prompt,
|
|
backend=backend,
|
|
quantize=quantize,
|
|
serialisation=serialization,
|
|
trust_remote_code=trust_remote_code,
|
|
embedded=True,
|
|
**attrs,
|
|
)
|
|
if messages_to_prompt is None:
|
|
messages_to_prompt = self._tokenizer_messages_to_prompt
|
|
|
|
# NOTE: We need to do this here to ensure model is saved and revision is set correctly.
|
|
assert self._llm.bentomodel
|
|
|
|
super().__init__(
|
|
model_id=model_id,
|
|
model_version=self._llm.revision,
|
|
model_tag=str(self._llm.tag),
|
|
prompt_template=prompt_template,
|
|
backend=self._llm.__llm_backend__,
|
|
quantize=self._llm.quantise,
|
|
serialization=self._llm._serialisation,
|
|
trust_remote_code=self._llm.trust_remote_code,
|
|
callback_manager=callback_manager,
|
|
system_prompt=system_prompt,
|
|
messages_to_prompt=messages_to_prompt,
|
|
completion_to_prompt=completion_to_prompt,
|
|
pydantic_program_mode=pydantic_program_mode,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "OpenLLM"
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
"""LLM metadata."""
|
|
return LLMMetadata(
|
|
num_output=self._llm.config["max_new_tokens"],
|
|
model_name=self.model_id,
|
|
)
|
|
|
|
def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
|
"""Use the tokenizer to convert messages to prompt. Fallback to generic."""
|
|
if hasattr(self._llm.tokenizer, "apply_chat_template"):
|
|
return self._llm.tokenizer.apply_chat_template(
|
|
[message.dict() for message in messages],
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
return generic_messages_to_prompt(messages)
|
|
|
|
@llm_completion_callback()
|
|
def complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
return asyncio.run(self.acomplete(prompt, **kwargs))
|
|
|
|
@llm_chat_callback()
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
return asyncio.run(self.achat(messages, **kwargs))
|
|
|
|
@property
|
|
def _loop(self) -> asyncio.AbstractEventLoop:
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
loop = asyncio.get_event_loop()
|
|
return loop
|
|
|
|
@llm_completion_callback()
|
|
def stream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
generator = self.astream_complete(prompt, **kwargs)
|
|
# Yield items from the queue synchronously
|
|
while True:
|
|
try:
|
|
yield self._loop.run_until_complete(generator.__anext__())
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
@llm_chat_callback()
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
generator = self.astream_chat(messages, **kwargs)
|
|
# Yield items from the queue synchronously
|
|
while True:
|
|
try:
|
|
yield self._loop.run_until_complete(generator.__anext__())
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
@llm_chat_callback()
|
|
async def achat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
response = await self.acomplete(self.messages_to_prompt(messages), **kwargs)
|
|
return completion_response_to_chat_response(response)
|
|
|
|
@llm_completion_callback()
|
|
async def acomplete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
response = await self._llm.generate(prompt, **kwargs)
|
|
return CompletionResponse(
|
|
text=response.outputs[0].text,
|
|
raw=response.model_dump(),
|
|
additional_kwargs={
|
|
"prompt_token_ids": response.prompt_token_ids,
|
|
"prompt_logprobs": response.prompt_logprobs,
|
|
"finished": response.finished,
|
|
"outputs": {
|
|
"token_ids": response.outputs[0].token_ids,
|
|
"cumulative_logprob": response.outputs[0].cumulative_logprob,
|
|
"logprobs": response.outputs[0].logprobs,
|
|
"finish_reason": response.outputs[0].finish_reason,
|
|
},
|
|
},
|
|
)
|
|
|
|
@llm_chat_callback()
|
|
async def astream_chat(
|
|
self,
|
|
messages: Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponseAsyncGen:
|
|
async for response_chunk in self.astream_complete(
|
|
self.messages_to_prompt(messages), **kwargs
|
|
):
|
|
yield completion_response_to_chat_response(response_chunk)
|
|
|
|
@llm_completion_callback()
|
|
async def astream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
config = self._llm.config.model_construct_env(**kwargs)
|
|
if config["n"] > 1:
|
|
logger.warning("Currently only support n=1")
|
|
|
|
texts: List[List[str]] = [[]] * config["n"]
|
|
|
|
async for response_chunk in self._llm.generate_iterator(prompt, **kwargs):
|
|
for output in response_chunk.outputs:
|
|
texts[output.index].append(output.text)
|
|
yield CompletionResponse(
|
|
text=response_chunk.outputs[0].text,
|
|
delta=response_chunk.outputs[0].text,
|
|
raw=response_chunk.model_dump(),
|
|
additional_kwargs={
|
|
"prompt_token_ids": response_chunk.prompt_token_ids,
|
|
"prompt_logprobs": response_chunk.prompt_logprobs,
|
|
"finished": response_chunk.finished,
|
|
"outputs": {
|
|
"text": response_chunk.outputs[0].text,
|
|
"token_ids": response_chunk.outputs[0].token_ids,
|
|
"cumulative_logprob": response_chunk.outputs[
|
|
0
|
|
].cumulative_logprob,
|
|
"logprobs": response_chunk.outputs[0].logprobs,
|
|
"finish_reason": response_chunk.outputs[0].finish_reason,
|
|
},
|
|
},
|
|
)
|
|
|
|
|
|
class OpenLLMAPI(LLM):
|
|
"""OpenLLM Client interface. This is useful when interacting with a remote OpenLLM server."""
|
|
|
|
address: Optional[str] = Field(
|
|
description="OpenLLM server address. This could either be set here or via OPENLLM_ENDPOINT"
|
|
)
|
|
timeout: int = Field(description="Timeout for sending requests.")
|
|
max_retries: int = Field(description="Maximum number of retries.")
|
|
api_version: Literal["v1"] = Field(description="OpenLLM Server API version.")
|
|
|
|
if TYPE_CHECKING:
|
|
try:
|
|
from openllm_client import AsyncHTTPClient, HTTPClient
|
|
|
|
_sync_client: HTTPClient
|
|
_async_client: AsyncHTTPClient
|
|
except ImportError:
|
|
_sync_client: Any # type: ignore[no-redef]
|
|
_async_client: Any # type: ignore[no-redef]
|
|
else:
|
|
_sync_client: Any = PrivateAttr()
|
|
_async_client: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
address: Optional[str] = None,
|
|
timeout: int = 30,
|
|
max_retries: int = 2,
|
|
api_version: Literal["v1"] = "v1",
|
|
**kwargs: Any,
|
|
):
|
|
try:
|
|
from openllm_client import AsyncHTTPClient, HTTPClient
|
|
except ImportError:
|
|
raise ImportError(
|
|
f'"{type(self).__name__}" requires "openllm-client". Make sure to install with `pip install openllm-client`'
|
|
)
|
|
super().__init__(
|
|
address=address,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
api_version=api_version,
|
|
**kwargs,
|
|
)
|
|
self._sync_client = HTTPClient(
|
|
address=address,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
api_version=api_version,
|
|
)
|
|
self._async_client = AsyncHTTPClient(
|
|
address=address,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
api_version=api_version,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "OpenLLM_Client"
|
|
|
|
@property
|
|
def _server_metadata(self) -> "Metadata":
|
|
return self._sync_client._metadata
|
|
|
|
@property
|
|
def _server_config(self) -> Dict[str, Any]:
|
|
return self._sync_client._config
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
return LLMMetadata(
|
|
num_output=self._server_config["max_new_tokens"],
|
|
model_name=self._server_metadata.model_id.replace("/", "--"),
|
|
)
|
|
|
|
def _convert_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
|
return self._sync_client.helpers.messages(
|
|
messages=[
|
|
{"role": message.role, "content": message.content}
|
|
for message in messages
|
|
],
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
async def _async_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
|
return await self._async_client.helpers.messages(
|
|
messages=[
|
|
{"role": message.role, "content": message.content}
|
|
for message in messages
|
|
],
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
@llm_completion_callback()
|
|
def complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
response = self._sync_client.generate(prompt, **kwargs)
|
|
return CompletionResponse(
|
|
text=response.outputs[0].text,
|
|
raw=response.model_dump(),
|
|
additional_kwargs={
|
|
"prompt_token_ids": response.prompt_token_ids,
|
|
"prompt_logprobs": response.prompt_logprobs,
|
|
"finished": response.finished,
|
|
"outputs": {
|
|
"token_ids": response.outputs[0].token_ids,
|
|
"cumulative_logprob": response.outputs[0].cumulative_logprob,
|
|
"logprobs": response.outputs[0].logprobs,
|
|
"finish_reason": response.outputs[0].finish_reason,
|
|
},
|
|
},
|
|
)
|
|
|
|
@llm_completion_callback()
|
|
def stream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
for response_chunk in self._sync_client.generate_stream(prompt, **kwargs):
|
|
yield CompletionResponse(
|
|
text=response_chunk.text,
|
|
delta=response_chunk.text,
|
|
raw=response_chunk.model_dump(),
|
|
additional_kwargs={"token_ids": response_chunk.token_ids},
|
|
)
|
|
|
|
@llm_chat_callback()
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
return completion_response_to_chat_response(
|
|
self.complete(self._convert_messages_to_prompt(messages), **kwargs)
|
|
)
|
|
|
|
@llm_chat_callback()
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
for response_chunk in self.stream_complete(
|
|
self._convert_messages_to_prompt(messages), **kwargs
|
|
):
|
|
yield completion_response_to_chat_response(response_chunk)
|
|
|
|
@llm_completion_callback()
|
|
async def acomplete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
response = await self._async_client.generate(prompt, **kwargs)
|
|
return CompletionResponse(
|
|
text=response.outputs[0].text,
|
|
raw=response.model_dump(),
|
|
additional_kwargs={
|
|
"prompt_token_ids": response.prompt_token_ids,
|
|
"prompt_logprobs": response.prompt_logprobs,
|
|
"finished": response.finished,
|
|
"outputs": {
|
|
"token_ids": response.outputs[0].token_ids,
|
|
"cumulative_logprob": response.outputs[0].cumulative_logprob,
|
|
"logprobs": response.outputs[0].logprobs,
|
|
"finish_reason": response.outputs[0].finish_reason,
|
|
},
|
|
},
|
|
)
|
|
|
|
@llm_completion_callback()
|
|
async def astream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
async for response_chunk in self._async_client.generate_stream(
|
|
prompt, **kwargs
|
|
):
|
|
yield CompletionResponse(
|
|
text=response_chunk.text,
|
|
delta=response_chunk.text,
|
|
raw=response_chunk.model_dump(),
|
|
additional_kwargs={"token_ids": response_chunk.token_ids},
|
|
)
|
|
|
|
@llm_chat_callback()
|
|
async def achat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponse:
|
|
return completion_response_to_chat_response(
|
|
await self.acomplete(
|
|
await self._async_messages_to_prompt(messages), **kwargs
|
|
)
|
|
)
|
|
|
|
@llm_chat_callback()
|
|
async def astream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseAsyncGen:
|
|
async for response_chunk in self.astream_complete(
|
|
await self._async_messages_to_prompt(messages), **kwargs
|
|
):
|
|
yield completion_response_to_chat_response(response_chunk)
|