211 lines
6.1 KiB
TypeScript
211 lines
6.1 KiB
TypeScript
import { ChatbotUIContext } from "@/context/context"
|
|
import { LLM, LLMID, ModelProvider } from "@/types"
|
|
import { IconCheck, IconChevronDown } from "@tabler/icons-react"
|
|
import { FC, useContext, useEffect, useRef, useState } from "react"
|
|
import { Button } from "../ui/button"
|
|
import {
|
|
DropdownMenu,
|
|
DropdownMenuContent,
|
|
DropdownMenuTrigger
|
|
} from "../ui/dropdown-menu"
|
|
import { Input } from "../ui/input"
|
|
import { Tabs, TabsList, TabsTrigger } from "../ui/tabs"
|
|
import { ModelIcon } from "./model-icon"
|
|
import { ModelOption } from "./model-option"
|
|
|
|
import { useTranslation } from 'react-i18next'
|
|
|
|
interface ModelSelectProps {
|
|
selectedModelId: string
|
|
onSelectModel: (modelId: LLMID) => void
|
|
}
|
|
|
|
export const ModelSelect: FC<ModelSelectProps> = ({
|
|
selectedModelId,
|
|
onSelectModel
|
|
}) => {
|
|
|
|
const { t } = useTranslation()
|
|
|
|
const {
|
|
profile,
|
|
models,
|
|
availableHostedModels,
|
|
availableLocalModels,
|
|
availableOpenRouterModels
|
|
} = useContext(ChatbotUIContext)
|
|
|
|
const inputRef = useRef<HTMLInputElement>(null)
|
|
const triggerRef = useRef<HTMLButtonElement>(null)
|
|
|
|
const [isOpen, setIsOpen] = useState(false)
|
|
const [search, setSearch] = useState("")
|
|
const [tab, setTab] = useState<"hosted" | "local">("hosted")
|
|
|
|
useEffect(() => {
|
|
if (isOpen) {
|
|
setTimeout(() => {
|
|
inputRef.current?.focus()
|
|
}, 100) // FIX: hacky
|
|
}
|
|
}, [isOpen])
|
|
|
|
const handleSelectModel = (modelId: LLMID) => {
|
|
onSelectModel(modelId)
|
|
setIsOpen(false)
|
|
}
|
|
|
|
const allModels = [
|
|
...models.map(model => ({
|
|
modelId: model.model_id as LLMID,
|
|
modelName: model.name,
|
|
provider: "custom" as ModelProvider,
|
|
hostedId: model.id,
|
|
platformLink: "",
|
|
imageInput: false
|
|
})),
|
|
...availableHostedModels,
|
|
...availableLocalModels,
|
|
...availableOpenRouterModels
|
|
]
|
|
|
|
const groupedModels = allModels.reduce<Record<string, LLM[]>>(
|
|
(groups, model) => {
|
|
const key = model.provider
|
|
if (!groups[key]) {
|
|
groups[key] = []
|
|
}
|
|
groups[key].push(model)
|
|
return groups
|
|
},
|
|
{}
|
|
)
|
|
|
|
const selectedModel = allModels.find(
|
|
model => model.modelId === selectedModelId
|
|
)
|
|
|
|
if (!profile) return null
|
|
|
|
return (
|
|
<DropdownMenu
|
|
open={isOpen}
|
|
onOpenChange={isOpen => {
|
|
setIsOpen(isOpen)
|
|
setSearch("")
|
|
}}
|
|
>
|
|
<DropdownMenuTrigger
|
|
className="bg-background w-full justify-start border-2 px-3 py-5"
|
|
asChild
|
|
disabled={allModels.length === 0}
|
|
>
|
|
{allModels.length === 0 ? (
|
|
<div className="rounded text-sm font-bold">
|
|
{t("chat.unlockModelsMessage")}
|
|
</div>
|
|
) : (
|
|
<Button
|
|
ref={triggerRef}
|
|
className="flex items-center justify-between"
|
|
variant="ghost"
|
|
>
|
|
<div className="flex items-center">
|
|
{selectedModel ? (
|
|
<>
|
|
<ModelIcon
|
|
provider={selectedModel?.provider}
|
|
width={26}
|
|
height={26}
|
|
/>
|
|
<div className="ml-2 flex items-center">
|
|
{selectedModel?.modelName}
|
|
</div>
|
|
</>
|
|
) : (
|
|
<div className="flex items-center">{t("chat.selectModel")}</div>
|
|
)}
|
|
</div>
|
|
|
|
<IconChevronDown />
|
|
</Button>
|
|
)}
|
|
</DropdownMenuTrigger>
|
|
|
|
<DropdownMenuContent
|
|
className="space-y-2 overflow-auto p-2"
|
|
style={{ width: triggerRef.current?.offsetWidth }}
|
|
align="start"
|
|
>
|
|
<Tabs value={tab} onValueChange={(value: any) => setTab(value)}>
|
|
{availableLocalModels.length > 0 && (
|
|
<TabsList defaultValue="hosted" className="grid grid-cols-2">
|
|
<TabsTrigger value="hosted">{t("chat.hosted")}</TabsTrigger>
|
|
|
|
<TabsTrigger value="local">{t("chat.local")}</TabsTrigger>
|
|
</TabsList>
|
|
)}
|
|
</Tabs>
|
|
|
|
<Input
|
|
ref={inputRef}
|
|
className="w-full"
|
|
placeholder={t("chat.searchModelsPlaceholder")}
|
|
value={search}
|
|
onChange={e => setSearch(e.target.value)}
|
|
/>
|
|
|
|
<div className="max-h-[300px] overflow-auto">
|
|
{Object.entries(groupedModels).map(([provider, models]) => {
|
|
const filteredModels = models
|
|
.filter(model => {
|
|
if (tab === "hosted") return model.provider !== "ollama"
|
|
if (tab === "local") return model.provider === "ollama"
|
|
if (tab === "openrouter") return model.provider === "openrouter"
|
|
})
|
|
.filter(model =>
|
|
model.modelName.toLowerCase().includes(search.toLowerCase())
|
|
)
|
|
.sort((a, b) => a.provider.localeCompare(b.provider))
|
|
|
|
if (filteredModels.length === 0) return null
|
|
|
|
return (
|
|
<div key={provider}>
|
|
<div className="mb-1 ml-2 text-xs font-bold tracking-wide opacity-50">
|
|
{provider === "openai" && profile.use_azure_openai
|
|
? "AZURE OPENAI"
|
|
: provider === "custom"
|
|
? t("modelProvider.custom")
|
|
: provider.toUpperCase()}
|
|
</div>
|
|
|
|
<div className="mb-4">
|
|
{filteredModels.map(model => {
|
|
return (
|
|
<div
|
|
key={model.modelId}
|
|
className="flex items-center space-x-1"
|
|
>
|
|
{selectedModelId === model.modelId && (
|
|
<IconCheck className="ml-2" size={32} />
|
|
)}
|
|
|
|
<ModelOption
|
|
key={model.modelId}
|
|
model={model}
|
|
onSelect={() => handleSelectModel(model.modelId)}
|
|
/>
|
|
</div>
|
|
)
|
|
})}
|
|
</div>
|
|
</div>
|
|
)
|
|
})}
|
|
</div>
|
|
</DropdownMenuContent>
|
|
</DropdownMenu>
|
|
)
|
|
}
|