103 lines
3.1 KiB
TypeScript
103 lines
3.1 KiB
TypeScript
import { generateLocalEmbedding } from "@/lib/generate-local-embedding"
|
|
import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers"
|
|
import { Database } from "@/supabase/types"
|
|
import { createClient } from "@supabase/supabase-js"
|
|
import OpenAI from "openai"
|
|
|
|
export async function POST(request: Request) {
|
|
const json = await request.json()
|
|
const { userInput, fileIds, embeddingsProvider, sourceCount } = json as {
|
|
userInput: string
|
|
fileIds: string[]
|
|
embeddingsProvider: "openai" | "local"
|
|
sourceCount: number
|
|
}
|
|
|
|
const uniqueFileIds = [...new Set(fileIds)]
|
|
|
|
try {
|
|
const supabaseAdmin = createClient<Database>(
|
|
process.env.NEXT_PUBLIC_SUPABASE_URL!,
|
|
process.env.SUPABASE_SERVICE_ROLE_KEY!
|
|
)
|
|
|
|
const profile = await getServerProfile()
|
|
|
|
if (embeddingsProvider === "openai") {
|
|
if (profile.use_azure_openai) {
|
|
checkApiKey(profile.azure_openai_api_key, "Azure OpenAI")
|
|
} else {
|
|
checkApiKey(profile.openai_api_key, "OpenAI")
|
|
}
|
|
}
|
|
|
|
let chunks: any[] = []
|
|
|
|
let openai
|
|
if (profile.use_azure_openai) {
|
|
openai = new OpenAI({
|
|
apiKey: profile.azure_openai_api_key || "",
|
|
baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`,
|
|
defaultQuery: { "api-version": "2023-12-01-preview" },
|
|
defaultHeaders: { "api-key": profile.azure_openai_api_key }
|
|
})
|
|
} else {
|
|
openai = new OpenAI({
|
|
apiKey: profile.openai_api_key || "",
|
|
organization: profile.openai_organization_id
|
|
})
|
|
}
|
|
|
|
if (embeddingsProvider === "openai") {
|
|
const response = await openai.embeddings.create({
|
|
model: "text-embedding-3-small",
|
|
input: userInput
|
|
})
|
|
|
|
const openaiEmbedding = response.data.map(item => item.embedding)[0]
|
|
|
|
const { data: openaiFileItems, error: openaiError } =
|
|
await supabaseAdmin.rpc("match_file_items_openai", {
|
|
query_embedding: openaiEmbedding as any,
|
|
match_count: sourceCount,
|
|
file_ids: uniqueFileIds
|
|
})
|
|
|
|
if (openaiError) {
|
|
throw openaiError
|
|
}
|
|
|
|
chunks = openaiFileItems
|
|
} else if (embeddingsProvider === "local") {
|
|
const localEmbedding = await generateLocalEmbedding(userInput)
|
|
|
|
const { data: localFileItems, error: localFileItemsError } =
|
|
await supabaseAdmin.rpc("match_file_items_local", {
|
|
query_embedding: localEmbedding as any,
|
|
match_count: sourceCount,
|
|
file_ids: uniqueFileIds
|
|
})
|
|
|
|
if (localFileItemsError) {
|
|
throw localFileItemsError
|
|
}
|
|
|
|
chunks = localFileItems
|
|
}
|
|
|
|
const mostSimilarChunks = chunks?.sort(
|
|
(a, b) => b.similarity - a.similarity
|
|
)
|
|
|
|
return new Response(JSON.stringify({ results: mostSimilarChunks }), {
|
|
status: 200
|
|
})
|
|
} catch (error: any) {
|
|
const errorMessage = error.error?.message || "An unexpected error occurred"
|
|
const errorCode = error.status || 500
|
|
return new Response(JSON.stringify({ message: errorMessage }), {
|
|
status: errorCode
|
|
})
|
|
}
|
|
}
|