faiss_rag_enterprise/llama_index/prompts/lmformatenforcer_utils.py

63 lines
2.1 KiB
Python

from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Iterator
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.llms.llm import LLM
if TYPE_CHECKING:
from lmformatenforcer import CharacterLevelParser
def build_lm_format_enforcer_function(
llm: LLM, character_level_parser: "CharacterLevelParser"
) -> Callable:
"""Prepare for using the LM format enforcer.
This builds the processing function that will be injected into the LLM to
activate the LM Format Enforcer.
"""
if isinstance(llm, HuggingFaceLLM):
from lmformatenforcer.integrations.transformers import (
build_transformers_prefix_allowed_tokens_fn,
)
return build_transformers_prefix_allowed_tokens_fn(
llm._tokenizer, character_level_parser
)
if isinstance(llm, LlamaCPP):
from llama_cpp import LogitsProcessorList
from lmformatenforcer.integrations.llamacpp import (
build_llamacpp_logits_processor,
)
return LogitsProcessorList(
[build_llamacpp_logits_processor(llm._model, character_level_parser)]
)
raise ValueError("Unsupported LLM type")
@contextmanager
def activate_lm_format_enforcer(
llm: LLM, lm_format_enforcer_fn: Callable
) -> Iterator[None]:
"""Activate the LM Format Enforcer for the given LLM.
with activate_lm_format_enforcer(llm, lm_format_enforcer_fn):
llm.complete(...)
"""
if isinstance(llm, HuggingFaceLLM):
generate_kwargs_key = "prefix_allowed_tokens_fn"
elif isinstance(llm, LlamaCPP):
generate_kwargs_key = "logits_processor"
else:
raise ValueError("Unsupported LLM type")
llm.generate_kwargs[generate_kwargs_key] = lm_format_enforcer_fn
try:
# This is where the user code will run
yield
finally:
# We remove the token enforcer function from the generate_kwargs at the end
# in case other code paths use the same llm object.
del llm.generate_kwargs[generate_kwargs_key]