import { ChatbotUIContext } from "@/context/context" import { LLM, LLMID, ModelProvider } from "@/types" import { IconCheck, IconChevronDown } from "@tabler/icons-react" import { FC, useContext, useEffect, useRef, useState } from "react" import { Button } from "../ui/button" import { DropdownMenu, DropdownMenuContent, DropdownMenuTrigger } from "../ui/dropdown-menu" import { Input } from "../ui/input" import { Tabs, TabsList, TabsTrigger } from "../ui/tabs" import { ModelIcon } from "./model-icon" import { ModelOption } from "./model-option" import { useTranslation } from 'react-i18next' interface ModelSelectProps { selectedModelId: string onSelectModel: (modelId: LLMID) => void } export const ModelSelect: FC = ({ selectedModelId, onSelectModel }) => { const { t } = useTranslation() const { profile, models, availableHostedModels, availableLocalModels, availableOpenRouterModels } = useContext(ChatbotUIContext) const inputRef = useRef(null) const triggerRef = useRef(null) const [isOpen, setIsOpen] = useState(false) const [search, setSearch] = useState("") const [tab, setTab] = useState<"hosted" | "local">("hosted") useEffect(() => { if (isOpen) { setTimeout(() => { inputRef.current?.focus() }, 100) // FIX: hacky } }, [isOpen]) const handleSelectModel = (modelId: LLMID) => { onSelectModel(modelId) setIsOpen(false) } const allModels = [ ...models.map(model => ({ modelId: model.model_id as LLMID, modelName: model.name, provider: "custom" as ModelProvider, hostedId: model.id, platformLink: "", imageInput: false })), ...availableHostedModels, ...availableLocalModels, ...availableOpenRouterModels ] const groupedModels = allModels.reduce>( (groups, model) => { const key = model.provider if (!groups[key]) { groups[key] = [] } groups[key].push(model) return groups }, {} ) const selectedModel = allModels.find( model => model.modelId === selectedModelId ) if (!profile) return null return ( { setIsOpen(isOpen) setSearch("") }} > {allModels.length === 0 ? (
{t("chat.unlockModelsMessage")}
) : ( )}
setTab(value)}> {availableLocalModels.length > 0 && ( {t("chat.hosted")} {t("chat.local")} )} setSearch(e.target.value)} />
{Object.entries(groupedModels).map(([provider, models]) => { const filteredModels = models .filter(model => { if (tab === "hosted") return model.provider !== "ollama" if (tab === "local") return model.provider === "ollama" if (tab === "openrouter") return model.provider === "openrouter" }) .filter(model => model.modelName.toLowerCase().includes(search.toLowerCase()) ) .sort((a, b) => a.provider.localeCompare(b.provider)) if (filteredModels.length === 0) return null return (
{provider === "openai" && profile.use_azure_openai ? "AZURE OPENAI" : provider === "custom" ? t("modelProvider.custom") : provider.toUpperCase()}
{filteredModels.map(model => { return (
{selectedModelId === model.modelId && ( )} handleSelectModel(model.modelId)} />
) })}
) })}
) }