From 6794ebf5c1858a83c9b18396034d1b0838797d2c Mon Sep 17 00:00:00 2001 From: hailin Date: Wed, 28 May 2025 20:34:22 +0800 Subject: [PATCH] . --- .../app/api/retrieval/retrieve/route.ts | 13 ++++++++ chatdesk-ui/lib/generate-bgem3-embedding.ts | 33 ++++++++++++++----- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/chatdesk-ui/app/api/retrieval/retrieve/route.ts b/chatdesk-ui/app/api/retrieval/retrieve/route.ts index adc2ead..2a7a543 100644 --- a/chatdesk-ui/app/api/retrieval/retrieve/route.ts +++ b/chatdesk-ui/app/api/retrieval/retrieve/route.ts @@ -1,4 +1,5 @@ import { generateLocalEmbedding } from "@/lib/generate-local-embedding" +import { generateBgeM3Embedding } from "@/lib/generate-bge-m3-embedding" // 新增 import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" import { Database } from "@/supabase/types" import { createClient } from "@supabase/supabase-js" @@ -86,6 +87,18 @@ export async function POST(request: Request) { chunks = localFileItems } else if (embeddingsProvider === "bge-m3"){ // 示例:调用你自己的 BGE-M3 API 或本地函数 + // 新增:使用 BGE-M3 嵌入 + const bgeEmbedding = await generateBgeM3Embedding(userInput) + + // 调用对应的 RPC,需要在数据库侧提前定义 match_file_items_bge_m3 + const { data: bgeFileItems, error: bgeError } = + await supabaseAdmin.rpc("match_file_items_bge_m3", { + query_embedding: bgeEmbedding as any, + match_count: sourceCount, + file_ids: uniqueFileIds + }) + if (bgeError) throw bgeError + chunks = bgeFileItems } const mostSimilarChunks = chunks?.sort( diff --git a/chatdesk-ui/lib/generate-bgem3-embedding.ts b/chatdesk-ui/lib/generate-bgem3-embedding.ts index 3472c53..699e1ee 100644 --- a/chatdesk-ui/lib/generate-bgem3-embedding.ts +++ b/chatdesk-ui/lib/generate-bgem3-embedding.ts @@ -1,19 +1,36 @@ export async function generateBgeM3Embedding(text: string): Promise { try { - const response = await fetch("http://localhost:8000/embedding", { + // 动态获取当前协议和主机(不含端口),然后指定后端端口 8001 + const { protocol, host } = window.location; + const hostname = host.split(":")[0]; + const apiUrl = `${protocol}//${hostname}:8001/v1/embeddings`; + + const response = await fetch(apiUrl, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ input: text }) - }) + body: JSON.stringify({ + // OpenAI 兼容请求字段 + input: text, + model: "text-embedding-bge-m3" + }) + }); if (!response.ok) { - throw new Error(`Failed to fetch BGE-M3 embedding: ${response.status}`) + throw new Error(`Failed to fetch BGE-M3 embedding: ${response.status}`); } - const result = await response.json() - return result.embedding as number[] + // 返回结构为 { object, data: [{ embedding, … }], model, usage } + const result = await response.json(); + + // 取 data[0].embedding + if (Array.isArray(result.data) && result.data.length > 0) { + return result.data[0].embedding as number[]; + } else { + console.error("Unexpected embedding response format:", result); + return null; + } } catch (err) { - console.error("Error in generateBgeM3Embedding:", err) - return null + console.error("Error in generateBgeM3Embedding:", err); + return null; } }