faiss_rag_enterprise/llama_index/llms/huggingface.py

637 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import (
DEFAULT_CONTEXT_WINDOW,
DEFAULT_NUM_OUTPUTS,
)
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
MessageRole,
)
from llama_index.llms.base import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.llms.custom import CustomLLM
from llama_index.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
from llama_index.prompts.base import PromptTemplate
from llama_index.types import BaseOutputParser, PydanticProgramMode
DEFAULT_HUGGINGFACE_MODEL = "StabilityAI/stablelm-tuned-alpha-3b"
if TYPE_CHECKING:
try:
from huggingface_hub import AsyncInferenceClient, InferenceClient
from huggingface_hub.hf_api import ModelInfo
from huggingface_hub.inference._types import ConversationalOutput
except ModuleNotFoundError:
AsyncInferenceClient = Any
InferenceClient = Any
ConversationalOutput = dict
ModelInfo = Any
logger = logging.getLogger(__name__)
class HuggingFaceLLM(CustomLLM):
"""HuggingFace LLM."""
model_name: str = Field(
default=DEFAULT_HUGGINGFACE_MODEL,
description=(
"The model name to use from HuggingFace. "
"Unused if `model` is passed in directly."
),
)
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum number of tokens available for input.",
gt=0,
)
max_new_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The maximum number of tokens to generate.",
gt=0,
)
system_prompt: str = Field(
default="",
description=(
"The system prompt, containing any extra instructions or context. "
"The model card on HuggingFace should specify if this is needed."
),
)
query_wrapper_prompt: PromptTemplate = Field(
default=PromptTemplate("{query_str}"),
description=(
"The query wrapper prompt, containing the query placeholder. "
"The model card on HuggingFace should specify if this is needed. "
"Should contain a `{query_str}` placeholder."
),
)
tokenizer_name: str = Field(
default=DEFAULT_HUGGINGFACE_MODEL,
description=(
"The name of the tokenizer to use from HuggingFace. "
"Unused if `tokenizer` is passed in directly."
),
)
device_map: str = Field(
default="auto", description="The device_map to use. Defaults to 'auto'."
)
stopping_ids: List[int] = Field(
default_factory=list,
description=(
"The stopping ids to use. "
"Generation stops when these token IDs are predicted."
),
)
tokenizer_outputs_to_remove: list = Field(
default_factory=list,
description=(
"The outputs to remove from the tokenizer. "
"Sometimes huggingface tokenizers return extra inputs that cause errors."
),
)
tokenizer_kwargs: dict = Field(
default_factory=dict, description="The kwargs to pass to the tokenizer."
)
model_kwargs: dict = Field(
default_factory=dict,
description="The kwargs to pass to the model during initialization.",
)
generate_kwargs: dict = Field(
default_factory=dict,
description="The kwargs to pass to the model during generation.",
)
is_chat_model: bool = Field(
default=False,
description=(
LLMMetadata.__fields__["is_chat_model"].field_info.description
+ " Be sure to verify that you either pass an appropriate tokenizer "
"that can convert prompts to properly formatted chat messages or a "
"`messages_to_prompt` that does so."
),
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_stopping_criteria: Any = PrivateAttr()
def __init__(
self,
context_window: int = DEFAULT_CONTEXT_WINDOW,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}",
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
device_map: Optional[str] = "auto",
stopping_ids: Optional[List[int]] = None,
tokenizer_kwargs: Optional[dict] = None,
tokenizer_outputs_to_remove: Optional[list] = None,
model_kwargs: Optional[dict] = None,
generate_kwargs: Optional[dict] = None,
is_chat_model: Optional[bool] = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: str = "",
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
) -> None:
"""Initialize params."""
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
except ImportError as exc:
raise ImportError(
f"{type(self).__name__} requires torch and transformers packages.\n"
"Please install both with `pip install transformers[torch]`."
) from exc
model_kwargs = model_kwargs or {}
self._model = model or AutoModelForCausalLM.from_pretrained(
model_name, device_map=device_map, **model_kwargs
)
# check context_window
config_dict = self._model.config.to_dict()
model_context_window = int(
config_dict.get("max_position_embeddings", context_window)
)
if model_context_window and model_context_window < context_window:
logger.warning(
f"Supplied context_window {context_window} is greater "
f"than the model's max input size {model_context_window}. "
"Disable this warning by setting a lower context_window."
)
context_window = model_context_window
tokenizer_kwargs = tokenizer_kwargs or {}
if "max_length" not in tokenizer_kwargs:
tokenizer_kwargs["max_length"] = context_window
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
tokenizer_name, **tokenizer_kwargs
)
if tokenizer_name != model_name:
logger.warning(
f"The model `{model_name}` and tokenizer `{tokenizer_name}` "
f"are different, please ensure that they are compatible."
)
# setup stopping criteria
stopping_ids_list = stopping_ids or []
class StopOnTokens(StoppingCriteria):
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs: Any,
) -> bool:
for stop_id in stopping_ids_list:
if input_ids[0][-1] == stop_id:
return True
return False
self._stopping_criteria = StoppingCriteriaList([StopOnTokens()])
if isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)
messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt
super().__init__(
context_window=context_window,
max_new_tokens=max_new_tokens,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name=tokenizer_name,
model_name=model_name,
device_map=device_map,
stopping_ids=stopping_ids or [],
tokenizer_kwargs=tokenizer_kwargs or {},
tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [],
model_kwargs=model_kwargs or {},
generate_kwargs=generate_kwargs or {},
is_chat_model=is_chat_model,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
)
@classmethod
def class_name(cls) -> str:
return "HuggingFace_LLM"
@property
def metadata(self) -> LLMMetadata:
"""LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name=self.model_name,
is_chat_model=self.is_chat_model,
)
def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
"""Use the tokenizer to convert messages to prompt. Fallback to generic."""
if hasattr(self._tokenizer, "apply_chat_template"):
messages_dict = [
{"role": message.role.value, "content": message.content}
for message in messages
]
tokens = self._tokenizer.apply_chat_template(messages_dict)
return self._tokenizer.decode(tokens)
return generic_messages_to_prompt(messages)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
"""Completion endpoint."""
full_prompt = prompt
if not formatted:
if self.query_wrapper_prompt:
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
if self.system_prompt:
full_prompt = f"{self.system_prompt} {full_prompt}"
inputs = self._tokenizer(full_prompt, return_tensors="pt")
inputs = inputs.to(self._model.device)
# remove keys from the tokenizer if needed, to avoid HF errors
for key in self.tokenizer_outputs_to_remove:
if key in inputs:
inputs.pop(key, None)
tokens = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
stopping_criteria=self._stopping_criteria,
**self.generate_kwargs,
)
completion_tokens = tokens[0][inputs["input_ids"].size(1) :]
completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)
return CompletionResponse(text=completion, raw={"model_output": tokens})
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
"""Streaming completion endpoint."""
from transformers import TextIteratorStreamer
full_prompt = prompt
if not formatted:
if self.query_wrapper_prompt:
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
if self.system_prompt:
full_prompt = f"{self.system_prompt} {full_prompt}"
inputs = self._tokenizer(full_prompt, return_tensors="pt")
inputs = inputs.to(self._model.device)
# remove keys from the tokenizer if needed, to avoid HF errors
for key in self.tokenizer_outputs_to_remove:
if key in inputs:
inputs.pop(key, None)
streamer = TextIteratorStreamer(
self._tokenizer,
skip_prompt=True,
decode_kwargs={"skip_special_tokens": True},
)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=self.max_new_tokens,
stopping_criteria=self._stopping_criteria,
**self.generate_kwargs,
)
# generate in background thread
# NOTE/TODO: token counting doesn't work with streaming
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
thread.start()
# create generator based off of streamer
def gen() -> CompletionResponseGen:
text = ""
for x in streamer:
text += x
yield CompletionResponse(text=text, delta=x)
return gen()
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)
def chat_messages_to_conversational_kwargs(
messages: Sequence[ChatMessage],
) -> Dict[str, Any]:
"""Convert ChatMessages to keyword arguments for Inference API conversational."""
if len(messages) % 2 != 1:
raise NotImplementedError("Messages passed in must be of odd length.")
last_message = messages[-1]
kwargs: Dict[str, Any] = {
"text": last_message.content,
**last_message.additional_kwargs,
}
if len(messages) != 1:
kwargs["past_user_inputs"] = []
kwargs["generated_responses"] = []
for user_msg, assistant_msg in zip(messages[::2], messages[1::2]):
if (
user_msg.role != MessageRole.USER
or assistant_msg.role != MessageRole.ASSISTANT
):
raise NotImplementedError(
"Didn't handle when messages aren't ordered in alternating"
f" pairs of {(MessageRole.USER, MessageRole.ASSISTANT)}."
)
kwargs["past_user_inputs"].append(user_msg.content)
kwargs["generated_responses"].append(assistant_msg.content)
return kwargs
class HuggingFaceInferenceAPI(CustomLLM):
"""
Wrapper on the Hugging Face's Inference API.
Overview of the design:
- Synchronous uses InferenceClient, asynchronous uses AsyncInferenceClient
- chat uses the conversational task: https://huggingface.co/tasks/conversational
- complete uses the text generation task: https://huggingface.co/tasks/text-generation
Note: some models that support the text generation task can leverage Hugging
Face's optimized deployment toolkit called text-generation-inference (TGI).
Use InferenceClient.get_model_status to check if TGI is being used.
Relevant links:
- General Docs: https://huggingface.co/docs/api-inference/index
- API Docs: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client
- Source: https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub/inference
"""
@classmethod
def class_name(cls) -> str:
return "HuggingFaceInferenceAPI"
# Corresponds with huggingface_hub.InferenceClient
model_name: Optional[str] = Field(
default=None,
description=(
"The model to run inference with. Can be a model id hosted on the Hugging"
" Face Hub, e.g. bigcode/starcoder or a URL to a deployed Inference"
" Endpoint. Defaults to None, in which case a recommended model is"
" automatically selected for the task (see Field below)."
),
)
token: Union[str, bool, None] = Field(
default=None,
description=(
"Hugging Face token. Will default to the locally saved token. Pass "
"token=False if you dont want to send your token to the server."
),
)
timeout: Optional[float] = Field(
default=None,
description=(
"The maximum number of seconds to wait for a response from the server."
" Loading a new model in Inference API can take up to several minutes."
" Defaults to None, meaning it will loop until the server is available."
),
)
headers: Dict[str, str] = Field(
default=None,
description=(
"Additional headers to send to the server. By default only the"
" authorization and user-agent headers are sent. Values in this dictionary"
" will override the default values."
),
)
cookies: Dict[str, str] = Field(
default=None, description="Additional cookies to send to the server."
)
task: Optional[str] = Field(
default=None,
description=(
"Optional task to pick Hugging Face's recommended model, used when"
" model_name is left as default of None."
),
)
_sync_client: "InferenceClient" = PrivateAttr()
_async_client: "AsyncInferenceClient" = PrivateAttr()
_get_model_info: "Callable[..., ModelInfo]" = PrivateAttr()
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
LLMMetadata.__fields__["context_window"].field_info.description
+ " This may be looked up in a model's `config.json`."
),
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description=LLMMetadata.__fields__["num_output"].field_info.description,
)
is_chat_model: bool = Field(
default=False,
description=(
LLMMetadata.__fields__["is_chat_model"].field_info.description
+ " Unless chat templating is intentionally applied, Hugging Face models"
" are not chat models."
),
)
is_function_calling_model: bool = Field(
default=False,
description=(
LLMMetadata.__fields__["is_function_calling_model"].field_info.description
+ " As of 10/17/2023, Hugging Face doesn't support function calling"
" messages."
),
)
def _get_inference_client_kwargs(self) -> Dict[str, Any]:
"""Extract the Hugging Face InferenceClient construction parameters."""
return {
"model": self.model_name,
"token": self.token,
"timeout": self.timeout,
"headers": self.headers,
"cookies": self.cookies,
}
def __init__(self, **kwargs: Any) -> None:
"""Initialize.
Args:
kwargs: See the class-level Fields.
"""
try:
from huggingface_hub import (
AsyncInferenceClient,
InferenceClient,
model_info,
)
except ModuleNotFoundError as exc:
raise ImportError(
f"{type(self).__name__} requires huggingface_hub with its inference"
" extra, please run `pip install huggingface_hub[inference]>=0.19.0`."
) from exc
if kwargs.get("model_name") is None:
task = kwargs.get("task", "")
# NOTE: task being None or empty string leads to ValueError,
# which ensures model is present
kwargs["model_name"] = InferenceClient.get_recommended_model(task=task)
logger.debug(
f"Using Hugging Face's recommended model {kwargs['model_name']}"
f" given task {task}."
)
if kwargs.get("task") is None:
task = "conversational"
else:
task = kwargs["task"].lower()
super().__init__(**kwargs) # Populate pydantic Fields
self._sync_client = InferenceClient(**self._get_inference_client_kwargs())
self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs())
self._get_model_info = model_info
def validate_supported(self, task: str) -> None:
"""
Confirm the contained model_name is deployed on the Inference API service.
Args:
task: Hugging Face task to check within. A list of all tasks can be
found here: https://huggingface.co/tasks
"""
all_models = self._sync_client.list_deployed_models(frameworks="all")
try:
if self.model_name not in all_models[task]:
raise ValueError(
"The Inference API service doesn't have the model"
f" {self.model_name!r} deployed."
)
except KeyError as exc:
raise KeyError(
f"Input task {task!r} not in possible tasks {list(all_models.keys())}."
) from exc
def get_model_info(self, **kwargs: Any) -> "ModelInfo":
"""Get metadata on the current model from Hugging Face."""
return self._get_model_info(self.model_name, **kwargs)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.num_output,
is_chat_model=self.is_chat_model,
is_function_calling_model=self.is_function_calling_model,
model_name=self.model_name,
)
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
# default to conversational task as that was the previous functionality
if self.task == "conversational" or self.task is None:
output: "ConversationalOutput" = self._sync_client.conversational(
**{**chat_messages_to_conversational_kwargs(messages), **kwargs}
)
return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT, content=output["generated_text"]
)
)
else:
# try and use text generation
prompt = self.messages_to_prompt(messages)
completion = self.complete(prompt)
return ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content=completion.text)
)
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
return CompletionResponse(
text=self._sync_client.text_generation(
prompt, **{**{"max_new_tokens": self.num_output}, **kwargs}
)
)
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
raise NotImplementedError
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
raise NotImplementedError
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
raise NotImplementedError
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
response = await self._async_client.text_generation(
prompt, **{**{"max_new_tokens": self.num_output}, **kwargs}
)
return CompletionResponse(text=response)
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
raise NotImplementedError
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
raise NotImplementedError