197 lines
5.9 KiB
TypeScript
197 lines
5.9 KiB
TypeScript
import { generateLocalEmbedding } from "@/lib/generate-local-embedding"
|
|
import { generateBgeM3Embedding } from "@/lib/generate-bgem3-embedding"
|
|
import {
|
|
processCSV,
|
|
processJSON,
|
|
processMarkdown,
|
|
processPdf,
|
|
processTxt
|
|
} from "@/lib/retrieval/processing"
|
|
import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers"
|
|
import { Database } from "@/supabase/types"
|
|
import { FileItemChunk } from "@/types"
|
|
import { createClient } from "@supabase/supabase-js"
|
|
import { NextResponse } from "next/server"
|
|
import OpenAI from "openai"
|
|
import { getRuntimeEnv } from "@/lib/ipconfig" // 新增引入
|
|
|
|
export async function POST(req: Request) {
|
|
console.log("......[process] Starting file processing request")
|
|
try {
|
|
const supabaseAdmin = createClient<Database>(
|
|
getRuntimeEnv("SUPABASE_URL") ?? "http://localhost:8000",
|
|
process.env.SUPABASE_SERVICE_ROLE_KEY!
|
|
)
|
|
|
|
const profile = await getServerProfile()
|
|
|
|
const formData = await req.formData()
|
|
|
|
const file_id = formData.get("file_id") as string
|
|
const embeddingsProvider = formData.get("embeddingsProvider") as string
|
|
|
|
const { data: fileMetadata, error: metadataError } = await supabaseAdmin
|
|
.from("files")
|
|
.select("*")
|
|
.eq("id", file_id)
|
|
.single()
|
|
|
|
if (metadataError) {
|
|
throw new Error(
|
|
`Failed to retrieve file metadata: ${metadataError.message}`
|
|
)
|
|
}
|
|
|
|
if (!fileMetadata) {
|
|
throw new Error("File not found")
|
|
}
|
|
|
|
if (fileMetadata.user_id !== profile.user_id) {
|
|
throw new Error("Unauthorized")
|
|
}
|
|
|
|
const { data: file, error: fileError } = await supabaseAdmin.storage
|
|
.from("files")
|
|
.download(fileMetadata.file_path)
|
|
|
|
if (fileError)
|
|
throw new Error(`Failed to retrieve file: ${fileError.message}`)
|
|
|
|
const fileBuffer = Buffer.from(await file.arrayBuffer())
|
|
const blob = new Blob([fileBuffer])
|
|
const fileExtension = fileMetadata.name.split(".").pop()?.toLowerCase()
|
|
|
|
if (embeddingsProvider === "openai") {
|
|
try {
|
|
if (profile.use_azure_openai) {
|
|
checkApiKey(profile.azure_openai_api_key, "Azure OpenAI")
|
|
} else {
|
|
checkApiKey(profile.openai_api_key, "OpenAI")
|
|
}
|
|
} catch (error: any) {
|
|
error.message =
|
|
error.message +
|
|
", make sure it is configured or else use local embeddings"
|
|
throw error
|
|
}
|
|
}
|
|
|
|
let chunks: FileItemChunk[] = []
|
|
|
|
switch (fileExtension) {
|
|
case "csv":
|
|
chunks = await processCSV(blob)
|
|
break
|
|
case "json":
|
|
chunks = await processJSON(blob)
|
|
break
|
|
case "md":
|
|
chunks = await processMarkdown(blob)
|
|
break
|
|
case "pdf":
|
|
console.log("......[process] Processing PDF...")
|
|
chunks = await processPdf(blob)
|
|
console.log("......[process] PDF Processed.")
|
|
break
|
|
case "txt":
|
|
chunks = await processTxt(blob)
|
|
break
|
|
default:
|
|
return new NextResponse("Unsupported file type", {
|
|
status: 400
|
|
})
|
|
}
|
|
|
|
let embeddings: any = []
|
|
|
|
let openai
|
|
if (profile.use_azure_openai) {
|
|
openai = new OpenAI({
|
|
apiKey: profile.azure_openai_api_key || "",
|
|
baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`,
|
|
defaultQuery: { "api-version": "2023-12-01-preview" },
|
|
defaultHeaders: { "api-key": profile.azure_openai_api_key }
|
|
})
|
|
} else {
|
|
openai = new OpenAI({
|
|
apiKey: profile.openai_api_key || "",
|
|
organization: profile.openai_organization_id
|
|
})
|
|
}
|
|
|
|
if (embeddingsProvider === "openai") {
|
|
const response = await openai.embeddings.create({
|
|
model: "text-embedding-3-small",
|
|
input: chunks.map(chunk => chunk.content)
|
|
})
|
|
|
|
embeddings = response.data.map((item: any) => {
|
|
return item.embedding
|
|
})
|
|
} else if (embeddingsProvider === "local") {
|
|
const embeddingPromises = chunks.map(async chunk => {
|
|
try {
|
|
return await generateLocalEmbedding(chunk.content)
|
|
} catch (error) {
|
|
console.error(`Error generating embedding for chunk: ${chunk}`, error)
|
|
|
|
return null
|
|
}
|
|
})
|
|
|
|
embeddings = await Promise.all(embeddingPromises)
|
|
} else if (embeddingsProvider === "bge-m3"){
|
|
console.log("......[embedding] enter bge-m3.")
|
|
// 示例:调用你自己的 BGE-M3 API 或本地函数
|
|
const embeddingPromises = chunks.map(async chunk => {
|
|
try {
|
|
return await generateBgeM3Embedding(chunk.content)
|
|
} catch (error) {
|
|
console.error(`Error generating BGE-M3 embedding for chunk: ${chunk}`, error)
|
|
return null
|
|
}
|
|
})
|
|
embeddings = await Promise.all(embeddingPromises)
|
|
}
|
|
|
|
const file_items = chunks.map((chunk, index) => ({
|
|
file_id,
|
|
user_id: profile.user_id,
|
|
content: chunk.content,
|
|
tokens: chunk.tokens,
|
|
openai_embedding:
|
|
embeddingsProvider === "openai"
|
|
? ((embeddings[index] || null) as any)
|
|
: null,
|
|
local_embedding:
|
|
embeddingsProvider === "local"
|
|
? ((embeddings[index] || null) as any)
|
|
: null,
|
|
bge_m3_embedding:
|
|
embeddingsProvider === "bge-m3"
|
|
? (embeddings[index] || null) as any
|
|
: null
|
|
}))
|
|
|
|
await supabaseAdmin.from("file_items").upsert(file_items)
|
|
|
|
const totalTokens = file_items.reduce((acc, item) => acc + item.tokens, 0)
|
|
|
|
await supabaseAdmin
|
|
.from("files")
|
|
.update({ tokens: totalTokens })
|
|
.eq("id", file_id)
|
|
|
|
return new NextResponse("Embed Successful", {
|
|
status: 200
|
|
})
|
|
} catch (error: any) {
|
|
console.log(`Error in retrieval/process: ${error.stack}`)
|
|
const errorMessage = error?.message || "An unexpected error occurred"
|
|
const errorCode = error.status || 500
|
|
return new Response(JSON.stringify({ message: errorMessage }), {
|
|
status: errorCode
|
|
})
|
|
}
|
|
}
|