dgx-spark-playbooks/nvidia/station-txt2kg/assets/frontend/components/llm-selector-compact.tsx
2026-05-26 18:25:53 +00:00

311 lines
10 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
"use client"
import { useState, useEffect } from "react"
import { ChevronDown, Cpu, Server, RefreshCw } from "lucide-react"
import { OllamaIcon } from "@/components/ui/ollama-icon"
interface LLMModel {
id: string
name: string
model: string
provider: string
description?: string
}
// NVIDIA API models (always available if API key is set)
const NVIDIA_MODELS: LLMModel[] = [
{
id: "nvidia-nemotron-super",
name: "Nemotron Super 49B",
model: "nvidia/llama-3.3-nemotron-super-49b-v1.5",
provider: "nvidia",
description: "NVIDIA API (requires key)"
},
{
id: "nvidia-nemotron-nano",
name: "Nemotron Nano 9B v2",
model: "nvidia/nvidia-nemotron-nano-9b-v2",
provider: "nvidia",
description: "NVIDIA API - Fast & efficient"
},
]
const VLLM_INITIALIZING: LLMModel = {
id: "vllm-initializing",
name: "vLLM (Local) Initializing…",
model: "",
provider: "vllm",
description: "Backend loading (can take 30+ min)",
}
export function LLMSelectorCompact() {
const [models, setModels] = useState<LLMModel[]>([])
const [selectedModel, setSelectedModel] = useState<LLMModel | null>(null)
const [isOpen, setIsOpen] = useState(false)
const [isLoading, setIsLoading] = useState(true)
// Fetch available models from running backends
const fetchAvailableModels = async () => {
setIsLoading(true)
const availableModels: LLMModel[] = []
// Check vLLM first (port 8001)
try {
const vllmResponse = await fetch('/api/vllm/models', {
signal: AbortSignal.timeout(3000)
})
if (vllmResponse.ok) {
const data = await vllmResponse.json()
if (data.models && Array.isArray(data.models)) {
data.models.forEach((model: any) => {
const modelId = model.id || model.name || model
availableModels.push({
id: `vllm-${modelId}`,
name: modelId.split('/').pop() || modelId,
model: modelId,
provider: "vllm",
description: "vLLM (GPU-accelerated)"
})
})
}
}
} catch (e) {
// vLLM not available
console.log("vLLM not available")
}
// Check Ollama (port 11434)
try {
const ollamaResponse = await fetch('/api/ollama/tags', {
signal: AbortSignal.timeout(3000)
})
if (ollamaResponse.ok) {
const data = await ollamaResponse.json()
if (data.models && Array.isArray(data.models)) {
data.models.forEach((model: any) => {
const modelName = model.name || model
availableModels.push({
id: `ollama-${modelName}`,
name: modelName,
model: modelName,
provider: "ollama",
description: "Local Ollama model"
})
})
}
}
} catch (e) {
// Ollama not available
console.log("Ollama not available")
}
// Always add NVIDIA API models
availableModels.push(...NVIDIA_MODELS)
// When no local backend is ready yet, add placeholder so we don't default to cloud
const hasLocal = availableModels.some(m => m.provider === "vllm" || m.provider === "ollama")
if (!hasLocal && availableModels.length > 0) {
availableModels.unshift(VLLM_INITIALIZING)
}
setModels(availableModels)
// Set default selected model: always prefer local, never default to cloud
if (availableModels.length > 0) {
try {
const saved = localStorage.getItem("selectedModelForRAG")
if (saved) {
const savedModel: LLMModel = JSON.parse(saved)
if (savedModel.id !== "vllm-initializing") {
const found = availableModels.find(m => m.id === savedModel.id)
if (found) {
setSelectedModel(found)
setIsLoading(false)
return
}
}
}
} catch (e) {
// Ignore
}
const localModel = availableModels.find(m => m.provider === "vllm" || m.provider === "ollama")
setSelectedModel(localModel || availableModels[0])
}
setIsLoading(false)
}
// Fetch models on mount
useEffect(() => {
fetchAvailableModels()
}, [])
// When showing vLLM initializing placeholder, poll to switch to real vLLM when ready
useEffect(() => {
if (selectedModel?.id !== "vllm-initializing") return
const id = setInterval(fetchAvailableModels, 60_000)
return () => clearInterval(id)
}, [selectedModel?.id])
// Save selected model to localStorage and dispatch event
const handleSelectModel = (model: LLMModel) => {
setSelectedModel(model)
setIsOpen(false)
localStorage.setItem("selectedModelForRAG", JSON.stringify(model))
// Dispatch event for other components
window.dispatchEvent(new CustomEvent('ragModelSelected', {
detail: { model }
}))
}
const getModelIcon = (provider: string) => {
if (provider === "ollama") {
return <OllamaIcon className="h-3 w-3 text-orange-500" />
}
if (provider === "vllm") {
return <Server className="h-3 w-3 text-purple-500" />
}
return <Cpu className="h-3 w-3 text-green-500" />
}
const getProviderLabel = (provider: string) => {
switch (provider) {
case "ollama": return "Ollama"
case "vllm": return "vLLM"
case "nvidia": return "NVIDIA API"
default: return provider
}
}
if (isLoading) {
return (
<div className="flex items-center gap-2 px-3 py-1.5 text-sm border border-border/40 rounded-lg bg-background/50">
<RefreshCw className="h-3 w-3 animate-spin text-muted-foreground" />
<span className="text-muted-foreground">Loading models...</span>
</div>
)
}
if (!selectedModel) {
return (
<div className="flex items-center gap-2 px-3 py-1.5 text-sm border border-border/40 rounded-lg bg-background/50 text-muted-foreground">
No models available
</div>
)
}
// Group models by provider
const groupedModels = models.reduce((acc, model) => {
if (!acc[model.provider]) {
acc[model.provider] = []
}
acc[model.provider].push(model)
return acc
}, {} as Record<string, LLMModel[]>)
return (
<div className="relative">
<button
type="button"
onClick={() => setIsOpen(!isOpen)}
aria-haspopup="listbox"
aria-expanded={isOpen}
aria-label={`Select LLM model. Currently selected: ${selectedModel.name}`}
className="flex items-center gap-2 px-3 py-1.5 text-sm border border-border/40 rounded-lg bg-background/50 hover:bg-muted/30 transition-colors"
>
{getModelIcon(selectedModel.provider)}
<span className="font-medium">{selectedModel.name}</span>
<ChevronDown className={`h-3 w-3 text-muted-foreground transition-transform ${isOpen ? 'rotate-180' : ''}`} />
</button>
{isOpen && (
<>
{/* Backdrop */}
<div
className="fixed inset-0 z-40"
onClick={() => setIsOpen(false)}
/>
{/* Dropdown */}
<div
className="absolute top-full left-0 mt-2 w-72 border border-border/40 rounded-lg bg-popover shadow-lg z-50 overflow-hidden"
role="listbox"
aria-label="Available LLM models"
>
<div className="p-2 border-b border-border/40 bg-muted/30 flex items-center justify-between">
<h4 className="text-xs font-semibold text-foreground">Select LLM for Answer Generation</h4>
<button
type="button"
onClick={(e) => {
e.stopPropagation()
fetchAvailableModels()
}}
className="p-1 hover:bg-muted/50 rounded"
title="Refresh models"
>
<RefreshCw className="h-3 w-3 text-muted-foreground" />
</button>
</div>
<div className="max-h-80 overflow-y-auto">
{Object.entries(groupedModels).map(([provider, providerModels]) => (
<div key={provider}>
<div className="px-3 py-1.5 text-xs font-semibold text-muted-foreground bg-muted/20 border-b border-border/20">
{getProviderLabel(provider)}
</div>
{providerModels.map((model) => (
<button
key={model.id}
type="button"
role="option"
aria-selected={selectedModel.id === model.id}
onClick={() => handleSelectModel(model)}
className={`w-full flex items-start gap-2 p-3 hover:bg-muted/50 transition-colors text-left ${
selectedModel.id === model.id ? 'bg-nvidia-green/10' : ''
}`}
>
<div className="mt-0.5">
{getModelIcon(model.provider)}
</div>
<div className="flex-1 min-w-0">
<div className="text-sm font-medium text-foreground truncate">
{model.name}
</div>
{model.description && (
<div className="text-xs text-muted-foreground">
{model.description}
</div>
)}
</div>
{selectedModel.id === model.id && (
<div className="w-2 h-2 rounded-full bg-nvidia-green flex-shrink-0 mt-1.5" />
)}
</button>
))}
</div>
))}
</div>
</div>
</>
)}
</div>
)
}