chatdesk-ui/chatdesk-ui/app/api/retrieval/retrieve/route.ts

145 lines
4.7 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { generateLocalEmbedding } from "@/lib/generate-local-embedding"
import { generateBgeM3Embedding } from "@/lib/generate-bgem3-embedding" // 新增
import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers"
import { Database } from "@/supabase/types"
import { createClient } from "@supabase/supabase-js"
import OpenAI from "openai"
import { getRuntimeEnv } from "@/lib/ipconfig" // 新增引入
export async function POST(request: Request) {
console.log("......[retrieve] request=", request)
const json = await request.json()
const { userInput, fileIds, embeddingsProvider, sourceCount } = json as {
userInput: string
fileIds: string[]
embeddingsProvider: "openai" | "local" | "bge-m3"
sourceCount: number
}
const uniqueFileIds = [...new Set(fileIds)]
try {
const rawSupaUrl = getRuntimeEnv("SUPABASE_URL") ?? "http://localhost:8000"
const supaUrlObj = new URL(rawSupaUrl)
supaUrlObj.port = "8000"
const supabaseAdmin = createClient<Database>(
supaUrlObj.origin,
// getRuntimeEnv("SUPABASE_URL") ?? "http://localhost:8000",
process.env.SUPABASE_SERVICE_ROLE_KEY!
)
const profile = await getServerProfile()
if (embeddingsProvider === "openai") {
if (profile.use_azure_openai) {
checkApiKey(profile.azure_openai_api_key, "Azure OpenAI")
} else {
checkApiKey(profile.openai_api_key, "OpenAI")
}
}
let chunks: any[] = []
let openai
if (profile.use_azure_openai) {
openai = new OpenAI({
apiKey: profile.azure_openai_api_key || "",
baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`,
defaultQuery: { "api-version": "2023-12-01-preview" },
defaultHeaders: { "api-key": profile.azure_openai_api_key }
})
} else {
openai = new OpenAI({
apiKey: profile.openai_api_key || "",
organization: profile.openai_organization_id
})
}
if (embeddingsProvider === "openai") {
const response = await openai.embeddings.create({
model: "text-embedding-3-small",
input: userInput
})
const openaiEmbedding = response.data.map(item => item.embedding)[0]
const { data: openaiFileItems, error: openaiError } =
await supabaseAdmin.rpc("match_file_items_openai", {
query_embedding: openaiEmbedding as any,
match_count: sourceCount,
file_ids: uniqueFileIds
})
if (openaiError) {
throw openaiError
}
chunks = openaiFileItems
} else if (embeddingsProvider === "local") {
const localEmbedding = await generateLocalEmbedding(userInput)
const { data: localFileItems, error: localFileItemsError } =
await supabaseAdmin.rpc("match_file_items_local", {
query_embedding: localEmbedding as any,
match_count: sourceCount,
file_ids: uniqueFileIds
})
if (localFileItemsError) {
throw localFileItemsError
}
chunks = localFileItems
} else if (embeddingsProvider === "bge-m3"){
// 示例:调用你自己的 BGE-M3 API 或本地函数
// 新增:使用 BGE-M3 嵌入
console.log("......[retrieve] userInput=",userInput)
const bgeEmbedding = await generateBgeM3Embedding(userInput)
console.log("......[retrieve] [bge-m3] got embedding:", bgeEmbedding)
// 3. 调用 RPC 之前打印参数
console.log(
"......[retrieve] [bge-m3] calling RPC match_file_items_bge_m3 with:",
{
query_embedding: bgeEmbedding,
match_count: sourceCount,
file_ids: uniqueFileIds,
}
)
// 调用对应的 RPC需要在数据库侧提前定义 match_file_items_bge_m3
const { data: bgeFileItems, error: bgeError } =
await supabaseAdmin.rpc("match_file_items_bge_m3", {
query_embedding: bgeEmbedding as any,
match_count: sourceCount,
file_ids: uniqueFileIds
})
if (bgeError) {
console.error("......[retrieve] [bge-m3] RPC error:", bgeError)
throw bgeError
}
console.log("......[retrieve] [bge-m3] RPC result count =", bgeFileItems?.length)
chunks = bgeFileItems
}
const mostSimilarChunks = chunks?.sort(
(a, b) => b.similarity - a.similarity
)
return new Response(JSON.stringify({ results: mostSimilarChunks }), {
status: 200
})
} catch (error: any) {
const errorMessage = error.error?.message || "An unexpected error occurred"
const errorCode = error.status || 500
return new Response(JSON.stringify({ message: errorMessage }), {
status: errorCode
})
}
}