423 lines
14 KiB
Python
423 lines
14 KiB
Python
import json
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.core.llms.types import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
LLMMetadata,
|
|
)
|
|
from llama_index.llms.base import llm_chat_callback, llm_completion_callback
|
|
from llama_index.llms.generic_utils import (
|
|
completion_response_to_chat_response,
|
|
stream_completion_response_to_chat_response,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
messages_to_prompt as generic_messages_to_prompt,
|
|
)
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.vllm_utils import get_response, post_http_request
|
|
from llama_index.types import BaseOutputParser, PydanticProgramMode
|
|
|
|
|
|
class Vllm(LLM):
|
|
model: Optional[str] = Field(description="The HuggingFace Model to use.")
|
|
|
|
temperature: float = Field(description="The temperature to use for sampling.")
|
|
|
|
tensor_parallel_size: Optional[int] = Field(
|
|
default=1,
|
|
description="The number of GPUs to use for distributed execution with tensor parallelism.",
|
|
)
|
|
|
|
trust_remote_code: Optional[bool] = Field(
|
|
default=True,
|
|
description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.",
|
|
)
|
|
|
|
n: int = Field(
|
|
default=1,
|
|
description="Number of output sequences to return for the given prompt.",
|
|
)
|
|
|
|
best_of: Optional[int] = Field(
|
|
default=None,
|
|
description="Number of output sequences that are generated from the prompt.",
|
|
)
|
|
|
|
presence_penalty: float = Field(
|
|
default=0.0,
|
|
description="Float that penalizes new tokens based on whether they appear in the generated text so far.",
|
|
)
|
|
|
|
frequency_penalty: float = Field(
|
|
default=0.0,
|
|
description="Float that penalizes new tokens based on their frequency in the generated text so far.",
|
|
)
|
|
|
|
top_p: float = Field(
|
|
default=1.0,
|
|
description="Float that controls the cumulative probability of the top tokens to consider.",
|
|
)
|
|
|
|
top_k: int = Field(
|
|
default=-1,
|
|
description="Integer that controls the number of top tokens to consider.",
|
|
)
|
|
|
|
use_beam_search: bool = Field(
|
|
default=False, description="Whether to use beam search instead of sampling."
|
|
)
|
|
|
|
stop: Optional[List[str]] = Field(
|
|
default=None,
|
|
description="List of strings that stop the generation when they are generated.",
|
|
)
|
|
|
|
ignore_eos: bool = Field(
|
|
default=False,
|
|
description="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.",
|
|
)
|
|
|
|
max_new_tokens: int = Field(
|
|
default=512,
|
|
description="Maximum number of tokens to generate per output sequence.",
|
|
)
|
|
|
|
logprobs: Optional[int] = Field(
|
|
default=None,
|
|
description="Number of log probabilities to return per output token.",
|
|
)
|
|
|
|
dtype: str = Field(
|
|
default="auto",
|
|
description="The data type for the model weights and activations.",
|
|
)
|
|
|
|
download_dir: Optional[str] = Field(
|
|
default=None,
|
|
description="Directory to download and load the weights. (Default to the default cache dir of huggingface)",
|
|
)
|
|
|
|
vllm_kwargs: Dict[str, Any] = Field(
|
|
default_factory=dict,
|
|
description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.",
|
|
)
|
|
|
|
api_url: str = Field(description="The api url for vllm server")
|
|
|
|
_client: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "facebook/opt-125m",
|
|
temperature: float = 1.0,
|
|
tensor_parallel_size: int = 1,
|
|
trust_remote_code: bool = True,
|
|
n: int = 1,
|
|
best_of: Optional[int] = None,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = -1,
|
|
use_beam_search: bool = False,
|
|
stop: Optional[List[str]] = None,
|
|
ignore_eos: bool = False,
|
|
max_new_tokens: int = 512,
|
|
logprobs: Optional[int] = None,
|
|
dtype: str = "auto",
|
|
download_dir: Optional[str] = None,
|
|
vllm_kwargs: Dict[str, Any] = {},
|
|
api_url: Optional[str] = "",
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
|
|
completion_to_prompt: Optional[Callable[[str], str]] = None,
|
|
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
) -> None:
|
|
try:
|
|
from vllm import LLM as VLLModel
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import vllm python package. "
|
|
"Please install it with `pip install vllm`."
|
|
)
|
|
if model != "":
|
|
self._client = VLLModel(
|
|
model=model,
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
trust_remote_code=trust_remote_code,
|
|
dtype=dtype,
|
|
download_dir=download_dir,
|
|
**vllm_kwargs
|
|
)
|
|
else:
|
|
self._client = None
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
super().__init__(
|
|
model=model,
|
|
temperature=temperature,
|
|
n=n,
|
|
best_of=best_of,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
use_beam_search=use_beam_search,
|
|
stop=stop,
|
|
ignore_eos=ignore_eos,
|
|
max_new_tokens=max_new_tokens,
|
|
logprobs=logprobs,
|
|
dtype=dtype,
|
|
download_dir=download_dir,
|
|
vllm_kwargs=vllm_kwargs,
|
|
api_url=api_url,
|
|
system_prompt=system_prompt,
|
|
messages_to_prompt=messages_to_prompt,
|
|
completion_to_prompt=completion_to_prompt,
|
|
pydantic_program_mode=pydantic_program_mode,
|
|
output_parser=output_parser,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "Vllm"
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
return LLMMetadata(model_name=self.model)
|
|
|
|
@property
|
|
def _model_kwargs(self) -> Dict[str, Any]:
|
|
base_kwargs = {
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_new_tokens,
|
|
"n": self.n,
|
|
"frequency_penalty": self.frequency_penalty,
|
|
"presence_penalty": self.presence_penalty,
|
|
"use_beam_search": self.use_beam_search,
|
|
"best_of": self.best_of,
|
|
"ignore_eos": self.ignore_eos,
|
|
"stop": self.stop,
|
|
"logprobs": self.logprobs,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"stop": self.stop,
|
|
}
|
|
return {**base_kwargs}
|
|
|
|
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
|
return {
|
|
**self._model_kwargs,
|
|
**kwargs,
|
|
}
|
|
|
|
@llm_chat_callback()
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
kwargs = kwargs if kwargs else {}
|
|
prompt = self.messages_to_prompt(messages)
|
|
completion_response = self.complete(prompt, **kwargs)
|
|
return completion_response_to_chat_response(completion_response)
|
|
|
|
@llm_completion_callback()
|
|
def complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
kwargs = kwargs if kwargs else {}
|
|
params = {**self._model_kwargs, **kwargs}
|
|
|
|
from vllm import SamplingParams
|
|
|
|
# build sampling parameters
|
|
sampling_params = SamplingParams(**params)
|
|
outputs = self._client.generate([prompt], sampling_params)
|
|
return CompletionResponse(text=outputs[0].outputs[0].text)
|
|
|
|
@llm_chat_callback()
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
raise (ValueError("Not Implemented"))
|
|
|
|
@llm_completion_callback()
|
|
def stream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
raise (ValueError("Not Implemented"))
|
|
|
|
@llm_chat_callback()
|
|
async def achat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponse:
|
|
kwargs = kwargs if kwargs else {}
|
|
return self.chat(messages, **kwargs)
|
|
|
|
@llm_completion_callback()
|
|
async def acomplete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
raise (ValueError("Not Implemented"))
|
|
|
|
@llm_chat_callback()
|
|
async def astream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseAsyncGen:
|
|
raise (ValueError("Not Implemented"))
|
|
|
|
@llm_completion_callback()
|
|
async def astream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
raise (ValueError("Not Implemented"))
|
|
|
|
|
|
class VllmServer(Vllm):
|
|
def __init__(
|
|
self,
|
|
model: str = "facebook/opt-125m",
|
|
api_url: str = "http://localhost:8000",
|
|
temperature: float = 1.0,
|
|
tensor_parallel_size: Optional[int] = 1,
|
|
trust_remote_code: Optional[bool] = True,
|
|
n: int = 1,
|
|
best_of: Optional[int] = None,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = -1,
|
|
use_beam_search: bool = False,
|
|
stop: Optional[List[str]] = None,
|
|
ignore_eos: bool = False,
|
|
max_new_tokens: int = 512,
|
|
logprobs: Optional[int] = None,
|
|
dtype: str = "auto",
|
|
download_dir: Optional[str] = None,
|
|
messages_to_prompt: Optional[Callable] = None,
|
|
completion_to_prompt: Optional[Callable] = None,
|
|
vllm_kwargs: Dict[str, Any] = {},
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
) -> None:
|
|
self._client = None
|
|
messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
|
|
completion_to_prompt = completion_to_prompt or (lambda x: x)
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
model = ""
|
|
super().__init__(
|
|
model=model,
|
|
temperature=temperature,
|
|
n=n,
|
|
best_of=best_of,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
use_beam_search=use_beam_search,
|
|
stop=stop,
|
|
ignore_eos=ignore_eos,
|
|
max_new_tokens=max_new_tokens,
|
|
logprobs=logprobs,
|
|
dtype=dtype,
|
|
download_dir=download_dir,
|
|
messages_to_prompt=messages_to_prompt,
|
|
completion_to_prompt=completion_to_prompt,
|
|
vllm_kwargs=vllm_kwargs,
|
|
api_url=api_url,
|
|
callback_manager=callback_manager,
|
|
output_parser=output_parser,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "VllmServer"
|
|
|
|
@llm_completion_callback()
|
|
def complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> List[CompletionResponse]:
|
|
kwargs = kwargs if kwargs else {}
|
|
params = {**self._model_kwargs, **kwargs}
|
|
|
|
from vllm import SamplingParams
|
|
|
|
# build sampling parameters
|
|
sampling_params = SamplingParams(**params).__dict__
|
|
sampling_params["prompt"] = prompt
|
|
response = post_http_request(self.api_url, sampling_params, stream=False)
|
|
output = get_response(response)
|
|
|
|
return CompletionResponse(text=output[0])
|
|
|
|
@llm_completion_callback()
|
|
def stream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
kwargs = kwargs if kwargs else {}
|
|
params = {**self._model_kwargs, **kwargs}
|
|
|
|
from vllm import SamplingParams
|
|
|
|
# build sampling parameters
|
|
sampling_params = SamplingParams(**params).__dict__
|
|
sampling_params["prompt"] = prompt
|
|
response = post_http_request(self.api_url, sampling_params, stream=True)
|
|
|
|
def gen() -> CompletionResponseGen:
|
|
for chunk in response.iter_lines(
|
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
|
):
|
|
if chunk:
|
|
data = json.loads(chunk.decode("utf-8"))
|
|
|
|
yield CompletionResponse(text=data["text"][0])
|
|
|
|
return gen()
|
|
|
|
@llm_completion_callback()
|
|
async def acomplete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
kwargs = kwargs if kwargs else {}
|
|
return self.complete(prompt, **kwargs)
|
|
|
|
@llm_completion_callback()
|
|
async def astream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
kwargs = kwargs if kwargs else {}
|
|
params = {**self._model_kwargs, **kwargs}
|
|
|
|
from vllm import SamplingParams
|
|
|
|
# build sampling parameters
|
|
sampling_params = SamplingParams(**params).__dict__
|
|
sampling_params["prompt"] = prompt
|
|
|
|
async def gen() -> CompletionResponseAsyncGen:
|
|
for message in self.stream_complete(prompt, **kwargs):
|
|
yield message
|
|
|
|
return gen()
|
|
|
|
@llm_chat_callback()
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
prompt = self.messages_to_prompt(messages)
|
|
completion_response = self.stream_complete(prompt, **kwargs)
|
|
return stream_completion_response_to_chat_response(completion_response)
|
|
|
|
@llm_chat_callback()
|
|
async def astream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseAsyncGen:
|
|
return self.stream_chat(messages, **kwargs)
|