faiss_rag_enterprise/llama_index/embeddings/huggingface_utils.py

100 lines
3.1 KiB
Python

from typing import Optional
import requests
DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base"
# Originally pulled from:
# https://github.com/langchain-ai/langchain/blob/v0.0.257/libs/langchain/langchain/embeddings/huggingface.py#L10
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = (
"Represent the question for retrieving supporting documents: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
"Represent this question for searching relevant passages: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
BGE_MODELS = (
"BAAI/bge-small-en",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-small-zh",
"BAAI/bge-small-zh-v1.5",
"BAAI/bge-base-zh",
"BAAI/bge-base-zh-v1.5",
"BAAI/bge-large-zh",
"BAAI/bge-large-zh-v1.5",
)
INSTRUCTOR_MODELS = (
"hku-nlp/instructor-base",
"hku-nlp/instructor-large",
"hku-nlp/instructor-xl",
"hkunlp/instructor-base",
"hkunlp/instructor-large",
"hkunlp/instructor-xl",
)
def get_query_instruct_for_model_name(model_name: Optional[str]) -> str:
"""Get query text instruction for a given model name."""
if model_name in INSTRUCTOR_MODELS:
return DEFAULT_QUERY_INSTRUCTION
if model_name in BGE_MODELS:
if "zh" in model_name:
return DEFAULT_QUERY_BGE_INSTRUCTION_ZH
return DEFAULT_QUERY_BGE_INSTRUCTION_EN
return ""
def format_query(
query: str, model_name: Optional[str], instruction: Optional[str] = None
) -> str:
if instruction is None:
instruction = get_query_instruct_for_model_name(model_name)
# NOTE: strip() enables backdoor for defeating instruction prepend by
# passing empty string
return f"{instruction} {query}".strip()
def get_text_instruct_for_model_name(model_name: Optional[str]) -> str:
"""Get text instruction for a given model name."""
return DEFAULT_EMBED_INSTRUCTION if model_name in INSTRUCTOR_MODELS else ""
def format_text(
text: str, model_name: Optional[str], instruction: Optional[str] = None
) -> str:
if instruction is None:
instruction = get_text_instruct_for_model_name(model_name)
# NOTE: strip() enables backdoor for defeating instruction prepend by
# passing empty string
return f"{instruction} {text}".strip()
def get_pooling_mode(model_name: Optional[str]) -> str:
pooling_config_url = (
f"https://huggingface.co/{model_name}/raw/main/1_Pooling/config.json"
)
try:
response = requests.get(pooling_config_url)
config_data = response.json()
cls_token = config_data.get("pooling_mode_cls_token", False)
mean_tokens = config_data.get("pooling_mode_mean_tokens", False)
if mean_tokens:
return "mean"
elif cls_token:
return "cls"
except requests.exceptions.RequestException:
print(
"Warning: Pooling config file not found; pooling mode is defaulted to 'cls'."
)
return "cls"