mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-20 21:29:31 +00:00
310 lines
12 KiB
Python
310 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Benchmark multi-turn conversation throughput on SGLang.
|
|
|
|
Sends parallel multi-turn conversations and measures:
|
|
- Per-turn latency (end-to-end wall time; may rise as prompts grow)
|
|
- Total throughput in tokens per second
|
|
- Optional cached prefill tokens from API usage when exposed by the server
|
|
|
|
Server cache / metrics text is summarized on stdout; full scrape is written to a
|
|
detail file for easier review (avoids dumping large /metrics blobs in the terminal).
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
import statistics
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import requests
|
|
|
|
|
|
def _cached_from_usage(usage: Dict[str, Any]) -> Optional[int]:
|
|
"""Return cached prefill token count if the server reports it (OpenAI-style usage)."""
|
|
if not usage:
|
|
return None
|
|
details = usage.get("prompt_tokens_details")
|
|
if isinstance(details, dict):
|
|
v = details.get("cached_tokens")
|
|
if isinstance(v, int):
|
|
return v
|
|
# Some stacks use alternate keys
|
|
for key in ("cached_prompt_tokens", "cache_read_input_tokens"):
|
|
v = usage.get(key)
|
|
if isinstance(v, int):
|
|
return v
|
|
return None
|
|
|
|
|
|
def chat_completion(base_url, model, messages, max_tokens=128):
|
|
"""Send a chat completion request and measure timing."""
|
|
start = time.perf_counter()
|
|
resp = requests.post(
|
|
f"{base_url}/v1/chat/completions",
|
|
json={
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
},
|
|
timeout=120,
|
|
)
|
|
elapsed = time.perf_counter() - start
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
usage = data.get("usage") or {}
|
|
cached = _cached_from_usage(usage)
|
|
return {
|
|
"elapsed_s": elapsed,
|
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
"completion_tokens": usage.get("completion_tokens", 0),
|
|
"cached_prompt_tokens": cached,
|
|
"content": data["choices"][0]["message"]["content"],
|
|
}
|
|
|
|
|
|
def run_conversation(base_url, model, conv_id, num_turns):
|
|
"""Run a multi-turn conversation and return per-turn metrics."""
|
|
system_msg = {"role": "system", "content": "You are a helpful assistant. Keep responses concise."}
|
|
topics = [
|
|
"What is photosynthesis?",
|
|
"How does it relate to the carbon cycle?",
|
|
"What role do oceans play in that cycle?",
|
|
"How does climate change affect the oceans?",
|
|
"What can individuals do to reduce their carbon footprint?",
|
|
"How do renewable energy sources compare?",
|
|
"What are the most promising new energy technologies?",
|
|
"How long until fusion energy is practical?",
|
|
]
|
|
|
|
messages = [system_msg]
|
|
results = []
|
|
|
|
for turn in range(num_turns):
|
|
user_msg = {"role": "user", "content": topics[turn % len(topics)]}
|
|
messages.append(user_msg)
|
|
|
|
result = chat_completion(base_url, model, messages)
|
|
result["conversation_id"] = conv_id
|
|
result["turn"] = turn + 1
|
|
results.append(result)
|
|
|
|
messages.append({"role": "assistant", "content": result["content"]})
|
|
|
|
return results
|
|
|
|
|
|
def run_structured_output_test(base_url, model):
|
|
"""Demonstrate structured JSON output with timing."""
|
|
print("\n--- Structured Output Test ---")
|
|
|
|
schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"items": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"category": {"type": "string"},
|
|
"score": {"type": "number"},
|
|
},
|
|
"required": ["name", "category", "score"],
|
|
},
|
|
}
|
|
},
|
|
"required": ["items"],
|
|
}
|
|
|
|
start = time.perf_counter()
|
|
resp = requests.post(
|
|
f"{base_url}/v1/chat/completions",
|
|
json={
|
|
"model": model,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "Rate these programming languages for data science: Python, R, Julia. "
|
|
"Return name, category, and a score from 1-10.",
|
|
}
|
|
],
|
|
"max_tokens": 512,
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"json_schema": {"name": "ratings", "schema": schema},
|
|
},
|
|
},
|
|
timeout=120,
|
|
)
|
|
elapsed = time.perf_counter() - start
|
|
resp.raise_for_status()
|
|
|
|
data = resp.json()
|
|
content = data["choices"][0]["message"]["content"]
|
|
|
|
parsed = json.loads(content)
|
|
print(f" Time: {elapsed:.2f}s")
|
|
print(" Valid JSON: Yes")
|
|
print(f" Items returned: {len(parsed.get('items', []))}")
|
|
print(f" Output: {json.dumps(parsed, indent=2)}")
|
|
|
|
|
|
_MAX_HIGHLIGHT_LEN = 200
|
|
|
|
|
|
def _interesting_metric_line(line: str) -> bool:
|
|
stripped = line.strip()
|
|
# /server_info returns a single-line JSON blob whose schema mentions "cache"
|
|
# dozens of times; suppress it (and any other JSON or oversized line) so the
|
|
# terminal summary stays readable. The full body is still in the detail file.
|
|
if not stripped or stripped[0] in "{[":
|
|
return False
|
|
if len(stripped) > _MAX_HIGHLIGHT_LEN:
|
|
return False
|
|
lower = stripped.lower()
|
|
if not any(kw in lower for kw in ("cache", "prefix", "radix", "hit", "reuse")):
|
|
return False
|
|
# Prefer lines that look like Prometheus samples or log counters
|
|
return bool(re.search(r"\d", stripped))
|
|
|
|
|
|
def get_cache_stats(base_url: str, detail_path: Path) -> None:
|
|
"""Write full /server_info and /metrics text to detail_path; print a short summary."""
|
|
print("\n--- Server cache / metrics (summary) ---")
|
|
sections: List[Tuple[str, str]] = []
|
|
for endpoint in ("/server_info", "/metrics"):
|
|
try:
|
|
resp = requests.get(f"{base_url}{endpoint}", timeout=30)
|
|
label = f"{endpoint} HTTP {resp.status_code}"
|
|
body = resp.text if resp.text else ""
|
|
sections.append((label, body))
|
|
except requests.RequestException as exc:
|
|
sections.append((f"{endpoint} error", str(exc)))
|
|
|
|
detail_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with detail_path.open("w", encoding="utf-8") as fh:
|
|
for label, body in sections:
|
|
fh.write(f"===== {label} =====\n")
|
|
fh.write(body)
|
|
if body and not body.endswith("\n"):
|
|
fh.write("\n")
|
|
fh.write("\n")
|
|
|
|
print(f" Full raw responses written to: {detail_path.resolve()}")
|
|
|
|
highlights: List[str] = []
|
|
for _, body in sections:
|
|
if not body or body.startswith("HTTP"):
|
|
continue
|
|
for line in body.splitlines():
|
|
s = line.strip()
|
|
if _interesting_metric_line(s):
|
|
highlights.append(s)
|
|
|
|
# De-duplicate while preserving order
|
|
seen: set[str] = set()
|
|
unique: List[str] = []
|
|
for h in highlights:
|
|
if h not in seen:
|
|
seen.add(h)
|
|
unique.append(h)
|
|
|
|
max_print = 12
|
|
if not unique:
|
|
print(" No cache-related lines matched heuristics in /server_info or /metrics.")
|
|
print(" See the detail file above; also check `docker logs` for #cached-token lines.")
|
|
return
|
|
|
|
print(f" Showing up to {max_print} cache-related lines (of {len(unique)} matched):")
|
|
for line in unique[:max_print]:
|
|
print(f" {line}")
|
|
if len(unique) > max_print:
|
|
print(f" … plus {len(unique) - max_print} more in the detail file.")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Benchmark SGLang multi-turn inference")
|
|
parser.add_argument("--base-url", default="http://localhost:30000", help="SGLang server URL")
|
|
parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Model name")
|
|
parser.add_argument("--num-conversations", type=int, default=20, help="Number of parallel conversations")
|
|
parser.add_argument("--turns-per-conversation", type=int, default=5, help="Turns per conversation")
|
|
parser.add_argument("--max-workers", type=int, default=8, help="Max parallel workers")
|
|
parser.add_argument(
|
|
"--cache-detail-file",
|
|
default="sglang_benchmark_cache_details.log",
|
|
help="Path to write full /server_info and /metrics bodies for offline review",
|
|
)
|
|
args = parser.parse_args()
|
|
detail_path = Path(args.cache_detail_file)
|
|
|
|
print("SGLang Multi-Turn Benchmark")
|
|
print(f" Server: {args.base_url}")
|
|
print(f" Model: {args.model}")
|
|
print(f" Conversations: {args.num_conversations}")
|
|
print(f" Turns each: {args.turns_per_conversation}")
|
|
print()
|
|
print(
|
|
"Note: per-turn **wall time** often **increases** as the transcript grows (more prompt "
|
|
"tokens to attend to, longer assistant decodes, and parallel load). RadixAttention still "
|
|
"reuses KV for shared prefixes—confirm with `docker logs` (#cached-token) or cached_tokens "
|
|
"in usage when the server exposes them."
|
|
)
|
|
print()
|
|
|
|
print("Warming up...")
|
|
chat_completion(args.base_url, args.model, [{"role": "user", "content": "Hello"}], max_tokens=8)
|
|
|
|
print(f"Running {args.num_conversations} conversations with {args.turns_per_conversation} turns each...")
|
|
all_results = []
|
|
bench_start = time.perf_counter()
|
|
|
|
with ThreadPoolExecutor(max_workers=args.max_workers) as pool:
|
|
futures = {
|
|
pool.submit(run_conversation, args.base_url, args.model, i, args.turns_per_conversation): i
|
|
for i in range(args.num_conversations)
|
|
}
|
|
for future in as_completed(futures):
|
|
conv_results = future.result()
|
|
all_results.extend(conv_results)
|
|
|
|
bench_elapsed = time.perf_counter() - bench_start
|
|
|
|
print(f"\n--- Results ({bench_elapsed:.1f}s total) ---\n")
|
|
|
|
total_prompt_tokens = sum(r["prompt_tokens"] for r in all_results)
|
|
total_completion_tokens = sum(r["completion_tokens"] for r in all_results)
|
|
total_tokens = total_prompt_tokens + total_completion_tokens
|
|
|
|
print(f"Total tokens: {total_tokens:,} ({total_prompt_tokens:,} prompt + {total_completion_tokens:,} completion)")
|
|
print(f"Throughput: {total_tokens / bench_elapsed:.0f} tok/s")
|
|
print()
|
|
|
|
any_cached = any(r.get("cached_prompt_tokens") is not None for r in all_results)
|
|
|
|
print(f"{'Turn':<6} {'Median wall':<14} {'P90 wall':<14} {'Med prompt tok':<16} {'Med cached prefill':<20} {'N':<6}")
|
|
print("-" * 82)
|
|
|
|
for turn in range(1, args.turns_per_conversation + 1):
|
|
turn_results = [r for r in all_results if r["turn"] == turn]
|
|
latencies = sorted(r["elapsed_s"] for r in turn_results)
|
|
median = statistics.median(latencies)
|
|
p90 = latencies[int(len(latencies) * 0.9)] if len(latencies) >= 10 else latencies[-1]
|
|
med_prompt = int(statistics.median([r["prompt_tokens"] for r in turn_results]))
|
|
cached_vals = [r["cached_prompt_tokens"] for r in turn_results if r["cached_prompt_tokens"] is not None]
|
|
med_cached = f"{statistics.median(cached_vals):.0f}" if cached_vals else ("n/a" if not any_cached else "0")
|
|
print(f" {turn:<4} {median:>10.3f}s {p90:>10.3f}s {med_prompt:<16} {med_cached:<20} {len(turn_results):<6}")
|
|
|
|
print()
|
|
print("Interpretation: rising median wall time with turn index is normal under load and longer contexts.")
|
|
print("Use single-conversation runs (`--num-conversations 1`) and server logs for the clearest cache story.")
|
|
|
|
run_structured_output_test(args.base_url, args.model)
|
|
get_cache_stats(args.base_url, detail_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|