637 lines
24 KiB
Python
637 lines
24 KiB
Python
import logging
|
||
from threading import Thread
|
||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
|
||
|
||
from llama_index.bridge.pydantic import Field, PrivateAttr
|
||
from llama_index.callbacks import CallbackManager
|
||
from llama_index.constants import (
|
||
DEFAULT_CONTEXT_WINDOW,
|
||
DEFAULT_NUM_OUTPUTS,
|
||
)
|
||
from llama_index.core.llms.types import (
|
||
ChatMessage,
|
||
ChatResponse,
|
||
ChatResponseAsyncGen,
|
||
ChatResponseGen,
|
||
CompletionResponse,
|
||
CompletionResponseAsyncGen,
|
||
CompletionResponseGen,
|
||
LLMMetadata,
|
||
MessageRole,
|
||
)
|
||
from llama_index.llms.base import (
|
||
llm_chat_callback,
|
||
llm_completion_callback,
|
||
)
|
||
from llama_index.llms.custom import CustomLLM
|
||
from llama_index.llms.generic_utils import (
|
||
completion_response_to_chat_response,
|
||
stream_completion_response_to_chat_response,
|
||
)
|
||
from llama_index.llms.generic_utils import (
|
||
messages_to_prompt as generic_messages_to_prompt,
|
||
)
|
||
from llama_index.prompts.base import PromptTemplate
|
||
from llama_index.types import BaseOutputParser, PydanticProgramMode
|
||
|
||
DEFAULT_HUGGINGFACE_MODEL = "StabilityAI/stablelm-tuned-alpha-3b"
|
||
if TYPE_CHECKING:
|
||
try:
|
||
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
||
from huggingface_hub.hf_api import ModelInfo
|
||
from huggingface_hub.inference._types import ConversationalOutput
|
||
except ModuleNotFoundError:
|
||
AsyncInferenceClient = Any
|
||
InferenceClient = Any
|
||
ConversationalOutput = dict
|
||
ModelInfo = Any
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class HuggingFaceLLM(CustomLLM):
|
||
"""HuggingFace LLM."""
|
||
|
||
model_name: str = Field(
|
||
default=DEFAULT_HUGGINGFACE_MODEL,
|
||
description=(
|
||
"The model name to use from HuggingFace. "
|
||
"Unused if `model` is passed in directly."
|
||
),
|
||
)
|
||
context_window: int = Field(
|
||
default=DEFAULT_CONTEXT_WINDOW,
|
||
description="The maximum number of tokens available for input.",
|
||
gt=0,
|
||
)
|
||
max_new_tokens: int = Field(
|
||
default=DEFAULT_NUM_OUTPUTS,
|
||
description="The maximum number of tokens to generate.",
|
||
gt=0,
|
||
)
|
||
system_prompt: str = Field(
|
||
default="",
|
||
description=(
|
||
"The system prompt, containing any extra instructions or context. "
|
||
"The model card on HuggingFace should specify if this is needed."
|
||
),
|
||
)
|
||
query_wrapper_prompt: PromptTemplate = Field(
|
||
default=PromptTemplate("{query_str}"),
|
||
description=(
|
||
"The query wrapper prompt, containing the query placeholder. "
|
||
"The model card on HuggingFace should specify if this is needed. "
|
||
"Should contain a `{query_str}` placeholder."
|
||
),
|
||
)
|
||
tokenizer_name: str = Field(
|
||
default=DEFAULT_HUGGINGFACE_MODEL,
|
||
description=(
|
||
"The name of the tokenizer to use from HuggingFace. "
|
||
"Unused if `tokenizer` is passed in directly."
|
||
),
|
||
)
|
||
device_map: str = Field(
|
||
default="auto", description="The device_map to use. Defaults to 'auto'."
|
||
)
|
||
stopping_ids: List[int] = Field(
|
||
default_factory=list,
|
||
description=(
|
||
"The stopping ids to use. "
|
||
"Generation stops when these token IDs are predicted."
|
||
),
|
||
)
|
||
tokenizer_outputs_to_remove: list = Field(
|
||
default_factory=list,
|
||
description=(
|
||
"The outputs to remove from the tokenizer. "
|
||
"Sometimes huggingface tokenizers return extra inputs that cause errors."
|
||
),
|
||
)
|
||
tokenizer_kwargs: dict = Field(
|
||
default_factory=dict, description="The kwargs to pass to the tokenizer."
|
||
)
|
||
model_kwargs: dict = Field(
|
||
default_factory=dict,
|
||
description="The kwargs to pass to the model during initialization.",
|
||
)
|
||
generate_kwargs: dict = Field(
|
||
default_factory=dict,
|
||
description="The kwargs to pass to the model during generation.",
|
||
)
|
||
is_chat_model: bool = Field(
|
||
default=False,
|
||
description=(
|
||
LLMMetadata.__fields__["is_chat_model"].field_info.description
|
||
+ " Be sure to verify that you either pass an appropriate tokenizer "
|
||
"that can convert prompts to properly formatted chat messages or a "
|
||
"`messages_to_prompt` that does so."
|
||
),
|
||
)
|
||
|
||
_model: Any = PrivateAttr()
|
||
_tokenizer: Any = PrivateAttr()
|
||
_stopping_criteria: Any = PrivateAttr()
|
||
|
||
def __init__(
|
||
self,
|
||
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
|
||
query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}",
|
||
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||
model: Optional[Any] = None,
|
||
tokenizer: Optional[Any] = None,
|
||
device_map: Optional[str] = "auto",
|
||
stopping_ids: Optional[List[int]] = None,
|
||
tokenizer_kwargs: Optional[dict] = None,
|
||
tokenizer_outputs_to_remove: Optional[list] = None,
|
||
model_kwargs: Optional[dict] = None,
|
||
generate_kwargs: Optional[dict] = None,
|
||
is_chat_model: Optional[bool] = False,
|
||
callback_manager: Optional[CallbackManager] = None,
|
||
system_prompt: str = "",
|
||
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
|
||
completion_to_prompt: Optional[Callable[[str], str]] = None,
|
||
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
|
||
output_parser: Optional[BaseOutputParser] = None,
|
||
) -> None:
|
||
"""Initialize params."""
|
||
try:
|
||
import torch
|
||
from transformers import (
|
||
AutoModelForCausalLM,
|
||
AutoTokenizer,
|
||
StoppingCriteria,
|
||
StoppingCriteriaList,
|
||
)
|
||
except ImportError as exc:
|
||
raise ImportError(
|
||
f"{type(self).__name__} requires torch and transformers packages.\n"
|
||
"Please install both with `pip install transformers[torch]`."
|
||
) from exc
|
||
|
||
model_kwargs = model_kwargs or {}
|
||
self._model = model or AutoModelForCausalLM.from_pretrained(
|
||
model_name, device_map=device_map, **model_kwargs
|
||
)
|
||
|
||
# check context_window
|
||
config_dict = self._model.config.to_dict()
|
||
model_context_window = int(
|
||
config_dict.get("max_position_embeddings", context_window)
|
||
)
|
||
if model_context_window and model_context_window < context_window:
|
||
logger.warning(
|
||
f"Supplied context_window {context_window} is greater "
|
||
f"than the model's max input size {model_context_window}. "
|
||
"Disable this warning by setting a lower context_window."
|
||
)
|
||
context_window = model_context_window
|
||
|
||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||
if "max_length" not in tokenizer_kwargs:
|
||
tokenizer_kwargs["max_length"] = context_window
|
||
|
||
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
|
||
tokenizer_name, **tokenizer_kwargs
|
||
)
|
||
|
||
if tokenizer_name != model_name:
|
||
logger.warning(
|
||
f"The model `{model_name}` and tokenizer `{tokenizer_name}` "
|
||
f"are different, please ensure that they are compatible."
|
||
)
|
||
|
||
# setup stopping criteria
|
||
stopping_ids_list = stopping_ids or []
|
||
|
||
class StopOnTokens(StoppingCriteria):
|
||
def __call__(
|
||
self,
|
||
input_ids: torch.LongTensor,
|
||
scores: torch.FloatTensor,
|
||
**kwargs: Any,
|
||
) -> bool:
|
||
for stop_id in stopping_ids_list:
|
||
if input_ids[0][-1] == stop_id:
|
||
return True
|
||
return False
|
||
|
||
self._stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
||
|
||
if isinstance(query_wrapper_prompt, str):
|
||
query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)
|
||
|
||
messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt
|
||
|
||
super().__init__(
|
||
context_window=context_window,
|
||
max_new_tokens=max_new_tokens,
|
||
query_wrapper_prompt=query_wrapper_prompt,
|
||
tokenizer_name=tokenizer_name,
|
||
model_name=model_name,
|
||
device_map=device_map,
|
||
stopping_ids=stopping_ids or [],
|
||
tokenizer_kwargs=tokenizer_kwargs or {},
|
||
tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [],
|
||
model_kwargs=model_kwargs or {},
|
||
generate_kwargs=generate_kwargs or {},
|
||
is_chat_model=is_chat_model,
|
||
callback_manager=callback_manager,
|
||
system_prompt=system_prompt,
|
||
messages_to_prompt=messages_to_prompt,
|
||
completion_to_prompt=completion_to_prompt,
|
||
pydantic_program_mode=pydantic_program_mode,
|
||
output_parser=output_parser,
|
||
)
|
||
|
||
@classmethod
|
||
def class_name(cls) -> str:
|
||
return "HuggingFace_LLM"
|
||
|
||
@property
|
||
def metadata(self) -> LLMMetadata:
|
||
"""LLM metadata."""
|
||
return LLMMetadata(
|
||
context_window=self.context_window,
|
||
num_output=self.max_new_tokens,
|
||
model_name=self.model_name,
|
||
is_chat_model=self.is_chat_model,
|
||
)
|
||
|
||
def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||
"""Use the tokenizer to convert messages to prompt. Fallback to generic."""
|
||
if hasattr(self._tokenizer, "apply_chat_template"):
|
||
messages_dict = [
|
||
{"role": message.role.value, "content": message.content}
|
||
for message in messages
|
||
]
|
||
tokens = self._tokenizer.apply_chat_template(messages_dict)
|
||
return self._tokenizer.decode(tokens)
|
||
|
||
return generic_messages_to_prompt(messages)
|
||
|
||
@llm_completion_callback()
|
||
def complete(
|
||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||
) -> CompletionResponse:
|
||
"""Completion endpoint."""
|
||
full_prompt = prompt
|
||
if not formatted:
|
||
if self.query_wrapper_prompt:
|
||
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
|
||
if self.system_prompt:
|
||
full_prompt = f"{self.system_prompt} {full_prompt}"
|
||
|
||
inputs = self._tokenizer(full_prompt, return_tensors="pt")
|
||
inputs = inputs.to(self._model.device)
|
||
|
||
# remove keys from the tokenizer if needed, to avoid HF errors
|
||
for key in self.tokenizer_outputs_to_remove:
|
||
if key in inputs:
|
||
inputs.pop(key, None)
|
||
|
||
tokens = self._model.generate(
|
||
**inputs,
|
||
max_new_tokens=self.max_new_tokens,
|
||
stopping_criteria=self._stopping_criteria,
|
||
**self.generate_kwargs,
|
||
)
|
||
completion_tokens = tokens[0][inputs["input_ids"].size(1) :]
|
||
completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)
|
||
|
||
return CompletionResponse(text=completion, raw={"model_output": tokens})
|
||
|
||
@llm_completion_callback()
|
||
def stream_complete(
|
||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||
) -> CompletionResponseGen:
|
||
"""Streaming completion endpoint."""
|
||
from transformers import TextIteratorStreamer
|
||
|
||
full_prompt = prompt
|
||
if not formatted:
|
||
if self.query_wrapper_prompt:
|
||
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
|
||
if self.system_prompt:
|
||
full_prompt = f"{self.system_prompt} {full_prompt}"
|
||
|
||
inputs = self._tokenizer(full_prompt, return_tensors="pt")
|
||
inputs = inputs.to(self._model.device)
|
||
|
||
# remove keys from the tokenizer if needed, to avoid HF errors
|
||
for key in self.tokenizer_outputs_to_remove:
|
||
if key in inputs:
|
||
inputs.pop(key, None)
|
||
|
||
streamer = TextIteratorStreamer(
|
||
self._tokenizer,
|
||
skip_prompt=True,
|
||
decode_kwargs={"skip_special_tokens": True},
|
||
)
|
||
generation_kwargs = dict(
|
||
inputs,
|
||
streamer=streamer,
|
||
max_new_tokens=self.max_new_tokens,
|
||
stopping_criteria=self._stopping_criteria,
|
||
**self.generate_kwargs,
|
||
)
|
||
|
||
# generate in background thread
|
||
# NOTE/TODO: token counting doesn't work with streaming
|
||
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
|
||
thread.start()
|
||
|
||
# create generator based off of streamer
|
||
def gen() -> CompletionResponseGen:
|
||
text = ""
|
||
for x in streamer:
|
||
text += x
|
||
yield CompletionResponse(text=text, delta=x)
|
||
|
||
return gen()
|
||
|
||
@llm_chat_callback()
|
||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
||
prompt = self.messages_to_prompt(messages)
|
||
completion_response = self.complete(prompt, formatted=True, **kwargs)
|
||
return completion_response_to_chat_response(completion_response)
|
||
|
||
@llm_chat_callback()
|
||
def stream_chat(
|
||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||
) -> ChatResponseGen:
|
||
prompt = self.messages_to_prompt(messages)
|
||
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
|
||
return stream_completion_response_to_chat_response(completion_response)
|
||
|
||
|
||
def chat_messages_to_conversational_kwargs(
|
||
messages: Sequence[ChatMessage],
|
||
) -> Dict[str, Any]:
|
||
"""Convert ChatMessages to keyword arguments for Inference API conversational."""
|
||
if len(messages) % 2 != 1:
|
||
raise NotImplementedError("Messages passed in must be of odd length.")
|
||
last_message = messages[-1]
|
||
kwargs: Dict[str, Any] = {
|
||
"text": last_message.content,
|
||
**last_message.additional_kwargs,
|
||
}
|
||
if len(messages) != 1:
|
||
kwargs["past_user_inputs"] = []
|
||
kwargs["generated_responses"] = []
|
||
for user_msg, assistant_msg in zip(messages[::2], messages[1::2]):
|
||
if (
|
||
user_msg.role != MessageRole.USER
|
||
or assistant_msg.role != MessageRole.ASSISTANT
|
||
):
|
||
raise NotImplementedError(
|
||
"Didn't handle when messages aren't ordered in alternating"
|
||
f" pairs of {(MessageRole.USER, MessageRole.ASSISTANT)}."
|
||
)
|
||
kwargs["past_user_inputs"].append(user_msg.content)
|
||
kwargs["generated_responses"].append(assistant_msg.content)
|
||
return kwargs
|
||
|
||
|
||
class HuggingFaceInferenceAPI(CustomLLM):
|
||
"""
|
||
Wrapper on the Hugging Face's Inference API.
|
||
|
||
Overview of the design:
|
||
- Synchronous uses InferenceClient, asynchronous uses AsyncInferenceClient
|
||
- chat uses the conversational task: https://huggingface.co/tasks/conversational
|
||
- complete uses the text generation task: https://huggingface.co/tasks/text-generation
|
||
|
||
Note: some models that support the text generation task can leverage Hugging
|
||
Face's optimized deployment toolkit called text-generation-inference (TGI).
|
||
Use InferenceClient.get_model_status to check if TGI is being used.
|
||
|
||
Relevant links:
|
||
- General Docs: https://huggingface.co/docs/api-inference/index
|
||
- API Docs: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client
|
||
- Source: https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub/inference
|
||
"""
|
||
|
||
@classmethod
|
||
def class_name(cls) -> str:
|
||
return "HuggingFaceInferenceAPI"
|
||
|
||
# Corresponds with huggingface_hub.InferenceClient
|
||
model_name: Optional[str] = Field(
|
||
default=None,
|
||
description=(
|
||
"The model to run inference with. Can be a model id hosted on the Hugging"
|
||
" Face Hub, e.g. bigcode/starcoder or a URL to a deployed Inference"
|
||
" Endpoint. Defaults to None, in which case a recommended model is"
|
||
" automatically selected for the task (see Field below)."
|
||
),
|
||
)
|
||
token: Union[str, bool, None] = Field(
|
||
default=None,
|
||
description=(
|
||
"Hugging Face token. Will default to the locally saved token. Pass "
|
||
"token=False if you don’t want to send your token to the server."
|
||
),
|
||
)
|
||
timeout: Optional[float] = Field(
|
||
default=None,
|
||
description=(
|
||
"The maximum number of seconds to wait for a response from the server."
|
||
" Loading a new model in Inference API can take up to several minutes."
|
||
" Defaults to None, meaning it will loop until the server is available."
|
||
),
|
||
)
|
||
headers: Dict[str, str] = Field(
|
||
default=None,
|
||
description=(
|
||
"Additional headers to send to the server. By default only the"
|
||
" authorization and user-agent headers are sent. Values in this dictionary"
|
||
" will override the default values."
|
||
),
|
||
)
|
||
cookies: Dict[str, str] = Field(
|
||
default=None, description="Additional cookies to send to the server."
|
||
)
|
||
task: Optional[str] = Field(
|
||
default=None,
|
||
description=(
|
||
"Optional task to pick Hugging Face's recommended model, used when"
|
||
" model_name is left as default of None."
|
||
),
|
||
)
|
||
|
||
_sync_client: "InferenceClient" = PrivateAttr()
|
||
_async_client: "AsyncInferenceClient" = PrivateAttr()
|
||
_get_model_info: "Callable[..., ModelInfo]" = PrivateAttr()
|
||
|
||
context_window: int = Field(
|
||
default=DEFAULT_CONTEXT_WINDOW,
|
||
description=(
|
||
LLMMetadata.__fields__["context_window"].field_info.description
|
||
+ " This may be looked up in a model's `config.json`."
|
||
),
|
||
)
|
||
num_output: int = Field(
|
||
default=DEFAULT_NUM_OUTPUTS,
|
||
description=LLMMetadata.__fields__["num_output"].field_info.description,
|
||
)
|
||
is_chat_model: bool = Field(
|
||
default=False,
|
||
description=(
|
||
LLMMetadata.__fields__["is_chat_model"].field_info.description
|
||
+ " Unless chat templating is intentionally applied, Hugging Face models"
|
||
" are not chat models."
|
||
),
|
||
)
|
||
is_function_calling_model: bool = Field(
|
||
default=False,
|
||
description=(
|
||
LLMMetadata.__fields__["is_function_calling_model"].field_info.description
|
||
+ " As of 10/17/2023, Hugging Face doesn't support function calling"
|
||
" messages."
|
||
),
|
||
)
|
||
|
||
def _get_inference_client_kwargs(self) -> Dict[str, Any]:
|
||
"""Extract the Hugging Face InferenceClient construction parameters."""
|
||
return {
|
||
"model": self.model_name,
|
||
"token": self.token,
|
||
"timeout": self.timeout,
|
||
"headers": self.headers,
|
||
"cookies": self.cookies,
|
||
}
|
||
|
||
def __init__(self, **kwargs: Any) -> None:
|
||
"""Initialize.
|
||
|
||
Args:
|
||
kwargs: See the class-level Fields.
|
||
"""
|
||
try:
|
||
from huggingface_hub import (
|
||
AsyncInferenceClient,
|
||
InferenceClient,
|
||
model_info,
|
||
)
|
||
except ModuleNotFoundError as exc:
|
||
raise ImportError(
|
||
f"{type(self).__name__} requires huggingface_hub with its inference"
|
||
" extra, please run `pip install huggingface_hub[inference]>=0.19.0`."
|
||
) from exc
|
||
if kwargs.get("model_name") is None:
|
||
task = kwargs.get("task", "")
|
||
# NOTE: task being None or empty string leads to ValueError,
|
||
# which ensures model is present
|
||
kwargs["model_name"] = InferenceClient.get_recommended_model(task=task)
|
||
logger.debug(
|
||
f"Using Hugging Face's recommended model {kwargs['model_name']}"
|
||
f" given task {task}."
|
||
)
|
||
if kwargs.get("task") is None:
|
||
task = "conversational"
|
||
else:
|
||
task = kwargs["task"].lower()
|
||
|
||
super().__init__(**kwargs) # Populate pydantic Fields
|
||
self._sync_client = InferenceClient(**self._get_inference_client_kwargs())
|
||
self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs())
|
||
self._get_model_info = model_info
|
||
|
||
def validate_supported(self, task: str) -> None:
|
||
"""
|
||
Confirm the contained model_name is deployed on the Inference API service.
|
||
|
||
Args:
|
||
task: Hugging Face task to check within. A list of all tasks can be
|
||
found here: https://huggingface.co/tasks
|
||
"""
|
||
all_models = self._sync_client.list_deployed_models(frameworks="all")
|
||
try:
|
||
if self.model_name not in all_models[task]:
|
||
raise ValueError(
|
||
"The Inference API service doesn't have the model"
|
||
f" {self.model_name!r} deployed."
|
||
)
|
||
except KeyError as exc:
|
||
raise KeyError(
|
||
f"Input task {task!r} not in possible tasks {list(all_models.keys())}."
|
||
) from exc
|
||
|
||
def get_model_info(self, **kwargs: Any) -> "ModelInfo":
|
||
"""Get metadata on the current model from Hugging Face."""
|
||
return self._get_model_info(self.model_name, **kwargs)
|
||
|
||
@property
|
||
def metadata(self) -> LLMMetadata:
|
||
return LLMMetadata(
|
||
context_window=self.context_window,
|
||
num_output=self.num_output,
|
||
is_chat_model=self.is_chat_model,
|
||
is_function_calling_model=self.is_function_calling_model,
|
||
model_name=self.model_name,
|
||
)
|
||
|
||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
||
# default to conversational task as that was the previous functionality
|
||
if self.task == "conversational" or self.task is None:
|
||
output: "ConversationalOutput" = self._sync_client.conversational(
|
||
**{**chat_messages_to_conversational_kwargs(messages), **kwargs}
|
||
)
|
||
return ChatResponse(
|
||
message=ChatMessage(
|
||
role=MessageRole.ASSISTANT, content=output["generated_text"]
|
||
)
|
||
)
|
||
else:
|
||
# try and use text generation
|
||
prompt = self.messages_to_prompt(messages)
|
||
completion = self.complete(prompt)
|
||
return ChatResponse(
|
||
message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text)
|
||
)
|
||
|
||
def complete(
|
||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||
) -> CompletionResponse:
|
||
return CompletionResponse(
|
||
text=self._sync_client.text_generation(
|
||
prompt, **{**{"max_new_tokens": self.num_output}, **kwargs}
|
||
)
|
||
)
|
||
|
||
def stream_chat(
|
||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||
) -> ChatResponseGen:
|
||
raise NotImplementedError
|
||
|
||
def stream_complete(
|
||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||
) -> CompletionResponseGen:
|
||
raise NotImplementedError
|
||
|
||
async def achat(
|
||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||
) -> ChatResponse:
|
||
raise NotImplementedError
|
||
|
||
async def acomplete(
|
||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||
) -> CompletionResponse:
|
||
response = await self._async_client.text_generation(
|
||
prompt, **{**{"max_new_tokens": self.num_output}, **kwargs}
|
||
)
|
||
return CompletionResponse(text=response)
|
||
|
||
async def astream_chat(
|
||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||
) -> ChatResponseAsyncGen:
|
||
raise NotImplementedError
|
||
|
||
async def astream_complete(
|
||
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||
) -> CompletionResponseAsyncGen:
|
||
raise NotImplementedError
|