dgx-spark-playbooks/nvidia/station-sglang-inference/assets/benchmark_multiturn.py
2026-05-26 18:25:53 +00:00

310 lines
12 KiB
Python

#!/usr/bin/env python3
"""Benchmark multi-turn conversation throughput on SGLang.
Sends parallel multi-turn conversations and measures:
- Per-turn latency (end-to-end wall time; may rise as prompts grow)
- Total throughput in tokens per second
- Optional cached prefill tokens from API usage when exposed by the server
Server cache / metrics text is summarized on stdout; full scrape is written to a
detail file for easier review (avoids dumping large /metrics blobs in the terminal).
"""
import argparse
import json
import re
import statistics
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import requests
def _cached_from_usage(usage: Dict[str, Any]) -> Optional[int]:
"""Return cached prefill token count if the server reports it (OpenAI-style usage)."""
if not usage:
return None
details = usage.get("prompt_tokens_details")
if isinstance(details, dict):
v = details.get("cached_tokens")
if isinstance(v, int):
return v
# Some stacks use alternate keys
for key in ("cached_prompt_tokens", "cache_read_input_tokens"):
v = usage.get(key)
if isinstance(v, int):
return v
return None
def chat_completion(base_url, model, messages, max_tokens=128):
"""Send a chat completion request and measure timing."""
start = time.perf_counter()
resp = requests.post(
f"{base_url}/v1/chat/completions",
json={
"model": model,
"messages": messages,
"max_tokens": max_tokens,
},
timeout=120,
)
elapsed = time.perf_counter() - start
resp.raise_for_status()
data = resp.json()
usage = data.get("usage") or {}
cached = _cached_from_usage(usage)
return {
"elapsed_s": elapsed,
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"cached_prompt_tokens": cached,
"content": data["choices"][0]["message"]["content"],
}
def run_conversation(base_url, model, conv_id, num_turns):
"""Run a multi-turn conversation and return per-turn metrics."""
system_msg = {"role": "system", "content": "You are a helpful assistant. Keep responses concise."}
topics = [
"What is photosynthesis?",
"How does it relate to the carbon cycle?",
"What role do oceans play in that cycle?",
"How does climate change affect the oceans?",
"What can individuals do to reduce their carbon footprint?",
"How do renewable energy sources compare?",
"What are the most promising new energy technologies?",
"How long until fusion energy is practical?",
]
messages = [system_msg]
results = []
for turn in range(num_turns):
user_msg = {"role": "user", "content": topics[turn % len(topics)]}
messages.append(user_msg)
result = chat_completion(base_url, model, messages)
result["conversation_id"] = conv_id
result["turn"] = turn + 1
results.append(result)
messages.append({"role": "assistant", "content": result["content"]})
return results
def run_structured_output_test(base_url, model):
"""Demonstrate structured JSON output with timing."""
print("\n--- Structured Output Test ---")
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"category": {"type": "string"},
"score": {"type": "number"},
},
"required": ["name", "category", "score"],
},
}
},
"required": ["items"],
}
start = time.perf_counter()
resp = requests.post(
f"{base_url}/v1/chat/completions",
json={
"model": model,
"messages": [
{
"role": "user",
"content": "Rate these programming languages for data science: Python, R, Julia. "
"Return name, category, and a score from 1-10.",
}
],
"max_tokens": 512,
"response_format": {
"type": "json_schema",
"json_schema": {"name": "ratings", "schema": schema},
},
},
timeout=120,
)
elapsed = time.perf_counter() - start
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
parsed = json.loads(content)
print(f" Time: {elapsed:.2f}s")
print(" Valid JSON: Yes")
print(f" Items returned: {len(parsed.get('items', []))}")
print(f" Output: {json.dumps(parsed, indent=2)}")
_MAX_HIGHLIGHT_LEN = 200
def _interesting_metric_line(line: str) -> bool:
stripped = line.strip()
# /server_info returns a single-line JSON blob whose schema mentions "cache"
# dozens of times; suppress it (and any other JSON or oversized line) so the
# terminal summary stays readable. The full body is still in the detail file.
if not stripped or stripped[0] in "{[":
return False
if len(stripped) > _MAX_HIGHLIGHT_LEN:
return False
lower = stripped.lower()
if not any(kw in lower for kw in ("cache", "prefix", "radix", "hit", "reuse")):
return False
# Prefer lines that look like Prometheus samples or log counters
return bool(re.search(r"\d", stripped))
def get_cache_stats(base_url: str, detail_path: Path) -> None:
"""Write full /server_info and /metrics text to detail_path; print a short summary."""
print("\n--- Server cache / metrics (summary) ---")
sections: List[Tuple[str, str]] = []
for endpoint in ("/server_info", "/metrics"):
try:
resp = requests.get(f"{base_url}{endpoint}", timeout=30)
label = f"{endpoint} HTTP {resp.status_code}"
body = resp.text if resp.text else ""
sections.append((label, body))
except requests.RequestException as exc:
sections.append((f"{endpoint} error", str(exc)))
detail_path.parent.mkdir(parents=True, exist_ok=True)
with detail_path.open("w", encoding="utf-8") as fh:
for label, body in sections:
fh.write(f"===== {label} =====\n")
fh.write(body)
if body and not body.endswith("\n"):
fh.write("\n")
fh.write("\n")
print(f" Full raw responses written to: {detail_path.resolve()}")
highlights: List[str] = []
for _, body in sections:
if not body or body.startswith("HTTP"):
continue
for line in body.splitlines():
s = line.strip()
if _interesting_metric_line(s):
highlights.append(s)
# De-duplicate while preserving order
seen: set[str] = set()
unique: List[str] = []
for h in highlights:
if h not in seen:
seen.add(h)
unique.append(h)
max_print = 12
if not unique:
print(" No cache-related lines matched heuristics in /server_info or /metrics.")
print(" See the detail file above; also check `docker logs` for #cached-token lines.")
return
print(f" Showing up to {max_print} cache-related lines (of {len(unique)} matched):")
for line in unique[:max_print]:
print(f" {line}")
if len(unique) > max_print:
print(f" … plus {len(unique) - max_print} more in the detail file.")
def main():
parser = argparse.ArgumentParser(description="Benchmark SGLang multi-turn inference")
parser.add_argument("--base-url", default="http://localhost:30000", help="SGLang server URL")
parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Model name")
parser.add_argument("--num-conversations", type=int, default=20, help="Number of parallel conversations")
parser.add_argument("--turns-per-conversation", type=int, default=5, help="Turns per conversation")
parser.add_argument("--max-workers", type=int, default=8, help="Max parallel workers")
parser.add_argument(
"--cache-detail-file",
default="sglang_benchmark_cache_details.log",
help="Path to write full /server_info and /metrics bodies for offline review",
)
args = parser.parse_args()
detail_path = Path(args.cache_detail_file)
print("SGLang Multi-Turn Benchmark")
print(f" Server: {args.base_url}")
print(f" Model: {args.model}")
print(f" Conversations: {args.num_conversations}")
print(f" Turns each: {args.turns_per_conversation}")
print()
print(
"Note: per-turn **wall time** often **increases** as the transcript grows (more prompt "
"tokens to attend to, longer assistant decodes, and parallel load). RadixAttention still "
"reuses KV for shared prefixes—confirm with `docker logs` (#cached-token) or cached_tokens "
"in usage when the server exposes them."
)
print()
print("Warming up...")
chat_completion(args.base_url, args.model, [{"role": "user", "content": "Hello"}], max_tokens=8)
print(f"Running {args.num_conversations} conversations with {args.turns_per_conversation} turns each...")
all_results = []
bench_start = time.perf_counter()
with ThreadPoolExecutor(max_workers=args.max_workers) as pool:
futures = {
pool.submit(run_conversation, args.base_url, args.model, i, args.turns_per_conversation): i
for i in range(args.num_conversations)
}
for future in as_completed(futures):
conv_results = future.result()
all_results.extend(conv_results)
bench_elapsed = time.perf_counter() - bench_start
print(f"\n--- Results ({bench_elapsed:.1f}s total) ---\n")
total_prompt_tokens = sum(r["prompt_tokens"] for r in all_results)
total_completion_tokens = sum(r["completion_tokens"] for r in all_results)
total_tokens = total_prompt_tokens + total_completion_tokens
print(f"Total tokens: {total_tokens:,} ({total_prompt_tokens:,} prompt + {total_completion_tokens:,} completion)")
print(f"Throughput: {total_tokens / bench_elapsed:.0f} tok/s")
print()
any_cached = any(r.get("cached_prompt_tokens") is not None for r in all_results)
print(f"{'Turn':<6} {'Median wall':<14} {'P90 wall':<14} {'Med prompt tok':<16} {'Med cached prefill':<20} {'N':<6}")
print("-" * 82)
for turn in range(1, args.turns_per_conversation + 1):
turn_results = [r for r in all_results if r["turn"] == turn]
latencies = sorted(r["elapsed_s"] for r in turn_results)
median = statistics.median(latencies)
p90 = latencies[int(len(latencies) * 0.9)] if len(latencies) >= 10 else latencies[-1]
med_prompt = int(statistics.median([r["prompt_tokens"] for r in turn_results]))
cached_vals = [r["cached_prompt_tokens"] for r in turn_results if r["cached_prompt_tokens"] is not None]
med_cached = f"{statistics.median(cached_vals):.0f}" if cached_vals else ("n/a" if not any_cached else "0")
print(f" {turn:<4} {median:>10.3f}s {p90:>10.3f}s {med_prompt:<16} {med_cached:<20} {len(turn_results):<6}")
print()
print("Interpretation: rising median wall time with turn index is normal under load and longer contexts.")
print("Use single-conversation runs (`--num-conversations 1`) and server logs for the clearest cache story.")
run_structured_output_test(args.base_url, args.model)
get_cache_stats(args.base_url, detail_path)
if __name__ == "__main__":
main()