from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Union import torch from transformers import AutoTokenizer from FlagEmbedding import BGEM3FlagModel # 自动检测设备 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # 如果是 GPU,则启用 fp16 加速,否则使用 fp32 USE_FP16 = DEVICE == "cuda" # 加载模型和分词器 MODEL_PATH = "model/bge-m3" model = BGEM3FlagModel(MODEL_PATH, use_fp16=USE_FP16, device=DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) app = FastAPI() class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "text-embedding-bge-m3" @app.post("/v1/embeddings") def create_embedding(request: EmbeddingRequest): # 统一成列表 texts = [request.input] if isinstance(request.input, str) else request.input try: # 先用 tokenizer 统计 token 数 encoding = tokenizer( texts, padding=True, truncation=True, max_length=8192, return_tensors="pt" ) # attention_mask 中 1 的数量即为实际 tokens(不含 padding) mask = encoding["attention_mask"] # prompt_tokens = 所有输入 tokens 之和 prompt_tokens = int(mask.sum().item()) total_tokens = prompt_tokens # embedding 不产生额外 tokens # 生成 dense 向量 output = model.encode(texts, return_dense=True) embeddings = output["dense_vecs"] return { "object": "list", "data": [ { "object": "embedding", "index": idx, "embedding": emb.tolist() if hasattr(emb, "tolist") else emb } for idx, emb in enumerate(embeddings) ], "model": request.model, "usage": { "prompt_tokens": prompt_tokens, "total_tokens": total_tokens } } except Exception as e: raise HTTPException(status_code=500, detail=str(e))