dgx-spark-playbooks/nvidia/station-rec-sys/assets/patches/HLLM/wandblogger.py
2026-05-26 18:25:53 +00:00

74 lines
2.7 KiB
Python

# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT
class WandbLogger(object):
"""WandbLogger to log metrics to Weights and Biases.
"""
def __init__(self, config):
"""
Args:
config (dict): A dictionary of parameters used by RecBole.
"""
self.config = config
self.log_wandb = config.log_wandb
self.setup()
def setup(self):
if self.log_wandb:
try:
import wandb
self._wandb = wandb
except ImportError:
raise ImportError(
"To use the Weights and Biases Logger please install wandb."
"Run `pip install wandb` to install it."
)
# Initialize a W&B run
if self._wandb.run is None:
self._wandb.init(
project=self.config.wandb_project,
config=self.config
)
self._set_steps()
def log_metrics(self, metrics, head='train', commit=True):
if self.log_wandb:
if head:
metrics = self._add_head_to_metrics(metrics, head)
self._wandb.log(metrics, commit=commit)
else:
self._wandb.log(metrics, commit=commit)
def log_eval_metrics(self, metrics, head='eval'):
if self.log_wandb:
metrics = self._add_head_to_metrics(metrics, head)
for k, v in metrics.items():
self._wandb.run.summary[k] = v
def _set_steps(self):
self._wandb.define_metric('train/*', step_metric='train_step')
self._wandb.define_metric('valid/*', step_metric='valid_step')
self._wandb.define_metric('train_step/*', step_metric='global_step')
# Promote cross-run-comparable NCE metrics to the run-summary table so
# they appear on the W&B runs-list / dashboard without extra config.
# loss_over_random = step_loss / log(N+1): 1.0 = random, 0.0 = perfect.
# nce_top{k}_acc_lift = acc / (k/(N+1)): 1.0 = random, >1 = lift.
# summary='min' / 'max' makes W&B auto-surface the best value per run.
self._wandb.define_metric('train_step/loss_over_random', step_metric='global_step', summary='min')
for k in (1, 5, 10, 50, 100):
self._wandb.define_metric(f'train_step/nce_top{k}_acc_lift', step_metric='global_step', summary='max')
def _add_head_to_metrics(self, metrics, head):
head_metrics = dict()
for k, v in metrics.items():
if '_step' in k:
head_metrics[k] = v
else:
head_metrics[f'{head}/{k}'] = v
return head_metrics