// // SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // "use client" import { useState, useEffect, useRef } from "react" import { ChevronDown, Cpu, Server, RefreshCw } from "lucide-react" import { OllamaIcon } from "@/components/ui/ollama-icon" interface Model { id: string name: string icon: React.ReactNode description: string model: string baseURL: string provider: string apiKeyName?: string } // NVIDIA API models (always available) const NVIDIA_MODELS: Model[] = [ { id: "nvidia-nemotron", name: "NVIDIA Llama 3.3 Nemotron Super 49B", icon: , description: "NVIDIA hosted Nemotron Super 49B v1.5 model", model: "nvidia/llama-3.3-nemotron-super-49b-v1.5", apiKeyName: "NVIDIA_API_KEY", baseURL: "https://integrate.api.nvidia.com/v1", provider: "nvidia", }, { id: "nvidia-nemotron-nano", name: "NVIDIA Nemotron Nano 9B v2", icon: , description: "NVIDIA hosted Nemotron Nano 9B v2 - Faster and more efficient", model: "nvidia/nvidia-nemotron-nano-9b-v2", apiKeyName: "NVIDIA_API_KEY", baseURL: "https://integrate.api.nvidia.com/v1", provider: "nvidia", }, ] // Helper to create model objects const createOllamaModel = (modelName: string): Model => ({ id: `ollama-${modelName}`, name: `Ollama ${modelName}`, icon: , description: `Local Ollama model`, model: modelName, baseURL: "http://localhost:11434/v1", provider: "ollama", }) const createVllmModel = (modelName: string): Model => ({ id: `vllm-${modelName}`, name: modelName.split('/').pop() || modelName, icon: , description: "vLLM (GPU-accelerated)", model: modelName, baseURL: "http://localhost:8001/v1", provider: "vllm", }) export function ModelSelector() { const [models, setModels] = useState([]) const [selectedModel, setSelectedModel] = useState(null) const [isOpen, setIsOpen] = useState(false) const [isLoading, setIsLoading] = useState(true) const buttonRef = useRef(null) const containerRef = useRef(null) const [mounted, setMounted] = useState(false) // Fetch available models from running backends const fetchAvailableModels = async () => { setIsLoading(true) const availableModels: Model[] = [] // Check vLLM first (port 8001) try { const vllmResponse = await fetch('/api/vllm/models', { signal: AbortSignal.timeout(3000) }) if (vllmResponse.ok) { const data = await vllmResponse.json() if (data.models && Array.isArray(data.models)) { data.models.forEach((model: any) => { const modelId = model.id || model.name || model availableModels.push(createVllmModel(modelId)) }) } } } catch (e) { console.log("vLLM not available") } // Check Ollama (port 11434) try { const ollamaResponse = await fetch('/api/ollama/tags', { signal: AbortSignal.timeout(3000) }) if (ollamaResponse.ok) { const data = await ollamaResponse.json() if (data.models && Array.isArray(data.models)) { data.models.forEach((model: any) => { const modelName = model.name || model availableModels.push(createOllamaModel(modelName)) }) } } } catch (e) { console.log("Ollama not available") } // Always add NVIDIA API models availableModels.push(...NVIDIA_MODELS) setModels(availableModels) // Set default selected model if (availableModels.length > 0) { // Try to restore saved selection try { const saved = localStorage.getItem("selectedModel") if (saved) { const savedModel = JSON.parse(saved) const found = availableModels.find(m => m.id === savedModel.id) if (found) { setSelectedModel(found) setIsLoading(false) return } } } catch (e) { // Ignore } // Default to first available local model (vLLM or Ollama) const localModel = availableModels.find(m => m.provider === "vllm" || m.provider === "ollama") setSelectedModel(localModel || availableModels[0]) } setIsLoading(false) } // Dispatch custom event when model changes const updateSelectedModel = (model: Model) => { setSelectedModel(model) localStorage.setItem("selectedModel", JSON.stringify(model)) // Dispatch a custom event with the selected model data const event = new CustomEvent('modelSelected', { detail: { model } }) window.dispatchEvent(event) } // Fetch models on mount useEffect(() => { fetchAvailableModels() }, []) // Set mounted state after component mounts (for SSR compatibility) useEffect(() => { setMounted(true) }, []) // Close on outside click and Escape useEffect(() => { const handleMouseDown = (e: MouseEvent) => { if (containerRef.current && !containerRef.current.contains(e.target as Node)) { setIsOpen(false) } } const handleKeyDown = (e: KeyboardEvent) => { if (e.key === 'Escape') setIsOpen(false) } document.addEventListener('mousedown', handleMouseDown) document.addEventListener('keydown', handleKeyDown) return () => { document.removeEventListener('mousedown', handleMouseDown) document.removeEventListener('keydown', handleKeyDown) } }, []) // Listen for Ollama model updates useEffect(() => { const handleOllamaUpdate = () => { console.log("Ollama models updated, reloading...") fetchAvailableModels() } window.addEventListener('ollama-models-updated', handleOllamaUpdate) return () => { window.removeEventListener('ollama-models-updated', handleOllamaUpdate) } }, []) if (isLoading) { return ( Loading models... ) } if (!selectedModel) { return ( No models available ) } // Group models by provider const groupedModels = models.reduce((acc, model) => { if (!acc[model.provider]) { acc[model.provider] = [] } acc[model.provider].push(model) return acc }, {} as Record) const getProviderLabel = (provider: string) => { switch (provider) { case "ollama": return "Ollama (Local)" case "vllm": return "vLLM (GPU-accelerated)" case "nvidia": return "NVIDIA API (Cloud)" default: return provider } } return ( setIsOpen(!isOpen)} > {selectedModel.icon} {selectedModel.name} {isOpen && mounted && ( Select Model { e.stopPropagation() fetchAvailableModels() }} className="p-1 hover:bg-muted/50 rounded" title="Refresh models" > {Object.entries(groupedModels).map(([provider, providerModels]) => ( {getProviderLabel(provider)} {providerModels.map((model) => ( { updateSelectedModel(model) setIsOpen(false) }} > {model.icon} {model.name} {model.description} ))} ))} )} ) }