145 lines
4.7 KiB
TypeScript
145 lines
4.7 KiB
TypeScript
import { generateLocalEmbedding } from "@/lib/generate-local-embedding"
|
||
import { generateBgeM3Embedding } from "@/lib/generate-bgem3-embedding" // 新增
|
||
import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers"
|
||
import { Database } from "@/supabase/types"
|
||
import { createClient } from "@supabase/supabase-js"
|
||
import OpenAI from "openai"
|
||
import { getRuntimeEnv } from "@/lib/ipconfig" // 新增引入
|
||
|
||
export async function POST(request: Request) {
|
||
|
||
console.log("......[retrieve] request=", request)
|
||
|
||
const json = await request.json()
|
||
const { userInput, fileIds, embeddingsProvider, sourceCount } = json as {
|
||
userInput: string
|
||
fileIds: string[]
|
||
embeddingsProvider: "openai" | "local" | "bge-m3"
|
||
sourceCount: number
|
||
}
|
||
|
||
const uniqueFileIds = [...new Set(fileIds)]
|
||
|
||
try {
|
||
|
||
const rawSupaUrl = getRuntimeEnv("SUPABASE_URL") ?? "http://localhost:8000"
|
||
const supaUrlObj = new URL(rawSupaUrl)
|
||
supaUrlObj.port = "8000"
|
||
|
||
const supabaseAdmin = createClient<Database>(
|
||
supaUrlObj.origin,
|
||
// getRuntimeEnv("SUPABASE_URL") ?? "http://localhost:8000",
|
||
process.env.SUPABASE_SERVICE_ROLE_KEY!
|
||
)
|
||
|
||
const profile = await getServerProfile()
|
||
|
||
if (embeddingsProvider === "openai") {
|
||
if (profile.use_azure_openai) {
|
||
checkApiKey(profile.azure_openai_api_key, "Azure OpenAI")
|
||
} else {
|
||
checkApiKey(profile.openai_api_key, "OpenAI")
|
||
}
|
||
}
|
||
|
||
let chunks: any[] = []
|
||
|
||
let openai
|
||
if (profile.use_azure_openai) {
|
||
openai = new OpenAI({
|
||
apiKey: profile.azure_openai_api_key || "",
|
||
baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`,
|
||
defaultQuery: { "api-version": "2023-12-01-preview" },
|
||
defaultHeaders: { "api-key": profile.azure_openai_api_key }
|
||
})
|
||
} else {
|
||
openai = new OpenAI({
|
||
apiKey: profile.openai_api_key || "",
|
||
organization: profile.openai_organization_id
|
||
})
|
||
}
|
||
|
||
if (embeddingsProvider === "openai") {
|
||
const response = await openai.embeddings.create({
|
||
model: "text-embedding-3-small",
|
||
input: userInput
|
||
})
|
||
|
||
const openaiEmbedding = response.data.map(item => item.embedding)[0]
|
||
|
||
const { data: openaiFileItems, error: openaiError } =
|
||
await supabaseAdmin.rpc("match_file_items_openai", {
|
||
query_embedding: openaiEmbedding as any,
|
||
match_count: sourceCount,
|
||
file_ids: uniqueFileIds
|
||
})
|
||
|
||
if (openaiError) {
|
||
throw openaiError
|
||
}
|
||
|
||
chunks = openaiFileItems
|
||
} else if (embeddingsProvider === "local") {
|
||
const localEmbedding = await generateLocalEmbedding(userInput)
|
||
|
||
const { data: localFileItems, error: localFileItemsError } =
|
||
await supabaseAdmin.rpc("match_file_items_local", {
|
||
query_embedding: localEmbedding as any,
|
||
match_count: sourceCount,
|
||
file_ids: uniqueFileIds
|
||
})
|
||
|
||
if (localFileItemsError) {
|
||
throw localFileItemsError
|
||
}
|
||
|
||
chunks = localFileItems
|
||
} else if (embeddingsProvider === "bge-m3"){
|
||
// 示例:调用你自己的 BGE-M3 API 或本地函数
|
||
// 新增:使用 BGE-M3 嵌入
|
||
console.log("......[retrieve] userInput=",userInput)
|
||
const bgeEmbedding = await generateBgeM3Embedding(userInput)
|
||
console.log("......[retrieve] [bge-m3] got embedding:", bgeEmbedding)
|
||
|
||
// 3. 调用 RPC 之前打印参数
|
||
console.log(
|
||
"......[retrieve] [bge-m3] calling RPC match_file_items_bge_m3 with:",
|
||
{
|
||
query_embedding: bgeEmbedding,
|
||
match_count: sourceCount,
|
||
file_ids: uniqueFileIds,
|
||
}
|
||
)
|
||
|
||
// 调用对应的 RPC,需要在数据库侧提前定义 match_file_items_bge_m3
|
||
const { data: bgeFileItems, error: bgeError } =
|
||
await supabaseAdmin.rpc("match_file_items_bge_m3", {
|
||
query_embedding: bgeEmbedding as any,
|
||
match_count: sourceCount,
|
||
file_ids: uniqueFileIds
|
||
})
|
||
if (bgeError) {
|
||
console.error("......[retrieve] [bge-m3] RPC error:", bgeError)
|
||
throw bgeError
|
||
}
|
||
|
||
console.log("......[retrieve] [bge-m3] RPC result count =", bgeFileItems?.length)
|
||
chunks = bgeFileItems
|
||
}
|
||
|
||
const mostSimilarChunks = chunks?.sort(
|
||
(a, b) => b.similarity - a.similarity
|
||
)
|
||
|
||
return new Response(JSON.stringify({ results: mostSimilarChunks }), {
|
||
status: 200
|
||
})
|
||
} catch (error: any) {
|
||
const errorMessage = error.error?.message || "An unexpected error occurred"
|
||
const errorCode = error.status || 500
|
||
return new Response(JSON.stringify({ message: errorMessage }), {
|
||
status: errorCode
|
||
})
|
||
}
|
||
}
|