# Copyright (c) 2024 westlake-repl # SPDX-License-Identifier: MIT class WandbLogger(object): """WandbLogger to log metrics to Weights and Biases. """ def __init__(self, config): """ Args: config (dict): A dictionary of parameters used by RecBole. """ self.config = config self.log_wandb = config.log_wandb self.setup() def setup(self): if self.log_wandb: try: import wandb self._wandb = wandb except ImportError: raise ImportError( "To use the Weights and Biases Logger please install wandb." "Run `pip install wandb` to install it." ) # Initialize a W&B run if self._wandb.run is None: self._wandb.init( project=self.config.wandb_project, config=self.config ) self._set_steps() def log_metrics(self, metrics, head='train', commit=True): if self.log_wandb: if head: metrics = self._add_head_to_metrics(metrics, head) self._wandb.log(metrics, commit=commit) else: self._wandb.log(metrics, commit=commit) def log_eval_metrics(self, metrics, head='eval'): if self.log_wandb: metrics = self._add_head_to_metrics(metrics, head) for k, v in metrics.items(): self._wandb.run.summary[k] = v def _set_steps(self): self._wandb.define_metric('train/*', step_metric='train_step') self._wandb.define_metric('valid/*', step_metric='valid_step') self._wandb.define_metric('train_step/*', step_metric='global_step') # Promote cross-run-comparable NCE metrics to the run-summary table so # they appear on the W&B runs-list / dashboard without extra config. # loss_over_random = step_loss / log(N+1): 1.0 = random, 0.0 = perfect. # nce_top{k}_acc_lift = acc / (k/(N+1)): 1.0 = random, >1 = lift. # summary='min' / 'max' makes W&B auto-surface the best value per run. self._wandb.define_metric('train_step/loss_over_random', step_metric='global_step', summary='min') for k in (1, 5, 10, 50, 100): self._wandb.define_metric(f'train_step/nce_top{k}_acc_lift', step_metric='global_step', summary='max') def _add_head_to_metrics(self, metrics, head): head_metrics = dict() for k, v in metrics.items(): if '_step' in k: head_metrics[k] = v else: head_metrics[f'{head}/{k}'] = v return head_metrics