dgx-spark-playbooks/nvidia/station-rec-sys/assets/pricing_agent.py
2026-05-26 18:25:53 +00:00

1266 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Dynamic pricing agent for the rec-sys playbook.
Trains a PPO policy that picks per-item price multipliers each day,
evaluated against FixedPrice and AgeMarkdown baselines on a calibrated
inventory simulator. Subcommands:
smoke Run baselines on a tiny synthetic catalog and print KPIs.
train Train the PPO agent against the simulator.
eval Evaluate a trained PPO agent + baselines.
all train then eval.
The simulator, demand model, and baselines port code from
`enterprise-retail-demo/src/pricing/` into a single file. Bug fixes from
the prior DQN implementation are applied in the PPO section.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import pandas as pd
# torch is imported lazily inside the training/eval entry points so the
# `smoke` subcommand can run on environments without GPU drivers.
# ────────────────────────────────────────────────────────────────────
# Pricing config — embedded constants (was configs/pricing.yaml in the
# upstream repo). Inlined here so the playbook stays a single file.
# ────────────────────────────────────────────────────────────────────
PRICING_CONFIG: Dict[str, Any] = {
"simulator": {
"horizon_days": 14,
"num_shoppers_per_day": 10_000,
"replenish_interval": 7,
"replenish_fraction": 0.5,
},
"inventory": {
"min_initial": 50,
"max_initial": 200,
"cost_fraction_min": 0.40,
"cost_fraction_max": 0.60,
"holding_cost_per_unit_day": 0.002, # ≈ 50%/yr
"stockout_penalty": 5.0,
},
"demand_model": {
# logit(p) = alpha + beta·rec + gamma·season epsilon·Δp/p₀ delta·stockout
# Categories are derived from price tier on the real Amazon Dresses catalog
# (luxury = top third by price, midrange = middle, budget = bottom).
# Elasticity ranges follow Tellis (1988) "The Price Elasticity of Selective
# Demand": apparel ~2.0, luxury/premium <1.0, budget/clearance 34.
"seasonal_amplitude": 0.2,
"rec_score_weight": 1.0,
"stockout_delta": 10.0,
"elasticity": {
"luxury": [0.5, 1.0], # premium buyers less price-sensitive
"midrange": [1.5, 2.5], # standard apparel elasticity
"budget": [2.5, 4.0], # budget shoppers highly price-sensitive
"default": [1.5, 2.5],
},
# alpha (pre-sigmoid) calibrated to ~615 units/day/item with 10K shoppers
"base_demand": {
"luxury": -7.5, # ~6 units/day — lower volume
"midrange": -7.0, # ~9 units/day
"budget": -6.5, # ~15 units/day — volume play
"default": -7.0,
},
},
# 9 multipliers covering -40% to +25%. Replaces the upstream [0.85, 1.05]
# range that couldn't reach the deep markdowns aging stock needs.
"price_grid": {
"multipliers": [0.60, 0.70, 0.80, 0.90, 1.00, 1.05, 1.10, 1.15, 1.25],
},
}
CATEGORIES: Tuple[str, ...] = ("luxury", "midrange", "budget")
# ────────────────────────────────────────────────────────────────────
# Demand model
# ────────────────────────────────────────────────────────────────────
def _sigmoid(x: np.ndarray) -> np.ndarray:
"""Numerically-stable sigmoid."""
return np.where(
x >= 0,
1.0 / (1.0 + np.exp(-x)),
np.exp(x) / (1.0 + np.exp(x)),
)
class DemandModel:
"""Vectorized log-linear demand model.
Purchase probability per item:
sigmoid( alpha_cat + beta·rec epsilon_cat·(pp₀)/p₀
+ gamma·season(day) delta·I(inventory=0) )
Elasticity is set to the midpoint of each category's configured range
(deterministic for reproducibility).
"""
def __init__(self, config: Dict[str, Any]) -> None:
self.elasticity_ranges = config.get("elasticity", {})
self.default_elasticity = self.elasticity_ranges.get("default", [1.5, 2.5])
self.base_demand = config.get("base_demand", {})
self.default_base_demand = self.base_demand.get("default", -7.0)
self.seasonal_amplitude = config.get("seasonal_amplitude", 0.2)
self.rec_score_weight = config.get("rec_score_weight", 1.0)
self.stockout_delta = config.get("stockout_delta", 10.0)
def _alpha(self, categories: np.ndarray) -> np.ndarray:
out = np.full(len(categories), self.default_base_demand, dtype=np.float64)
for cat, val in self.base_demand.items():
if cat == "default":
continue
out[categories == cat] = val
return out
def _epsilon(self, categories: np.ndarray) -> np.ndarray:
default_mid = float(np.mean(self.default_elasticity))
out = np.full(len(categories), default_mid, dtype=np.float64)
for cat, (lo, hi) in self.elasticity_ranges.items():
if cat == "default":
continue
out[categories == cat] = 0.5 * (lo + hi)
return out
@staticmethod
def season_factor(day: int) -> float:
"""Sinusoidal weekly pattern, peaks Saturday (weekday 5)."""
return float(np.sin(2.0 * np.pi * ((day % 7) - 2) / 7.0))
def purchase_probability(
self,
categories: np.ndarray,
inventories: np.ndarray,
base_prices: np.ndarray,
prices: np.ndarray,
day: int,
rec_scores: np.ndarray,
) -> np.ndarray:
alpha = self._alpha(categories)
epsilon = self._epsilon(categories)
price_ratio = np.where(
base_prices > 0, (prices - base_prices) / base_prices, 0.0
)
stockout = (inventories <= 0).astype(np.float64)
logit = (
alpha
+ self.rec_score_weight * rec_scores
+ self.seasonal_amplitude * self.season_factor(day)
- epsilon * price_ratio
- self.stockout_delta * stockout
)
return _sigmoid(logit)
def expected_demand(
self,
categories: np.ndarray,
inventories: np.ndarray,
base_prices: np.ndarray,
prices: np.ndarray,
day: int,
rec_scores: np.ndarray,
num_shoppers: int,
) -> np.ndarray:
prob = self.purchase_probability(
categories, inventories, base_prices, prices, day, rec_scores
)
return num_shoppers * prob
# ────────────────────────────────────────────────────────────────────
# Inventory state
# ────────────────────────────────────────────────────────────────────
class InventoryState:
"""Per-item inventory + pricing state. All arrays shape (n_items,)."""
def __init__(
self,
inventories: np.ndarray,
base_prices: np.ndarray,
unit_costs: np.ndarray,
current_prices: np.ndarray,
days_in_stock: np.ndarray,
categories: np.ndarray,
rec_scores: np.ndarray,
) -> None:
self.inventories = inventories.astype(np.float64)
self.base_prices = base_prices.astype(np.float64)
self.unit_costs = unit_costs.astype(np.float64)
self.current_prices = current_prices.astype(np.float64)
self.days_in_stock = days_in_stock.astype(np.float64)
self.categories = np.asarray(categories)
self.rec_scores = rec_scores.astype(np.float64)
self._initial_inventories = self.inventories.copy()
@property
def n_items(self) -> int:
return len(self.inventories)
def copy(self) -> "InventoryState":
return InventoryState(
inventories=self.inventories.copy(),
base_prices=self.base_prices.copy(),
unit_costs=self.unit_costs.copy(),
current_prices=self.current_prices.copy(),
days_in_stock=self.days_in_stock.copy(),
categories=self.categories.copy(),
rec_scores=self.rec_scores.copy(),
)
@classmethod
def initialize(
cls,
item_features: pd.DataFrame,
config: Dict[str, Any],
seed: int = 42,
) -> "InventoryState":
"""Build a starting state from an item-features DataFrame.
Expected columns: ``avg_price`` (float), optional ``category`` (str)
and ``rec_score`` (float). Missing columns are synthesized.
"""
inv_cfg = config.get("inventory", {})
rng = np.random.default_rng(seed)
n = len(item_features)
if "avg_price" in item_features.columns:
base_prices = item_features["avg_price"].to_numpy(dtype=np.float64)
median = np.nanmedian(base_prices[base_prices > 0]) if (base_prices > 0).any() else 50.0
base_prices = np.where(
np.isnan(base_prices) | (base_prices <= 0), median, base_prices
)
else:
base_prices = rng.uniform(5.0, 100.0, size=n)
cf_lo = inv_cfg.get("cost_fraction_min", 0.40)
cf_hi = inv_cfg.get("cost_fraction_max", 0.60)
unit_costs = base_prices * rng.uniform(cf_lo, cf_hi, size=n)
inv_lo = inv_cfg.get("min_initial", 50)
inv_hi = inv_cfg.get("max_initial", 200)
inventories = rng.integers(inv_lo, inv_hi + 1, size=n).astype(np.float64)
if "category" in item_features.columns:
categories = item_features["category"].to_numpy(dtype=str)
else:
categories = np.array([CATEGORIES[i % len(CATEGORIES)] for i in range(n)])
if "rec_score" in item_features.columns:
rec_scores = item_features["rec_score"].to_numpy(dtype=np.float64)
else:
rec_scores = rng.uniform(0.0, 1.0, size=n)
return cls(
inventories=inventories,
base_prices=base_prices,
unit_costs=unit_costs,
current_prices=base_prices.copy(),
days_in_stock=np.zeros(n, dtype=np.float64),
categories=categories,
rec_scores=rec_scores,
)
# ────────────────────────────────────────────────────────────────────
# Simulator + result tracking
# ────────────────────────────────────────────────────────────────────
@dataclass
class SimulationResult:
daily_revenue: List[float] = field(default_factory=list)
daily_margin: List[float] = field(default_factory=list)
daily_units_sold: List[float] = field(default_factory=list)
daily_stockout_count: List[int] = field(default_factory=list)
n_items: int = 0
horizon_days: int = 0
initial_inventories: np.ndarray = field(default_factory=lambda: np.array([]))
@property
def total_revenue(self) -> float:
return float(np.sum(self.daily_revenue))
@property
def total_margin(self) -> float:
return float(np.sum(self.daily_margin))
@property
def avg_stockout_rate(self) -> float:
if self.n_items == 0 or self.horizon_days == 0:
return 0.0
return float(np.mean(self.daily_stockout_count)) / self.n_items
@property
def sell_through_rate(self) -> float:
total_initial = float(np.sum(self.initial_inventories))
return (float(np.sum(self.daily_units_sold)) / total_initial) if total_initial > 0 else 0.0
class Simulator:
"""Day-by-day inventory + pricing simulator with weekly replenishment."""
def __init__(self, demand_model: DemandModel, config: Dict[str, Any]) -> None:
self.demand_model = demand_model
sim_cfg = config.get("simulator", {})
inv_cfg = config.get("inventory", {})
self.num_shoppers = sim_cfg.get("num_shoppers_per_day", 10_000)
self.replenish_interval = sim_cfg.get("replenish_interval", 7)
self.replenish_fraction = sim_cfg.get("replenish_fraction", 0.5)
self.holding_cost = inv_cfg.get("holding_cost_per_unit_day", 0.002)
def run(
self,
initial_state: InventoryState,
policy: Any,
horizon_days: int,
seed: int = 123,
) -> SimulationResult:
state = initial_state.copy()
result = SimulationResult(
n_items=state.n_items,
horizon_days=horizon_days,
initial_inventories=state._initial_inventories.copy(),
)
rng = np.random.default_rng(seed)
for day in range(horizon_days):
try:
prices = policy.select_prices(state, day=day)
except TypeError:
prices = policy.select_prices(state)
state.current_prices = prices
expected = self.demand_model.expected_demand(
categories=state.categories,
inventories=state.inventories,
base_prices=state.base_prices,
prices=prices,
day=day,
rec_scores=state.rec_scores,
num_shoppers=self.num_shoppers,
)
realised = rng.poisson(lam=np.clip(expected, 0, None)).astype(np.float64)
sold = np.minimum(realised, state.inventories)
day_revenue = float(np.sum(prices * sold))
day_cogs = float(np.sum(state.unit_costs * sold))
day_holding = float(
self.holding_cost * np.sum(state.inventories * state.base_prices)
)
day_margin = day_revenue - day_cogs - day_holding
state.inventories -= sold
state.days_in_stock = np.where(
state.inventories > 0,
state.days_in_stock + 1,
state.days_in_stock,
)
result.daily_revenue.append(day_revenue)
result.daily_margin.append(day_margin)
result.daily_units_sold.append(float(np.sum(sold)))
result.daily_stockout_count.append(int(np.sum(state.inventories <= 0)))
if (
self.replenish_interval > 0
and (day + 1) % self.replenish_interval == 0
):
replenish_level = state._initial_inventories * self.replenish_fraction
state.inventories = np.maximum(state.inventories, replenish_level)
return result
# ────────────────────────────────────────────────────────────────────
# Baseline policies
# ────────────────────────────────────────────────────────────────────
class FixedPrice:
"""Always charge the base price — control policy."""
name = "FixedPrice"
def select_prices(self, state: InventoryState, day: int = 0) -> np.ndarray:
return state.base_prices.copy()
class AgeMarkdown:
"""Markdown by ``weekly_discount`` per week of dwell, floored at -50%."""
name = "AgeMarkdown"
def __init__(self, weekly_discount: float = 0.05) -> None:
self.weekly_discount = weekly_discount
def select_prices(self, state: InventoryState, day: int = 0) -> np.ndarray:
weeks = state.days_in_stock // 7
discount = np.clip(self.weekly_discount * weeks, 0.0, 0.50)
return state.base_prices * (1.0 - discount)
# ────────────────────────────────────────────────────────────────────
# Catalog loaders
# ────────────────────────────────────────────────────────────────────
def _derive_price_tier(prices: np.ndarray) -> np.ndarray:
"""Split items into 3 elasticity-defining tiers by price terciles."""
valid = prices[prices > 0]
if len(valid) < 3:
return np.full(len(prices), "midrange", dtype=object)
q33, q67 = np.percentile(valid, [100 / 3, 200 / 3])
tiers = np.where(prices < q33, "budget",
np.where(prices < q67, "midrange", "luxury"))
return tiers.astype(object)
def _derive_popularity_rec_score(item_ids: np.ndarray, interactions_path: Path) -> np.ndarray:
"""Per-item popularity in [0, 1], from interaction counts. sqrt-flattened
so a handful of mega-popular items don't dominate the signal."""
if not interactions_path.exists():
return np.full(len(item_ids), 0.5)
inter = pd.read_parquet(interactions_path, columns=["item_id"])
counts = inter.groupby("item_id").size()
item_counts = pd.Series(item_ids).map(counts).fillna(0).to_numpy(dtype=np.float64)
if item_counts.max() == 0:
return np.full(len(item_ids), 0.5)
return np.sqrt(item_counts / item_counts.max())
def load_amazon_dresses_catalog(n_items: int, seed: int) -> pd.DataFrame:
"""Load the real Amazon Dresses catalog produced by the playbook's
`prepare_data.py`, with derived price-tier categories and popularity-
based rec_score signals.
Sample to ``n_items`` for tractable training; pass ``n_items <= 0`` to
use the full catalog (~14k items after dropping missing prices).
"""
workspace = workspace_root()
meta_path = workspace / "data" / "processed" / "dress_metadata.parquet"
inter_path = workspace / "data" / "processed" / "dress_interactions.parquet"
if not meta_path.exists():
raise FileNotFoundError(
f"Amazon Dresses metadata not found at {meta_path}.\n"
f"Run `bash assets/setup.sh` to download/prepare the dataset, "
f"or pass `--synthetic` to use a generated catalog."
)
df = pd.read_parquet(meta_path)
# Drop items with missing or non-positive prices (~2k of 16k in the dataset).
df = df[df["price"].notna() & (df["price"] > 0)].reset_index(drop=True)
if n_items > 0 and n_items < len(df):
df = df.sample(n=n_items, random_state=seed).reset_index(drop=True)
df["category"] = _derive_price_tier(df["price"].to_numpy())
df["rec_score"] = _derive_popularity_rec_score(df["item_id"].to_numpy(), inter_path)
df = df.rename(columns={"price": "avg_price"})
return df[["item_id", "category", "avg_price", "rec_score"]]
def build_synthetic_catalog(n_items: int = 100, seed: int = 0) -> pd.DataFrame:
"""A tiny synthetic catalog for the `smoke` subcommand and as a
`--synthetic` fallback when the real dataset isn't available.
Uses lognormal prices (median ~$33, matching the Amazon Dresses
median) and the same luxury/midrange/budget tier scheme so the
simulator and demand model behave identically across synthetic and
real catalogs.
"""
rng = np.random.default_rng(seed)
prices = np.round(rng.lognormal(mean=3.5, sigma=0.6, size=n_items), 2)
rec_scores = rng.beta(2.0, 5.0, size=n_items)
cats = _derive_price_tier(prices)
return pd.DataFrame(
{
"item_id": np.arange(n_items),
"category": cats,
"avg_price": prices,
"rec_score": rec_scores,
}
)
def load_catalog(n_items: int, seed: int, synthetic: bool) -> pd.DataFrame:
"""Top-level catalog loader: real Amazon Dresses by default, synthetic
if ``--synthetic`` was passed or the real data is unavailable."""
if synthetic:
return build_synthetic_catalog(n_items=n_items if n_items > 0 else 1000, seed=seed)
return load_amazon_dresses_catalog(n_items=n_items, seed=seed)
# ────────────────────────────────────────────────────────────────────
# Eval helpers
# ────────────────────────────────────────────────────────────────────
def run_policy(
policy: Any,
initial_state: InventoryState,
simulator: Simulator,
horizon_days: int,
seed: int = 123,
) -> Tuple[str, SimulationResult, float]:
name = getattr(policy, "name", type(policy).__name__)
t0 = time.perf_counter()
result = simulator.run(initial_state, policy, horizon_days, seed=seed)
return name, result, time.perf_counter() - t0
def format_kpi_row(name: str, result: SimulationResult, elapsed: float, baseline_rev: float | None) -> str:
rev = result.total_revenue
lift = (rev / baseline_rev) if (baseline_rev and baseline_rev > 0) else 1.0
return (
f" {name:18s} | Rev: {rev:10.2f} ({lift:.2f}x) | "
f"Margin: {result.total_margin:10.2f} | "
f"Stockout: {result.avg_stockout_rate:5.1%} | "
f"Sell-through: {result.sell_through_rate:5.1%} | "
f"{elapsed*1000:5.0f}ms"
)
def print_kpi_table(rows: List[Tuple[str, SimulationResult, float]]) -> None:
if not rows:
return
baseline_rev = rows[0][1].total_revenue
print("-" * 110)
for name, res, elapsed in rows:
print(format_kpi_row(name, res, elapsed, baseline_rev))
print("-" * 110)
# ────────────────────────────────────────────────────────────────────
# PPO agent (torch) — discrete action over price multipliers,
# shared MLP backbone, vectorized simulator envs.
# ────────────────────────────────────────────────────────────────────
CATEGORY_TO_IDX: Dict[str, int] = {c: i for i, c in enumerate(CATEGORIES)}
N_CATEGORIES: int = len(CATEGORIES)
N_STATE_FEATURES: int = 8 + N_CATEGORIES # 8 numeric + 3 one-hot
def encode_state_per_item(
state: InventoryState,
day: int,
horizon: int,
price_norm: float,
inv_norm: float,
) -> np.ndarray:
"""Return a (n_items, N_STATE_FEATURES) float32 array.
Features per item (in order):
0 inventory_ratio (current / initial)
1 days_in_stock / horizon
2 day / horizon
3 sin(2π · day_of_week / 7)
4 cos(2π · day_of_week / 7)
5 log1p(base_price) / log1p(price_norm)
6 log1p(inventory) / log1p(inv_norm)
7 rec_score
8..10 category one-hot (fashion, basics, seasonal)
"""
n = state.n_items
initial = state._initial_inventories
inv_ratio = np.where(initial > 0, state.inventories / initial, 0.0)
days_norm = state.days_in_stock / max(horizon, 1)
day_norm = np.full(n, day / max(horizon, 1))
day_of_week = day % 7
dow_sin = np.full(n, np.sin(2.0 * np.pi * day_of_week / 7.0))
dow_cos = np.full(n, np.cos(2.0 * np.pi * day_of_week / 7.0))
bp = np.log1p(state.base_prices) / np.log1p(max(price_norm, 1.0))
inv = np.log1p(state.inventories) / np.log1p(max(inv_norm, 1.0))
rec = state.rec_scores
cat_onehot = np.zeros((n, N_CATEGORIES), dtype=np.float64)
for i, cat in enumerate(state.categories):
idx = CATEGORY_TO_IDX.get(str(cat), 0)
cat_onehot[i, idx] = 1.0
return np.column_stack(
[inv_ratio, days_norm, day_norm, dow_sin, dow_cos, bp, inv, rec, cat_onehot]
).astype(np.float32)
class PricingEnv:
"""Gym-like wrapper exposing the simulator one day at a time.
``reset()`` returns the encoded state. ``step(actions)`` advances one
day and returns (next_state, per_item_reward, done, info).
Per-item reward is the per-item margin contribution to total margin,
so summing over items recovers the eval-metric exactly (no
reward-eval drift).
"""
def __init__(
self,
initial_state: InventoryState,
simulator: Simulator,
multipliers: np.ndarray,
horizon: int,
seed: int,
) -> None:
self.initial_state = initial_state
self.simulator = simulator
self.multipliers = multipliers
self.horizon = horizon
self.seed = seed
self._price_norm = float(self.initial_state.base_prices.max())
self._inv_norm = float(self.initial_state._initial_inventories.max())
self.reset()
def reset(self) -> np.ndarray:
self.state = self.initial_state.copy()
self.day = 0
self.rng = np.random.default_rng(self.seed)
return encode_state_per_item(
self.state, self.day, self.horizon, self._price_norm, self._inv_norm
)
def step(self, action_idx: np.ndarray) -> Tuple[np.ndarray, np.ndarray, bool, Dict[str, float]]:
prices = self.state.base_prices * self.multipliers[action_idx]
self.state.current_prices = prices
expected = self.simulator.demand_model.expected_demand(
categories=self.state.categories,
inventories=self.state.inventories,
base_prices=self.state.base_prices,
prices=prices,
day=self.day,
rec_scores=self.state.rec_scores,
num_shoppers=self.simulator.num_shoppers,
)
realised = self.rng.poisson(lam=np.clip(expected, 0, None)).astype(np.float64)
sold = np.minimum(realised, self.state.inventories)
revenue = prices * sold
cogs = self.state.unit_costs * sold
holding = self.simulator.holding_cost * self.state.inventories * self.state.base_prices
reward_per_item = (revenue - cogs - holding).astype(np.float32)
self.state.inventories -= sold
self.state.days_in_stock = np.where(
self.state.inventories > 0,
self.state.days_in_stock + 1,
self.state.days_in_stock,
)
self.day += 1
if (
self.simulator.replenish_interval > 0
and self.day % self.simulator.replenish_interval == 0
):
replenish_level = self.state._initial_inventories * self.simulator.replenish_fraction
self.state.inventories = np.maximum(self.state.inventories, replenish_level)
done = self.day >= self.horizon
next_obs = encode_state_per_item(
self.state, self.day, self.horizon, self._price_norm, self._inv_norm
)
info = {
"total_margin": float(reward_per_item.sum()),
"total_revenue": float(revenue.sum()),
}
return next_obs, reward_per_item, done, info
# ────────────────────────────────────────────────────────────────────
# Network
# ────────────────────────────────────────────────────────────────────
def _build_actor_critic(n_actions: int, hidden: int = 256):
"""Return an ActorCritic torch module. Imported lazily."""
import torch
import torch.nn as nn
class ActorCritic(nn.Module):
def __init__(self) -> None:
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(N_STATE_FEATURES, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
)
self.actor = nn.Linear(hidden, n_actions)
self.critic = nn.Linear(hidden, 1)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
nn.init.zeros_(m.bias)
nn.init.orthogonal_(self.actor.weight, gain=0.01)
nn.init.orthogonal_(self.critic.weight, gain=1.0)
def forward(self, x):
h = self.backbone(x)
return self.actor(h), self.critic(h).squeeze(-1)
return ActorCritic()
# ────────────────────────────────────────────────────────────────────
# PPO Trainer
# ────────────────────────────────────────────────────────────────────
@dataclass
class PPOConfig:
n_iters: int = 200
n_envs: int = 16
horizon: int = 14
n_items: int = 1000
n_epochs: int = 4
minibatch_size: int = 4096
lr: float = 3e-4
gamma: float = 0.99
gae_lambda: float = 0.95
clip_eps: float = 0.2
value_coef: float = 0.5
entropy_coef_start: float = 0.05
entropy_coef_end: float = 0.005
max_grad_norm: float = 0.5
reward_scale: float = 100.0
value_clip_eps: float = 0.2
device: str = "auto"
seed: int = 0
def _compute_gae(
rewards: np.ndarray, # (T, K, N)
values: np.ndarray, # (T+1, K, N)
gamma: float,
lam: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Generalized Advantage Estimation per item-trajectory.
Treats each (env, item) pair as an independent length-T trajectory.
"""
T = rewards.shape[0]
advantages = np.zeros_like(rewards, dtype=np.float32)
last_gae = np.zeros(rewards.shape[1:], dtype=np.float32)
for t in reversed(range(T)):
delta = rewards[t] + gamma * values[t + 1] - values[t]
last_gae = delta + gamma * lam * last_gae
advantages[t] = last_gae
returns = advantages + values[:-1]
return advantages, returns
class PPOTrainer:
"""Vectorized PPO trainer for the pricing env."""
def __init__(self, cfg: PPOConfig, catalog: pd.DataFrame, multipliers: np.ndarray):
import torch
self.cfg = cfg
self.multipliers = multipliers
self.n_actions = len(multipliers)
self.device = self._resolve_device(cfg.device)
# Build K parallel envs with different seeds for diversity.
self.envs: List[PricingEnv] = []
for k in range(cfg.n_envs):
initial_state = InventoryState.initialize(catalog, PRICING_CONFIG, seed=cfg.seed + k)
demand = DemandModel(PRICING_CONFIG["demand_model"])
simulator = Simulator(demand, PRICING_CONFIG)
self.envs.append(
PricingEnv(initial_state, simulator, multipliers, cfg.horizon, seed=cfg.seed + k)
)
self.net = _build_actor_critic(self.n_actions).to(self.device)
self.optim = torch.optim.Adam(self.net.parameters(), lr=cfg.lr)
self.train_curve: List[float] = [] # mean total revenue per iteration
@staticmethod
def _resolve_device(s: str) -> str:
import torch
if s == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return s
def _act(self, obs_kn: "Any", greedy: bool = False):
"""obs_kn: (K, N, F) torch tensor. Returns actions, log_probs, values, entropy."""
import torch
logits, value = self.net(obs_kn)
if greedy:
action = logits.argmax(dim=-1)
else:
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
log_prob = torch.nn.functional.log_softmax(logits, dim=-1).gather(
-1, action.unsqueeze(-1)
).squeeze(-1)
entropy = torch.distributions.Categorical(logits=logits).entropy()
return action, log_prob, value, entropy
def collect_rollout(self):
"""Run K parallel envs for `horizon` steps. Returns dict of tensors."""
import torch
cfg = self.cfg
n_items = self.envs[0].state.n_items
obs_buf = np.zeros((cfg.horizon, cfg.n_envs, n_items, N_STATE_FEATURES), dtype=np.float32)
act_buf = np.zeros((cfg.horizon, cfg.n_envs, n_items), dtype=np.int64)
logp_buf = np.zeros((cfg.horizon, cfg.n_envs, n_items), dtype=np.float32)
val_buf = np.zeros((cfg.horizon + 1, cfg.n_envs, n_items), dtype=np.float32)
rew_buf = np.zeros((cfg.horizon, cfg.n_envs, n_items), dtype=np.float32)
obs = np.stack([env.reset() for env in self.envs], axis=0) # (K, N, F)
episode_returns = np.zeros(cfg.n_envs, dtype=np.float64)
episode_revenues = np.zeros(cfg.n_envs, dtype=np.float64)
for t in range(cfg.horizon):
obs_t = torch.from_numpy(obs).to(self.device) # (K, N, F)
with torch.no_grad():
action, log_prob, value, _ = self._act(obs_t, greedy=False)
action_np = action.cpu().numpy()
obs_buf[t] = obs
act_buf[t] = action_np
logp_buf[t] = log_prob.cpu().numpy()
val_buf[t] = value.cpu().numpy()
next_obs_list = []
for k, env in enumerate(self.envs):
next_o, reward, _done, info = env.step(action_np[k])
rew_buf[t, k] = reward / cfg.reward_scale # scale rewards for stable critic
episode_returns[k] += info["total_margin"]
episode_revenues[k] += info["total_revenue"]
next_obs_list.append(next_o)
obs = np.stack(next_obs_list, axis=0)
# Bootstrap value for last state (used by GAE)
with torch.no_grad():
obs_T = torch.from_numpy(obs).to(self.device)
_, _, last_value, _ = self._act(obs_T, greedy=False)
val_buf[cfg.horizon] = last_value.cpu().numpy()
advantages, returns = _compute_gae(rew_buf, val_buf, cfg.gamma, cfg.gae_lambda)
return {
"obs": obs_buf,
"actions": act_buf,
"log_probs": logp_buf,
"values": val_buf[:-1], # old values for clipped value loss
"advantages": advantages,
"returns": returns,
"episode_returns": episode_returns,
"episode_revenues": episode_revenues,
}
def update(self, rollout: Dict[str, np.ndarray], entropy_coef: float) -> Dict[str, float]:
"""Run PPO epochs on the collected rollout. Returns loss diagnostics."""
import torch
cfg = self.cfg
obs = torch.from_numpy(rollout["obs"]).to(self.device).reshape(-1, N_STATE_FEATURES)
actions = torch.from_numpy(rollout["actions"]).to(self.device).reshape(-1)
old_log_probs = torch.from_numpy(rollout["log_probs"]).to(self.device).reshape(-1)
old_values = torch.from_numpy(rollout["values"]).to(self.device).reshape(-1)
advantages = torch.from_numpy(rollout["advantages"]).to(self.device).reshape(-1)
returns = torch.from_numpy(rollout["returns"]).to(self.device).reshape(-1)
# Normalize advantages for stable gradients
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
N = obs.shape[0]
idx = torch.arange(N, device=self.device)
policy_losses, value_losses, entropies = [], [], []
for _epoch in range(cfg.n_epochs):
perm = idx[torch.randperm(N, device=self.device)]
for start in range(0, N, cfg.minibatch_size):
mb = perm[start : start + cfg.minibatch_size]
logits, value = self.net(obs[mb])
dist = torch.distributions.Categorical(logits=logits)
new_log_prob = dist.log_prob(actions[mb])
entropy = dist.entropy().mean()
ratio = (new_log_prob - old_log_probs[mb]).exp()
adv = advantages[mb]
surrogate1 = ratio * adv
surrogate2 = torch.clamp(ratio, 1 - cfg.clip_eps, 1 + cfg.clip_eps) * adv
policy_loss = -torch.min(surrogate1, surrogate2).mean()
# Clipped value loss: prevent value-function divergence after policy converges.
v_clipped = old_values[mb] + torch.clamp(
value - old_values[mb], -cfg.value_clip_eps, cfg.value_clip_eps
)
v_loss_unclipped = (value - returns[mb]).pow(2)
v_loss_clipped = (v_clipped - returns[mb]).pow(2)
value_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
loss = policy_loss + cfg.value_coef * value_loss - entropy_coef * entropy
self.optim.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), cfg.max_grad_norm)
self.optim.step()
policy_losses.append(float(policy_loss.detach()))
value_losses.append(float(value_loss.detach()))
entropies.append(float(entropy.detach()))
return {
"policy_loss": float(np.mean(policy_losses)),
"value_loss": float(np.mean(value_losses)),
"entropy": float(np.mean(entropies)),
}
def train(self, verbose: bool = True) -> List[float]:
cfg = self.cfg
if verbose:
print(
f"Training PPO: {cfg.n_iters} iters × {cfg.n_envs} envs × "
f"{cfg.horizon} days × {cfg.n_items} items on {self.device}"
)
t0 = time.perf_counter()
for it in range(cfg.n_iters):
# Linear anneal of entropy bonus from start → end across training.
frac = it / max(1, cfg.n_iters - 1)
entropy_coef = cfg.entropy_coef_start + (cfg.entropy_coef_end - cfg.entropy_coef_start) * frac
rollout = self.collect_rollout()
losses = self.update(rollout, entropy_coef=entropy_coef)
mean_revenue = float(rollout["episode_revenues"].mean())
mean_margin = float(rollout["episode_returns"].mean())
self.train_curve.append(mean_revenue)
if verbose and (it + 1) % max(1, cfg.n_iters // 20) == 0:
print(
f" iter {it + 1:4d}/{cfg.n_iters} | "
f"rev/ep: {mean_revenue:10.0f} | margin/ep: {mean_margin:10.0f} | "
f"pi_loss: {losses['policy_loss']:+.3f} | "
f"v_loss: {losses['value_loss']:.3f} | "
f"H: {losses['entropy']:.3f} | "
f"ent_c: {entropy_coef:.3f}"
)
if verbose:
print(f"Training complete in {time.perf_counter() - t0:.1f}s")
return self.train_curve
class PPOPolicy:
"""Inference-time wrapper exposing the PricingPolicy interface."""
name = "PPO"
def __init__(self, net, multipliers: np.ndarray, device: str, horizon: int,
price_norm: float, inv_norm: float, greedy: bool = True) -> None:
self.net = net
self.multipliers = multipliers
self.device = device
self.horizon = horizon
self.price_norm = price_norm
self.inv_norm = inv_norm
self.greedy = greedy
self.net.eval()
def select_prices(self, state: InventoryState, day: int = 0) -> np.ndarray:
import torch
obs = encode_state_per_item(state, day, self.horizon, self.price_norm, self.inv_norm)
with torch.no_grad():
obs_t = torch.from_numpy(obs).to(self.device)
logits, _ = self.net(obs_t)
if self.greedy:
action = logits.argmax(dim=-1).cpu().numpy()
else:
action = torch.distributions.Categorical(logits=logits).sample().cpu().numpy()
return state.base_prices * self.multipliers[action]
# ────────────────────────────────────────────────────────────────────
# Workspace + persistence helpers
# ────────────────────────────────────────────────────────────────────
def workspace_root() -> Path:
return Path(os.environ.get("PLAYBOOK_WORKSPACE", str(Path.home())))
def model_dir() -> Path:
return workspace_root() / "models" / "pricing_ppo"
def processed_dir() -> Path:
return workspace_root() / "data" / "processed"
def save_checkpoint(
net,
cfg: PPOConfig,
multipliers: np.ndarray,
catalog_meta: Dict[str, Any],
train_curve: List[float],
path: Path,
) -> None:
import torch
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"state_dict": net.state_dict(),
"config": cfg.__dict__,
"multipliers": multipliers.tolist(),
"catalog_meta": catalog_meta,
"train_curve": train_curve,
},
path,
)
def load_checkpoint(path: Path):
import torch
return torch.load(path, map_location="cpu", weights_only=False)
def save_training_curve_png(train_curve: List[float], path: Path) -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError:
print(f"matplotlib not available, skipping {path}", file=sys.stderr)
return
path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(train_curve, color="#76b900", linewidth=1.5)
ax.set_xlabel("Iteration")
ax.set_ylabel("Mean episode revenue")
ax.set_title("PPO training curve")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(path, dpi=120)
plt.close(fig)
# ────────────────────────────────────────────────────────────────────
# CLI
# ────────────────────────────────────────────────────────────────────
def cmd_smoke(args: argparse.Namespace) -> int:
"""Run baselines on a tiny synthetic catalog and print KPIs."""
print(f"Running pricing smoke test ({args.n_items} items, {args.horizon} days)…")
catalog = build_synthetic_catalog(n_items=args.n_items, seed=args.seed)
print(
f"Catalog: {len(catalog)} items, categories: "
f"{dict(zip(*np.unique(catalog['category'], return_counts=True)))}"
)
demand = DemandModel(PRICING_CONFIG["demand_model"])
simulator = Simulator(demand, PRICING_CONFIG)
initial_state = InventoryState.initialize(catalog, PRICING_CONFIG, seed=args.seed)
policies = [FixedPrice(), AgeMarkdown(weekly_discount=0.05)]
rows = [run_policy(p, initial_state, simulator, args.horizon, seed=args.seed) for p in policies]
print_kpi_table(rows)
return 0
def _ppo_config_from_args(args: argparse.Namespace) -> PPOConfig:
return PPOConfig(
n_iters=args.n_iters,
n_envs=args.n_envs,
horizon=args.horizon,
n_items=args.n_items,
lr=args.lr,
device=args.device,
seed=args.seed,
)
def _print_catalog_summary(catalog: pd.DataFrame, source: str) -> None:
cat_counts = catalog["category"].value_counts().to_dict()
price_med = float(catalog["avg_price"].median())
print(
f"Catalog [{source}]: {len(catalog)} items | "
f"price median ${price_med:.2f} | tiers: {cat_counts}"
)
def cmd_train(args: argparse.Namespace) -> int:
cfg = _ppo_config_from_args(args)
catalog = load_catalog(n_items=cfg.n_items, seed=cfg.seed, synthetic=args.synthetic)
source = "synthetic" if args.synthetic else "Amazon Dresses"
multipliers = np.asarray(PRICING_CONFIG["price_grid"]["multipliers"], dtype=np.float64)
_print_catalog_summary(catalog, source)
print(f"Multipliers: {multipliers.tolist()}, horizon: {cfg.horizon}d")
trainer = PPOTrainer(cfg, catalog, multipliers)
trainer.train(verbose=True)
ckpt_path = model_dir() / "policy.pt"
catalog_meta = {
"n_items": len(catalog),
"seed": cfg.seed,
"synthetic": args.synthetic,
"source": source,
"price_norm": float(catalog["avg_price"].max()),
}
save_checkpoint(trainer.net, cfg, multipliers, catalog_meta, trainer.train_curve, ckpt_path)
print(f"Saved checkpoint → {ckpt_path}")
curve_path = processed_dir() / "pricing_training_curve.png"
save_training_curve_png(trainer.train_curve, curve_path)
print(f"Saved training curve → {curve_path}")
return 0
def cmd_eval(args: argparse.Namespace) -> int:
import torch
ckpt_path = model_dir() / "policy.pt"
if not ckpt_path.exists():
print(f"No checkpoint at {ckpt_path}. Run `train` first.", file=sys.stderr)
return 2
ckpt = load_checkpoint(ckpt_path)
cfg_dict = ckpt["config"]
multipliers = np.asarray(ckpt["multipliers"], dtype=np.float64)
catalog_meta = ckpt.get("catalog_meta", {})
synthetic = bool(catalog_meta.get("synthetic", False))
n_items = int(catalog_meta.get("n_items", cfg_dict.get("n_items", 1000)))
catalog = load_catalog(n_items=n_items, seed=args.seed, synthetic=synthetic)
source = catalog_meta.get("source", "synthetic" if synthetic else "Amazon Dresses")
_print_catalog_summary(catalog, source)
demand = DemandModel(PRICING_CONFIG["demand_model"])
simulator = Simulator(demand, PRICING_CONFIG)
initial_state = InventoryState.initialize(catalog, PRICING_CONFIG, seed=args.seed)
device = PPOTrainer._resolve_device(args.device)
net = _build_actor_critic(len(multipliers)).to(device)
net.load_state_dict(ckpt["state_dict"])
ppo_policy = PPOPolicy(
net=net,
multipliers=multipliers,
device=device,
horizon=cfg_dict["horizon"],
price_norm=float(initial_state.base_prices.max()),
inv_norm=float(initial_state._initial_inventories.max()),
greedy=True,
)
policies = [FixedPrice(), AgeMarkdown(weekly_discount=0.05), ppo_policy]
print(f"\nEvaluating ({cfg_dict['n_items']} items, {cfg_dict['horizon']} days, device={device}):")
rows = [run_policy(p, initial_state, simulator, cfg_dict["horizon"], seed=args.seed) for p in policies]
print_kpi_table(rows)
# Save eval results
eval_path = processed_dir() / "pricing_eval.json"
eval_path.parent.mkdir(parents=True, exist_ok=True)
baseline_rev = rows[0][1].total_revenue
summary = {
"n_items": cfg_dict["n_items"],
"horizon": cfg_dict["horizon"],
"device": device,
"seed": args.seed,
"policies": [
{
"name": name,
"total_revenue": res.total_revenue,
"total_margin": res.total_margin,
"avg_stockout_rate": res.avg_stockout_rate,
"sell_through_rate": res.sell_through_rate,
"revenue_lift_vs_fixed_price": (res.total_revenue / baseline_rev) if baseline_rev > 0 else 1.0,
"elapsed_s": elapsed,
}
for name, res, elapsed in rows
],
}
with open(eval_path, "w") as f:
json.dump(summary, f, indent=2)
print(f"Saved eval summary → {eval_path}")
return 0
def cmd_train_and_eval(args: argparse.Namespace) -> int:
rc = cmd_train(args)
if rc != 0:
return rc
return cmd_eval(args)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="pricing_agent",
description="Dynamic pricing agent for the rec-sys playbook.",
)
sub = parser.add_subparsers(dest="cmd", required=False)
p_smoke = sub.add_parser("smoke", help="Run baselines on a tiny synthetic catalog.")
p_smoke.add_argument("--n-items", type=int, default=100)
p_smoke.add_argument("--horizon", type=int, default=14)
p_smoke.add_argument("--seed", type=int, default=0)
p_smoke.set_defaults(func=cmd_smoke)
def _add_train_flags(p: argparse.ArgumentParser) -> None:
p.add_argument("--n-iters", type=int, default=200)
p.add_argument("--n-envs", type=int, default=16)
p.add_argument("--horizon", type=int, default=14)
p.add_argument(
"--n-items",
type=int,
default=1000,
help="Sample size from the Amazon Dresses catalog. Pass 0 for the full ~14k items.",
)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--synthetic",
action="store_true",
help="Use a generated synthetic catalog instead of the real Amazon Dresses data.",
)
p_train = sub.add_parser("train", help="Train the PPO agent against the simulator.")
_add_train_flags(p_train)
p_train.set_defaults(func=cmd_train)
p_eval = sub.add_parser("eval", help="Evaluate trained PPO + baselines.")
p_eval.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
p_eval.add_argument("--seed", type=int, default=0)
p_eval.set_defaults(func=cmd_eval)
p_tae = sub.add_parser(
"train_and_eval",
help="Train the PPO agent, then evaluate vs. baselines. Default if no subcommand is given.",
)
_add_train_flags(p_tae)
p_tae.set_defaults(func=cmd_train_and_eval)
return parser
_VALID_SUBCOMMANDS = {"smoke", "train", "eval", "train_and_eval"}
def main(argv: List[str] | None = None) -> int:
if argv is None:
argv = sys.argv[1:]
# Default to `train_and_eval` if the user didn't pick a subcommand. Letting
# `--help` through gives them the top-level help; anything else (flags, no
# args) implies they want the full train→eval flow.
if not argv or (argv[0] not in _VALID_SUBCOMMANDS and argv[0] not in {"-h", "--help"}):
argv = ["train_and_eval", *argv]
parser = build_parser()
args = parser.parse_args(argv)
return args.func(args)
if __name__ == "__main__":
raise SystemExit(main())