embed-bge-m3/app/main.py.old

68 lines
2.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Union
import torch
from transformers import AutoTokenizer
from FlagEmbedding import BGEM3FlagModel
# 自动检测设备
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 如果是 GPU则启用 fp16 加速,否则使用 fp32
USE_FP16 = DEVICE == "cuda"
# 加载模型和分词器
MODEL_PATH = "model/bge-m3"
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-bge-m3"
@app.post("/v1/embeddings")
def create_embedding(request: EmbeddingRequest):
# 统一成列表
texts = [request.input] if isinstance(request.input, str) else request.input
try:
# 先用 tokenizer 统计 token 数
encoding = tokenizer(
texts,
padding=True,
truncation=True,
max_length=8192,
return_tensors="pt"
)
# attention_mask 中 1 的数量即为实际 tokens不含 padding
mask = encoding["attention_mask"]
# prompt_tokens = 所有输入 tokens 之和
prompt_tokens = int(mask.sum().item())
total_tokens = prompt_tokens # embedding 不产生额外 tokens
# 生成 dense 向量
output = model.encode(texts, return_dense=True)
embeddings = output["dense_vecs"]
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
}
for idx, emb in enumerate(embeddings)
],
"model": request.model,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))