68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from typing import List, Union
|
||
import torch
|
||
from transformers import AutoTokenizer
|
||
from FlagEmbedding import BGEM3FlagModel
|
||
|
||
# 自动检测设备
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
# 如果是 GPU,则启用 fp16 加速,否则使用 fp32
|
||
USE_FP16 = DEVICE == "cuda"
|
||
|
||
# 加载模型和分词器
|
||
MODEL_PATH = "model/bge-m3"
|
||
model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE)
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||
|
||
app = FastAPI()
|
||
|
||
class EmbeddingRequest(BaseModel):
|
||
input: Union[str, List[str]]
|
||
model: str = "text-embedding-bge-m3"
|
||
|
||
@app.post("/v1/embeddings")
|
||
def create_embedding(request: EmbeddingRequest):
|
||
# 统一成列表
|
||
texts = [request.input] if isinstance(request.input, str) else request.input
|
||
|
||
try:
|
||
# 先用 tokenizer 统计 token 数
|
||
encoding = tokenizer(
|
||
texts,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=8192,
|
||
return_tensors="pt"
|
||
)
|
||
# attention_mask 中 1 的数量即为实际 tokens(不含 padding)
|
||
mask = encoding["attention_mask"]
|
||
# prompt_tokens = 所有输入 tokens 之和
|
||
prompt_tokens = int(mask.sum().item())
|
||
total_tokens = prompt_tokens # embedding 不产生额外 tokens
|
||
|
||
# 生成 dense 向量
|
||
output = model.encode(texts, return_dense=True)
|
||
embeddings = output["dense_vecs"]
|
||
|
||
return {
|
||
"object": "list",
|
||
"data": [
|
||
{
|
||
"object": "embedding",
|
||
"index": idx,
|
||
"embedding": emb.tolist() if hasattr(emb, "tolist") else emb
|
||
}
|
||
for idx, emb in enumerate(embeddings)
|
||
],
|
||
"model": request.model,
|
||
"usage": {
|
||
"prompt_tokens": prompt_tokens,
|
||
"total_tokens": total_tokens
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|