mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-06-20 13:19:34 +00:00
547 lines
24 KiB
Python
547 lines
24 KiB
Python
"""Train a LightGBM lambdarank re-ranker on HLLM embeddings + handcrafted features.
|
||
|
||
Two-stage retrieval-and-rank pipeline: HLLM/FAISS produces top-100 candidates
|
||
per user, this script trains a LightGBM model to re-order them.
|
||
|
||
Pipeline per user (leave-last-out):
|
||
1. Sort interactions by timestamp; hold out the last item as the positive.
|
||
2. Build user_emb = mean(HLLM[history items]), L2-normalized.
|
||
3. Retrieve top-100 candidates via torch.mm + topk over the item embedding
|
||
matrix (mathematically equivalent to FAISS IndexFlatIP).
|
||
4. If the held-out item lands in the top-100, label that row 1 and the
|
||
other 99 rows 0 (a training-signal row group). Skip users whose
|
||
positive missed the top-100 — there is nothing to learn for them.
|
||
5. Engineer ~21 features per (user, candidate) pair: item popularity
|
||
windows, user history stats, three HLLM similarity signals
|
||
(dot-product, max/avg vs. recent history), price ratios, is_repurchase.
|
||
|
||
Train/valid: 80/20 split over user groups. LightGBM with the lambdarank
|
||
objective and group sizes = candidate-set size per user.
|
||
|
||
Inference contract: assets/app.py and assets/benchmark_retrieval.py load
|
||
the saved model and call model.predict(features) to score candidates.
|
||
|
||
Feature provenance: the handcrafted feature set is adapted from the
|
||
1st-place solution to the H&M Personalized Fashion Recommendations
|
||
Kaggle competition, which combined item popularity windows, user
|
||
history aggregates, price relationships, and pairwise text-similarity
|
||
signals into a LightGBM lambdarank model. We keep the H&M structure
|
||
and swap the original TF-IDF text similarities for HLLM embedding
|
||
similarities (`hllm_dot_product`, `hllm_max_hist_sim`, `hllm_avg_hist_sim`).
|
||
|
||
Generalization assumption: this feature set is expected to transfer to
|
||
other sparse retail datasets (long-tail item distributions, repeat-purchase
|
||
signal, mixed continuous + categorical + missing features). If you adapt
|
||
this to a different domain — content/video, travel, music, location-based
|
||
— expect to add domain-specific signals (watch time, location distance,
|
||
session co-occurrence, etc.) and possibly drop features that no longer
|
||
apply (e.g. `is_repurchase` is meaningless for one-time purchases).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import lightgbm as lgb
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
|
||
|
||
def default_workspace() -> Path:
|
||
return Path(os.environ.get('PLAYBOOK_WORKSPACE', os.path.expanduser('~')))
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Data loading
|
||
# ----------------------------------------------------------------------
|
||
|
||
def load_inputs(processed_dir: Path):
|
||
interactions = pd.read_parquet(processed_dir / 'dress_interactions.parquet')
|
||
metadata = pd.read_parquet(processed_dir / 'dress_metadata.parquet')
|
||
embeddings = np.load(processed_dir / 'hllm_item_embeddings.npy').astype(np.float32)
|
||
item_ids = np.load(processed_dir / 'hllm_item_id_map.npy', allow_pickle=True).astype(str)
|
||
return interactions, metadata, embeddings, item_ids
|
||
|
||
|
||
def build_user_samples(interactions: pd.DataFrame, item_to_idx: dict[str, int]):
|
||
"""For each user with >=2 mapped interactions, emit (uid, history_idx, positive_idx, history_timestamps).
|
||
|
||
history_idx is the user's interactions excluding the last; positive_idx is
|
||
the last item. Returned order is users in interaction-frame order.
|
||
"""
|
||
inter = interactions.sort_values(['user_id', 'timestamp'])
|
||
samples = []
|
||
for uid, group in inter.groupby('user_id', sort=False):
|
||
item_seq = group['item_id'].tolist()
|
||
ts_seq = group['timestamp'].tolist()
|
||
idxs = [item_to_idx[i] for i in item_seq if i in item_to_idx]
|
||
ts = [t for i, t in zip(item_seq, ts_seq) if i in item_to_idx]
|
||
if len(idxs) < 2:
|
||
continue
|
||
samples.append((uid, idxs[:-1], idxs[-1], ts[:-1]))
|
||
return samples
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Candidate retrieval (FAISS-equivalent on GPU)
|
||
# ----------------------------------------------------------------------
|
||
|
||
def retrieve_candidates_gpu(
|
||
user_emb_matrix: np.ndarray,
|
||
item_emb_matrix: np.ndarray,
|
||
top_k: int,
|
||
device: torch.device,
|
||
chunk_size: int = 4096,
|
||
):
|
||
"""torch.mm + topk on GPU. Chunked over users to bound peak memory."""
|
||
item_t = torch.from_numpy(item_emb_matrix).to(device)
|
||
n_users = user_emb_matrix.shape[0]
|
||
out_idx = np.empty((n_users, top_k), dtype=np.int64)
|
||
out_score = np.empty((n_users, top_k), dtype=np.float32)
|
||
for start in range(0, n_users, chunk_size):
|
||
end = min(start + chunk_size, n_users)
|
||
u = torch.from_numpy(user_emb_matrix[start:end]).to(device)
|
||
scores = torch.mm(u, item_t.T)
|
||
top_s, top_i = torch.topk(scores, top_k, dim=1)
|
||
out_idx[start:end] = top_i.cpu().numpy()
|
||
out_score[start:end] = top_s.cpu().numpy()
|
||
return out_idx, out_score
|
||
|
||
|
||
def build_user_embeddings(samples, embeddings: np.ndarray) -> np.ndarray:
|
||
"""L2-normalized mean of history-item embeddings per user."""
|
||
n_users = len(samples)
|
||
dim = embeddings.shape[1]
|
||
user_emb = np.empty((n_users, dim), dtype=np.float32)
|
||
for i, (_, hist_idxs, _, _) in enumerate(samples):
|
||
emb = embeddings[hist_idxs].mean(axis=0)
|
||
user_emb[i] = emb / (float(np.linalg.norm(emb)) + 1e-8)
|
||
return user_emb
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Item / user statistics for handcrafted features
|
||
# ----------------------------------------------------------------------
|
||
|
||
def compute_item_stats(interactions: pd.DataFrame, metadata: pd.DataFrame) -> pd.DataFrame:
|
||
"""Per-item rollups: total purchases, recency windows, trend, content."""
|
||
max_ts = int(interactions['timestamp'].max())
|
||
DAY = 86400
|
||
win_30 = max_ts - 30 * DAY
|
||
win_90 = max_ts - 90 * DAY
|
||
win_180 = max_ts - 180 * DAY
|
||
|
||
counts = interactions.groupby('item_id').agg(
|
||
item_total_purchases=('user_id', 'size'),
|
||
item_unique_buyers=('user_id', 'nunique'),
|
||
item_first_seen=('timestamp', 'min'),
|
||
item_last_seen=('timestamp', 'max'),
|
||
)
|
||
|
||
pop_30 = interactions[interactions['timestamp'] >= win_30]['item_id'].value_counts().rename('item_pop_30d')
|
||
pop_90 = interactions[interactions['timestamp'] >= win_90]['item_id'].value_counts().rename('item_pop_90d')
|
||
pop_180 = interactions[interactions['timestamp'] >= win_180]['item_id'].value_counts().rename('item_pop_180d')
|
||
counts = counts.join(pop_30, how='left').join(pop_90, how='left').join(pop_180, how='left').fillna(0)
|
||
|
||
counts['item_trend'] = (counts['item_pop_30d'] + 1) / (counts['item_pop_180d'] / 6 + 1)
|
||
counts['item_age_days'] = (max_ts - counts['item_first_seen']) / DAY
|
||
counts['item_recency_days'] = (max_ts - counts['item_last_seen']) / DAY
|
||
counts = counts.drop(columns=['item_first_seen', 'item_last_seen'])
|
||
|
||
meta = metadata.set_index('item_id').copy()
|
||
meta['log_price'] = np.log1p(meta['price'].fillna(meta['price'].median()))
|
||
meta['title_length'] = meta['title'].fillna('').str.len()
|
||
meta['desc_length'] = meta['description'].fillna('').str.len()
|
||
meta['has_image'] = (meta['image_url'].fillna('').str.len() > 0).astype(np.int8)
|
||
meta = meta[['log_price', 'title_length', 'desc_length', 'has_image']]
|
||
|
||
item_stats = counts.join(meta, how='left').fillna(0)
|
||
return item_stats
|
||
|
||
|
||
def compute_user_stats(samples, metadata: pd.DataFrame, item_idx_to_id: list[str]) -> dict:
|
||
"""Per-user rollups derivable from history alone (no leakage)."""
|
||
price_lookup = metadata.set_index('item_id')['price'].to_dict()
|
||
DAY = 86400
|
||
out = {}
|
||
for uid, hist_idxs, _, ts in samples:
|
||
prices = [
|
||
float(price_lookup.get(item_idx_to_id[i], 0.0))
|
||
for i in hist_idxs
|
||
]
|
||
prices = [p for p in prices if p > 0]
|
||
out[uid] = {
|
||
'user_total_purchases': len(hist_idxs),
|
||
'user_unique_items': len(set(hist_idxs)),
|
||
'user_avg_price': float(np.mean(prices)) if prices else 0.0,
|
||
'user_price_std': float(np.std(prices)) if len(prices) > 1 else 0.0,
|
||
'user_recency_days': (max(ts) - min(ts)) / DAY if len(ts) > 1 else 0.0,
|
||
}
|
||
return out
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Feature matrix construction
|
||
# ----------------------------------------------------------------------
|
||
|
||
FEATURE_COLS = [
|
||
# HLLM signals (dominant; replace TF-IDF text similarity from H&M pipeline)
|
||
'hllm_dot_product', 'hllm_max_hist_sim', 'hllm_avg_hist_sim',
|
||
# Item popularity / lifecycle
|
||
'item_total_purchases', 'item_unique_buyers',
|
||
'item_pop_30d', 'item_pop_90d', 'item_pop_180d', 'item_trend',
|
||
'item_recency_days', 'item_age_days',
|
||
# Item content
|
||
'log_price', 'title_length', 'desc_length', 'has_image',
|
||
# User history
|
||
'user_total_purchases', 'user_unique_items',
|
||
'user_avg_price', 'user_price_std', 'user_recency_days',
|
||
# Cross
|
||
'price_ratio', 'price_diff', 'is_repurchase',
|
||
]
|
||
|
||
|
||
def build_feature_matrix(
|
||
samples,
|
||
cand_idx: np.ndarray,
|
||
cand_scores: np.ndarray,
|
||
embeddings: np.ndarray,
|
||
item_idx_to_id: list[str],
|
||
item_stats: pd.DataFrame,
|
||
user_stats: dict,
|
||
history_recent_k: int = 10,
|
||
):
|
||
"""Stack one row per (user, candidate) pair. Filters users where the
|
||
held-out positive missed the candidate set.
|
||
|
||
Returns:
|
||
X: (n_rows, n_features) float32
|
||
y: (n_rows,) int — 1 for the held-out positive, 0 otherwise
|
||
groups: (n_kept_users,) int — group sizes for lambdarank
|
||
"""
|
||
n_users, top_k = cand_idx.shape
|
||
item_emb = embeddings # (n_items, dim) numpy
|
||
item_stats_arr = item_stats.reindex(item_idx_to_id).fillna(0)
|
||
item_feat_lookup = item_stats_arr.to_numpy(dtype=np.float32)
|
||
|
||
feature_to_col = {c: i for i, c in enumerate(FEATURE_COLS)}
|
||
item_cols_in_stats = [
|
||
'item_total_purchases', 'item_unique_buyers',
|
||
'item_pop_30d', 'item_pop_90d', 'item_pop_180d', 'item_trend',
|
||
'item_recency_days', 'item_age_days',
|
||
'log_price', 'title_length', 'desc_length', 'has_image',
|
||
]
|
||
stats_col_idx = [item_stats_arr.columns.get_loc(c) for c in item_cols_in_stats]
|
||
|
||
rows_per_user = top_k
|
||
X = np.zeros((n_users * rows_per_user, len(FEATURE_COLS)), dtype=np.float32)
|
||
y = np.zeros(n_users * rows_per_user, dtype=np.int8)
|
||
keep_user = np.zeros(n_users, dtype=bool)
|
||
|
||
for u_i, (uid, hist_idxs, pos_idx, _) in enumerate(samples):
|
||
cand = cand_idx[u_i] # (top_k,)
|
||
scores = cand_scores[u_i] # (top_k,)
|
||
if pos_idx not in cand:
|
||
continue # no signal for this user
|
||
keep_user[u_i] = True
|
||
|
||
row_start = u_i * rows_per_user
|
||
row_end = row_start + rows_per_user
|
||
|
||
# HLLM dot-product (already from torch.mm)
|
||
X[row_start:row_end, feature_to_col['hllm_dot_product']] = scores
|
||
|
||
# HLLM max / avg similarity vs. recent history
|
||
recent_hist = hist_idxs[-history_recent_k:]
|
||
hist_emb = item_emb[recent_hist] # (h, dim)
|
||
cand_emb = item_emb[cand] # (top_k, dim)
|
||
# Embeddings already L2-normalized at extraction time
|
||
sim_matrix = cand_emb @ hist_emb.T # (top_k, h)
|
||
X[row_start:row_end, feature_to_col['hllm_max_hist_sim']] = sim_matrix.max(axis=1)
|
||
X[row_start:row_end, feature_to_col['hllm_avg_hist_sim']] = sim_matrix.mean(axis=1)
|
||
|
||
# Item-level features (vectorized lookup over candidates)
|
||
for col_name, col_pos in zip(item_cols_in_stats, stats_col_idx):
|
||
X[row_start:row_end, feature_to_col[col_name]] = item_feat_lookup[cand, col_pos]
|
||
|
||
# User-level features (broadcast)
|
||
u_stats = user_stats[uid]
|
||
X[row_start:row_end, feature_to_col['user_total_purchases']] = u_stats['user_total_purchases']
|
||
X[row_start:row_end, feature_to_col['user_unique_items']] = u_stats['user_unique_items']
|
||
X[row_start:row_end, feature_to_col['user_avg_price']] = u_stats['user_avg_price']
|
||
X[row_start:row_end, feature_to_col['user_price_std']] = u_stats['user_price_std']
|
||
X[row_start:row_end, feature_to_col['user_recency_days']] = u_stats['user_recency_days']
|
||
|
||
# Cross features
|
||
cand_log_price_col = feature_to_col['log_price']
|
||
cand_log_price = X[row_start:row_end, cand_log_price_col]
|
||
cand_price = np.expm1(cand_log_price)
|
||
u_avg_price = u_stats['user_avg_price']
|
||
X[row_start:row_end, feature_to_col['price_ratio']] = cand_price / (u_avg_price + 1e-8)
|
||
X[row_start:row_end, feature_to_col['price_diff']] = cand_price - u_avg_price
|
||
|
||
hist_set = set(hist_idxs)
|
||
X[row_start:row_end, feature_to_col['is_repurchase']] = np.array(
|
||
[1.0 if int(c) in hist_set else 0.0 for c in cand], dtype=np.float32,
|
||
)
|
||
|
||
# Label
|
||
for k, c in enumerate(cand):
|
||
if int(c) == pos_idx:
|
||
y[row_start + k] = 1
|
||
break
|
||
|
||
# Compact: drop dropped users' rows
|
||
mask = np.repeat(keep_user, rows_per_user)
|
||
X = X[mask]
|
||
y = y[mask]
|
||
groups = np.full(int(keep_user.sum()), rows_per_user, dtype=np.int64)
|
||
return X, y, groups, int(keep_user.sum())
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Training
|
||
# ----------------------------------------------------------------------
|
||
|
||
def train(args: argparse.Namespace) -> dict:
|
||
workspace = default_workspace()
|
||
processed_dir = Path(args.processed_dir)
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print('=' * 60)
|
||
print('HLLM Re-ranker Training (LightGBM lambdarank)')
|
||
print('=' * 60)
|
||
print(f' Processed dir: {processed_dir}')
|
||
print(f' Output dir: {output_dir}')
|
||
print(f' Top-K: {args.top_k}')
|
||
print(f' Boost rounds: {args.num_rounds}')
|
||
|
||
# ------ Load ------
|
||
print('\n--- Loading data ---')
|
||
t = time.time()
|
||
interactions, metadata, embeddings, item_ids = load_inputs(processed_dir)
|
||
item_to_idx = {iid: i for i, iid in enumerate(item_ids) if iid != '[PAD]'}
|
||
item_idx_to_id = list(item_ids)
|
||
print(f' {len(interactions):,} interactions, '
|
||
f'{embeddings.shape[0]:,} item embeddings ({embeddings.shape[1]} dim) in {time.time()-t:.1f}s')
|
||
|
||
# ------ Build user samples (leave-last-out) ------
|
||
print('\n--- Building user samples ---')
|
||
t = time.time()
|
||
samples = build_user_samples(interactions, item_to_idx)
|
||
print(f' Users with >=2 mapped interactions: {len(samples):,} ({time.time()-t:.1f}s)')
|
||
|
||
# ------ User embeddings + GPU candidate retrieval ------
|
||
print('\n--- Retrieving top-K candidates from HLLM embeddings ---')
|
||
t = time.time()
|
||
user_emb = build_user_embeddings(samples, embeddings)
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
cand_idx, cand_scores = retrieve_candidates_gpu(
|
||
user_emb, embeddings, args.top_k, device,
|
||
)
|
||
print(f' ({device}) Retrieved {cand_idx.shape[0]:,} × {args.top_k} '
|
||
f'candidates in {time.time()-t:.1f}s')
|
||
|
||
# ------ Statistics for handcrafted features ------
|
||
print('\n--- Computing item and user statistics ---')
|
||
t = time.time()
|
||
item_stats = compute_item_stats(interactions, metadata)
|
||
user_stats = compute_user_stats(samples, metadata, item_idx_to_id)
|
||
print(f' Item stats: {len(item_stats):,} rows × {len(item_stats.columns)} cols; '
|
||
f'user stats: {len(user_stats):,} users ({time.time()-t:.1f}s)')
|
||
|
||
# ------ Feature matrix ------
|
||
print('\n--- Building feature matrix ---')
|
||
t = time.time()
|
||
X, y, groups, n_kept = build_feature_matrix(
|
||
samples, cand_idx, cand_scores, embeddings,
|
||
item_idx_to_id, item_stats, user_stats,
|
||
)
|
||
print(f' Kept {n_kept:,}/{len(samples):,} users '
|
||
f'(positive in top-{args.top_k}: {n_kept/len(samples):.1%}); '
|
||
f'{X.shape[0]:,} rows × {X.shape[1]} features ({time.time()-t:.1f}s)')
|
||
|
||
if n_kept < 100:
|
||
raise RuntimeError(
|
||
f'Only {n_kept} users have their held-out item in the top-{args.top_k} '
|
||
'FAISS candidates. Retriever recall is too low to train a re-ranker. '
|
||
'Train the retriever for more epochs or increase --top-k.'
|
||
)
|
||
|
||
# ------ Train/valid split (by user group) ------
|
||
rng = np.random.default_rng(args.seed)
|
||
n_groups = len(groups)
|
||
perm = rng.permutation(n_groups)
|
||
split = max(1, int(0.8 * n_groups))
|
||
train_groups, valid_groups = perm[:split], perm[split:]
|
||
|
||
rows_per_user = args.top_k
|
||
train_row_mask = np.zeros(X.shape[0], dtype=bool)
|
||
valid_row_mask = np.zeros(X.shape[0], dtype=bool)
|
||
for gi in train_groups:
|
||
train_row_mask[gi * rows_per_user:(gi + 1) * rows_per_user] = True
|
||
for gi in valid_groups:
|
||
valid_row_mask[gi * rows_per_user:(gi + 1) * rows_per_user] = True
|
||
|
||
X_train, y_train = X[train_row_mask], y[train_row_mask]
|
||
X_valid, y_valid = X[valid_row_mask], y[valid_row_mask]
|
||
g_train = np.full(len(train_groups), rows_per_user, dtype=np.int64)
|
||
g_valid = np.full(len(valid_groups), rows_per_user, dtype=np.int64)
|
||
|
||
# ------ Train ------
|
||
print(f'\n--- Training LightGBM lambdarank ---')
|
||
print(f' Train: {len(train_groups):,} users / {X_train.shape[0]:,} rows | '
|
||
f'Valid: {len(valid_groups):,} users / {X_valid.shape[0]:,} rows')
|
||
|
||
train_set = lgb.Dataset(X_train, y_train, group=g_train, feature_name=FEATURE_COLS)
|
||
valid_set = lgb.Dataset(X_valid, y_valid, group=g_valid, feature_name=FEATURE_COLS,
|
||
reference=train_set)
|
||
|
||
params = {
|
||
'objective': 'lambdarank',
|
||
'metric': 'ndcg',
|
||
'ndcg_eval_at': [5, 10, 20],
|
||
'lambdarank_truncation_level': args.lambdarank_truncation_level,
|
||
'learning_rate': args.lr,
|
||
'num_leaves': args.num_leaves,
|
||
'max_depth': args.max_depth,
|
||
'min_child_samples': 50,
|
||
'min_gain_to_split': 0.0,
|
||
'lambda_l1': args.lambda_l1,
|
||
'lambda_l2': args.lambda_l2,
|
||
'feature_fraction': 0.8,
|
||
'bagging_fraction': 0.8,
|
||
'bagging_freq': 1,
|
||
'verbose': -1,
|
||
'seed': args.seed,
|
||
}
|
||
if args.label_gain == 'binary':
|
||
params['label_gain'] = [0, 1]
|
||
|
||
eval_history: dict = {}
|
||
t = time.time()
|
||
model = lgb.train(
|
||
params, train_set,
|
||
num_boost_round=args.num_rounds,
|
||
valid_sets=[train_set, valid_set],
|
||
valid_names=['train', 'valid'],
|
||
callbacks=[
|
||
lgb.log_evaluation(period=10),
|
||
lgb.record_evaluation(eval_history),
|
||
lgb.early_stopping(stopping_rounds=args.early_stopping_rounds, verbose=True),
|
||
],
|
||
)
|
||
train_seconds = time.time() - t
|
||
|
||
best_round = int(model.best_iteration)
|
||
train_ndcg10 = float(eval_history['train']['ndcg@10'][best_round - 1]) if best_round else 0.0
|
||
valid_ndcg10 = float(model.best_score['valid']['ndcg@10'])
|
||
valid_ndcg5 = float(model.best_score['valid']['ndcg@5'])
|
||
valid_ndcg20 = float(model.best_score['valid']['ndcg@20'])
|
||
print(
|
||
f'\nTrained {model.current_iteration()} rounds in {train_seconds:.1f}s '
|
||
f'(best={best_round}, early-stopped at {args.early_stopping_rounds} rounds patience)'
|
||
)
|
||
print(
|
||
f' Best valid: NDCG@5={valid_ndcg5:.4f} NDCG@10={valid_ndcg10:.4f} NDCG@20={valid_ndcg20:.4f}'
|
||
)
|
||
print(
|
||
f' Train/valid gap @10: {train_ndcg10 - valid_ndcg10:+.4f} '
|
||
f"(train NDCG@10={train_ndcg10:.4f}) — large positive = overfitting"
|
||
)
|
||
|
||
# ------ Feature importance ------
|
||
importance_gain = model.feature_importance(importance_type='gain')
|
||
importance_split = model.feature_importance(importance_type='split')
|
||
fi = sorted(
|
||
zip(FEATURE_COLS, importance_gain, importance_split),
|
||
key=lambda x: -x[1],
|
||
)
|
||
print('\nTop 10 features by gain:')
|
||
for name, gain, split_count in fi[:10]:
|
||
print(f' {name:28s} gain={gain:>12,.0f} splits={split_count:>5,}')
|
||
|
||
# ------ Save ------
|
||
model.save_model(str(output_dir / 'reranker_lightgbm.txt'))
|
||
metrics = {
|
||
'model': 'lightgbm_lambdarank',
|
||
'feature_cols': FEATURE_COLS,
|
||
'top_k': args.top_k,
|
||
'num_rounds': args.num_rounds,
|
||
'early_stopping_rounds': args.early_stopping_rounds,
|
||
'best_iteration': best_round,
|
||
'last_iteration': int(model.current_iteration()),
|
||
'best_train_ndcg10': train_ndcg10,
|
||
'best_valid_ndcg5': valid_ndcg5,
|
||
'best_valid_ndcg10': valid_ndcg10,
|
||
'best_valid_ndcg20': valid_ndcg20,
|
||
'train_valid_gap_ndcg10': train_ndcg10 - valid_ndcg10,
|
||
'hyperparams': {
|
||
'num_leaves': args.num_leaves,
|
||
'learning_rate': args.lr,
|
||
'lambda_l1': args.lambda_l1,
|
||
'lambda_l2': args.lambda_l2,
|
||
'feature_fraction': params['feature_fraction'],
|
||
'bagging_fraction': params['bagging_fraction'],
|
||
},
|
||
'eval_history': {
|
||
'train_ndcg10': [float(v) for v in eval_history['train']['ndcg@10']],
|
||
'valid_ndcg10': [float(v) for v in eval_history['valid']['ndcg@10']],
|
||
'valid_ndcg5': [float(v) for v in eval_history['valid']['ndcg@5']],
|
||
'valid_ndcg20': [float(v) for v in eval_history['valid']['ndcg@20']],
|
||
},
|
||
'n_users_trained': len(train_groups),
|
||
'n_users_valid': len(valid_groups),
|
||
'n_users_dropped': len(samples) - n_kept,
|
||
'retriever_recall_at_top_k': n_kept / len(samples),
|
||
'train_seconds': train_seconds,
|
||
'feature_importance': [
|
||
{'feature': n, 'gain': int(g), 'split': int(s)} for n, g, s in fi
|
||
],
|
||
}
|
||
(output_dir / 'metrics.json').write_text(json.dumps(metrics, indent=2) + '\n')
|
||
print(f'\nSaved model to {output_dir / "reranker_lightgbm.txt"}')
|
||
print(f'Saved metrics to {output_dir / "metrics.json"}')
|
||
return metrics
|
||
|
||
|
||
def main() -> int:
|
||
workspace = default_workspace()
|
||
parser = argparse.ArgumentParser(description=__doc__)
|
||
parser.add_argument('--processed-dir', default=str(workspace / 'data' / 'processed'))
|
||
parser.add_argument('--output-dir', default=str(workspace / 'models' / 'reranker_lightgbm'))
|
||
parser.add_argument('--top-k', type=int, default=100,
|
||
help='Candidates per user from the HLLM retriever (default: 100).')
|
||
parser.add_argument('--num-rounds', type=int, default=1000)
|
||
parser.add_argument('--early-stopping-rounds', type=int, default=50)
|
||
parser.add_argument('--num-leaves', type=int, default=63)
|
||
parser.add_argument('--max-depth', type=int, default=8,
|
||
help='Tree depth cap. -1 disables. Defaults to 8 to constrain '
|
||
'overfitting on the small post-recall@K training set; see '
|
||
'docs/experiment-log.md ablation 2026-05-09.')
|
||
parser.add_argument('--lr', type=float, default=0.05)
|
||
parser.add_argument('--lambda-l1', type=float, default=0.0)
|
||
parser.add_argument('--lambda-l2', type=float, default=1.0,
|
||
help='L2 regularization. Defaults to 1.0; see ablation 2026-05-09.')
|
||
parser.add_argument('--lambdarank-truncation-level', type=int, default=30,
|
||
help='LightGBM default 30. Set near the eval NDCG cutoff for better head-of-list gradients.')
|
||
parser.add_argument('--label-gain', choices=['default', 'binary'], default='default',
|
||
help='"default" = LightGBM graded-relevance gains; "binary" = [0,1] for 0/1 labels.')
|
||
parser.add_argument('--seed', type=int, default=42)
|
||
args = parser.parse_args()
|
||
|
||
t_total = time.time()
|
||
train(args)
|
||
print(f'\nTotal wall time: {time.time()-t_total:.1f}s')
|
||
return 0
|
||
|
||
|
||
if __name__ == '__main__':
|
||
raise SystemExit(main())
|