mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-23 18:33:54 +00:00
LitServe-based prompt injection detection server with a React monitoring dashboard. Serves HuggingFace classification models behind an OpenAI-compatible API with real-time metrics and GPU acceleration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
172 lines
4.8 KiB
Python
172 lines
4.8 KiB
Python
"""LitServe app for litguard - prompt injection detection."""
|
|
|
|
import json
|
|
import time
|
|
import os
|
|
import subprocess
|
|
|
|
import litserve as ls
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from .models import ModelRegistry, load_config
|
|
from .metrics import metrics, ClassificationRecord
|
|
|
|
|
|
class PromptInjectionAPI(ls.LitAPI):
|
|
def setup(self, device: str):
|
|
self.config = load_config()
|
|
self.registry = ModelRegistry()
|
|
self.registry.load_from_config(self.config)
|
|
|
|
def decode_request(self, request: dict) -> dict:
|
|
# Support OpenAI chat completions format
|
|
messages = request.get("messages", [])
|
|
model_name = request.get("model")
|
|
# Extract text from the last user message
|
|
text = ""
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
content = msg.get("content", "")
|
|
if isinstance(content, list):
|
|
# Handle content array format
|
|
text = " ".join(
|
|
p.get("text", "") for p in content if p.get("type") == "text"
|
|
)
|
|
else:
|
|
text = content
|
|
break
|
|
return {"text": text, "model": model_name}
|
|
|
|
def predict(self, inputs: dict) -> dict:
|
|
text = inputs["text"]
|
|
model_name = inputs.get("model")
|
|
|
|
if model_name:
|
|
model = self.registry.get(model_name)
|
|
else:
|
|
model = None
|
|
|
|
if model is None:
|
|
model = self.registry.get_default()
|
|
|
|
start = time.time()
|
|
results = model.predict([text])
|
|
latency_ms = (time.time() - start) * 1000
|
|
|
|
result = results[0]
|
|
|
|
# Record metrics
|
|
metrics.record(
|
|
ClassificationRecord(
|
|
timestamp=time.time(),
|
|
input_text=text,
|
|
model=model.name,
|
|
label=result["label"],
|
|
score=result["score"],
|
|
latency_ms=latency_ms,
|
|
)
|
|
)
|
|
|
|
return {**result, "model": model.name, "latency_ms": round(latency_ms, 2)}
|
|
|
|
def encode_response(self, output: dict) -> dict:
|
|
# Return as OpenAI-compatible chat completion response
|
|
result_json = json.dumps(
|
|
{
|
|
"label": output["label"],
|
|
"score": output["score"],
|
|
"confidence": output["confidence"],
|
|
}
|
|
)
|
|
return {
|
|
"id": f"chatcmpl-litguard-{int(time.time()*1000)}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": output["model"],
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": result_json},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0,
|
|
},
|
|
}
|
|
|
|
|
|
def _get_gpu_utilization() -> str:
|
|
try:
|
|
result = subprocess.run(
|
|
["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv,noheader,nounits"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5,
|
|
)
|
|
return result.stdout.strip()
|
|
except Exception:
|
|
return "N/A"
|
|
|
|
|
|
def create_app():
|
|
config = load_config()
|
|
api = PromptInjectionAPI()
|
|
|
|
server = ls.LitServer(
|
|
api,
|
|
api_path="/v1/chat/completions",
|
|
timeout=30,
|
|
)
|
|
|
|
# Build model info from config (available without worker process)
|
|
model_info = [
|
|
{
|
|
"name": m["name"],
|
|
"hf_model": m["hf_model"],
|
|
"device": os.environ.get("DEVICE", m.get("device", "cpu")),
|
|
"batch_size": m.get("batch_size", 32),
|
|
}
|
|
for m in config.get("models", [])
|
|
]
|
|
model_names = [m["name"] for m in model_info]
|
|
|
|
# Add custom endpoints via FastAPI app
|
|
fastapi_app = server.app
|
|
|
|
fastapi_app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@fastapi_app.get("/health")
|
|
def health():
|
|
return {"status": "ok", "models_loaded": model_names}
|
|
|
|
@fastapi_app.get("/models")
|
|
def list_models():
|
|
return {"models": model_info}
|
|
|
|
@fastapi_app.get("/metrics")
|
|
def get_metrics():
|
|
m = metrics.get_metrics()
|
|
m["gpu_utilization"] = _get_gpu_utilization()
|
|
m["models_loaded"] = model_info
|
|
return m
|
|
|
|
@fastapi_app.get("/api/history")
|
|
def get_history():
|
|
return {"history": metrics.get_history()}
|
|
|
|
return server
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = load_config()
|
|
server = create_app()
|
|
server.run(port=config.get("port", 8234), host="0.0.0.0")
|