This commit is contained in:
parent
be0662d918
commit
6794ebf5c1
|
|
@ -1,4 +1,5 @@
|
||||||
import { generateLocalEmbedding } from "@/lib/generate-local-embedding"
|
import { generateLocalEmbedding } from "@/lib/generate-local-embedding"
|
||||||
|
import { generateBgeM3Embedding } from "@/lib/generate-bge-m3-embedding" // 新增
|
||||||
import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers"
|
import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers"
|
||||||
import { Database } from "@/supabase/types"
|
import { Database } from "@/supabase/types"
|
||||||
import { createClient } from "@supabase/supabase-js"
|
import { createClient } from "@supabase/supabase-js"
|
||||||
|
|
@ -86,6 +87,18 @@ export async function POST(request: Request) {
|
||||||
chunks = localFileItems
|
chunks = localFileItems
|
||||||
} else if (embeddingsProvider === "bge-m3"){
|
} else if (embeddingsProvider === "bge-m3"){
|
||||||
// 示例:调用你自己的 BGE-M3 API 或本地函数
|
// 示例:调用你自己的 BGE-M3 API 或本地函数
|
||||||
|
// 新增:使用 BGE-M3 嵌入
|
||||||
|
const bgeEmbedding = await generateBgeM3Embedding(userInput)
|
||||||
|
|
||||||
|
// 调用对应的 RPC,需要在数据库侧提前定义 match_file_items_bge_m3
|
||||||
|
const { data: bgeFileItems, error: bgeError } =
|
||||||
|
await supabaseAdmin.rpc("match_file_items_bge_m3", {
|
||||||
|
query_embedding: bgeEmbedding as any,
|
||||||
|
match_count: sourceCount,
|
||||||
|
file_ids: uniqueFileIds
|
||||||
|
})
|
||||||
|
if (bgeError) throw bgeError
|
||||||
|
chunks = bgeFileItems
|
||||||
}
|
}
|
||||||
|
|
||||||
const mostSimilarChunks = chunks?.sort(
|
const mostSimilarChunks = chunks?.sort(
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,36 @@
|
||||||
export async function generateBgeM3Embedding(text: string): Promise<number[] | null> {
|
export async function generateBgeM3Embedding(text: string): Promise<number[] | null> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch("http://localhost:8000/embedding", {
|
// 动态获取当前协议和主机(不含端口),然后指定后端端口 8001
|
||||||
|
const { protocol, host } = window.location;
|
||||||
|
const hostname = host.split(":")[0];
|
||||||
|
const apiUrl = `${protocol}//${hostname}:8001/v1/embeddings`;
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({ input: text })
|
body: JSON.stringify({
|
||||||
})
|
// OpenAI 兼容请求字段
|
||||||
|
input: text,
|
||||||
|
model: "text-embedding-bge-m3"
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`Failed to fetch BGE-M3 embedding: ${response.status}`)
|
throw new Error(`Failed to fetch BGE-M3 embedding: ${response.status}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await response.json()
|
// 返回结构为 { object, data: [{ embedding, … }], model, usage }
|
||||||
return result.embedding as number[]
|
const result = await response.json();
|
||||||
|
|
||||||
|
// 取 data[0].embedding
|
||||||
|
if (Array.isArray(result.data) && result.data.length > 0) {
|
||||||
|
return result.data[0].embedding as number[];
|
||||||
|
} else {
|
||||||
|
console.error("Unexpected embedding response format:", result);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("Error in generateBgeM3Embedding:", err)
|
console.error("Error in generateBgeM3Embedding:", err);
|
||||||
return null
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue