# Copyright (c) 2024 westlake-repl # Copyright (c) 2024 Bytedance Ltd. and/or its affiliate # SPDX-License-Identifier: MIT # This file has been modified by Junyi Chen. # # Original file was released under MIT, with the full license text # available at https://choosealicense.com/licenses/mit/. # # This modified file is released under the same license. import os import sys import math from logging import getLogger from time import time import time as t import numpy as np import torch import torch.optim as optim import torch.distributed as dist from tqdm import tqdm import deepspeed from deepspeed.ops.adam import FusedAdam from REC.data.dataset import BatchTextDataset from REC.data.dataset.collate_fn import customize_rmpad_collate from torch.utils.data import DataLoader from REC.evaluator import Evaluator, Collector from REC.utils import ensure_dir, get_local_time, early_stopping, calculate_valid_score, dict2str, \ get_tensorboard, set_color, get_gpu_usage, WandbLogger from REC.utils.lr_scheduler import * import lightning as L from lightning.fabric.strategies import DeepSpeedStrategy, DDPStrategy # torch.compile: allow more recompiles for variable packed-sequence shapes # (HLLM's collate uses cu_input_lens → shapes differ per batch) import torch._dynamo torch._dynamo.config.cache_size_limit = 64 torch._dynamo.config.capture_scalar_outputs = True # trace through .item() calls (e.g. flash-attn max_seqlen) class Trainer(object): def __init__(self, config, model): super(Trainer, self).__init__() self.config = config self.model = model self.logger = getLogger() self.wandblogger = WandbLogger(config) self.optim_args = config['optim_args'] self.epochs = config['epochs'] self.eval_step = min(config['eval_step'], self.epochs) self.stopping_step = config['stopping_step'] self.max_steps = config.get('max_steps', 0) # 0 = unlimited (use epochs) self.clip_grad_norm = config.get('clip_grad_norm', 1.0) self.valid_metric = config['valid_metric'].lower() self.valid_metric_bigger = config['valid_metric_bigger'] self.test_batch_size = config['eval_batch_size'] self.gpu_available = torch.cuda.is_available() and config['use_gpu'] self.device = config['device'] self.rank = torch.distributed.get_rank() if self.rank == 0: self.tensorboard = get_tensorboard(self.logger) self.checkpoint_dir = config['checkpoint_dir'] if self.rank == 0: ensure_dir(self.checkpoint_dir) self.saved_model_name = '{}-{}.pth'.format(self.config['model'], 0) self.saved_model_file = os.path.join(self.checkpoint_dir, self.saved_model_name) self.use_text = config['use_text'] self.start_epoch = 0 self.cur_step = 0 self.global_step_count = 0 # tracks steps across epochs for max_steps self.best_valid_score = -np.inf if self.valid_metric_bigger else np.inf self.best_valid_result = None self.train_loss_dict = dict() self.optimizer = self._build_optimizer() self.update_interval = config['update_interval'] if config['update_interval'] else 20 self.grad_accum_steps = config.get('gradient_accumulation_steps', 1) self.scheduler_config = config['scheduler_args'] if config['freeze_prefix'] or config['freeze_ad']: freeze_prefix = config['freeze_prefix'] if config['freeze_prefix'] else [] if config['freeze_ad']: freeze_prefix.extend(['item_llm', 'item_emb_tokens']) if not config['ft_item']: freeze_prefix.extend(['item_embedding']) self._freeze_params(freeze_prefix) for n, p in self.model.named_parameters(): self.logger.info(f"{n} {p.size()} {p.requires_grad}") self.eval_collector = Collector(config) self.evaluator = Evaluator(config) self.item_feature = None self.tot_item_num = None def _freeze_params(self, freeze_prefix): for name, param in self.model.named_parameters(): for prefix in freeze_prefix: if name.startswith(prefix): self.logger.info(f"freeze_params: {name}") param.requires_grad = False def _build_scheduler(self, warmup_steps=None, tot_steps=None): if self.scheduler_config['type'] == 'cosine': self.logger.info(f"Use consine scheduler with {warmup_steps} warmup {tot_steps} total steps") return get_cosine_schedule_with_warmup(self.optimizer, warmup_steps, tot_steps) elif self.scheduler_config['type'] == 'liner': self.logger.info(f"Use linear scheduler with {warmup_steps} warmup {tot_steps} total steps") return get_linear_schedule_with_warmup(self.optimizer, warmup_steps, tot_steps) else: self.logger.info(f"Use constant scheduler") return get_constant_schedule(self.optimizer) def _build_optimizer(self): if len(self.optim_args) == 4: params = self.model.named_parameters() modal_params = [] recsys_params = [] modal_decay_params = [] recsys_decay_params = [] decay_check_name = self.config['decay_check_name'] for index, (name, param) in enumerate(params): if param.requires_grad: if 'visual_encoder' in name: modal_params.append(param) else: recsys_params.append(param) if decay_check_name: if decay_check_name in name: modal_decay_params.append(param) else: recsys_decay_params.append(param) if decay_check_name: optimizer = optim.AdamW([ {'params': modal_decay_params, 'lr': self.optim_args['modal_lr'], 'weight_decay': self.optim_args['modal_decay']}, {'params': recsys_decay_params, 'lr': self.optim_args['rec_lr'], 'weight_decay': self.optim_args['rec_decay']} ]) optim_output = set_color(f'recsys_decay_params_len: {len(recsys_decay_params)} modal_params_decay_len: {len(modal_decay_params)}', 'blue') self.logger.info(optim_output) else: optimizer = optim.AdamW([ {'params': modal_params, 'lr': self.optim_args['modal_lr'], 'weight_decay': self.optim_args['modal_decay']}, {'params': recsys_params, 'lr': self.optim_args['rec_lr'], 'weight_decay': self.optim_args['rec_decay']} ]) optim_output = set_color(f'recsys_lr_params_len: {len(recsys_params)} modal_lr_params_len: {len(modal_params)}', 'blue') self.logger.info(optim_output) elif self.config['lr_mult_prefix'] and self.config['lr_mult_rate']: normal_params_dict = { "params": [], "lr": self.optim_args['learning_rate'], "weight_decay": self.optim_args['weight_decay'] } high_lr_params_dict = { "params": [], "lr": self.optim_args['learning_rate'] * self.config['lr_mult_rate'], "weight_decay": self.optim_args['weight_decay'] } self.logger.info(f'Use higher lr rate {self.config["lr_mult_rate"]} x {self.optim_args["learning_rate"]} for prefix {self.config["lr_mult_prefix"]}') for n, p in self.model.named_parameters(): if any(n.startswith(x) for x in self.config['lr_mult_prefix']): self.logger.info(f"high lr param: {n} {self.optim_args['learning_rate'] * self.config['lr_mult_rate']}") high_lr_params_dict["params"].append(p) else: normal_params_dict["params"].append(p) optimizer = optim.AdamW([normal_params_dict, high_lr_params_dict]) elif self.config['optimizer_kwargs']: params = [p for p in self.model.parameters() if p.requires_grad] self.config['optimizer_kwargs']['optimizer']['params']['lr'] = self.optim_args['learning_rate'] self.config['optimizer_kwargs']['optimizer']['params']['weight_decay'] = self.optim_args['weight_decay'] optimizer = deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam(params, **self.config['optimizer_kwargs']['optimizer']['params']) else: params = [p for p in self.model.parameters() if p.requires_grad] # use deepspeed fused adam optimizer if set in the config if self.config.get('use_fused_adam', True): optimizer = FusedAdam( params, lr=self.optim_args['learning_rate'], weight_decay=self.optim_args['weight_decay'], adam_w_mode=True, ) self.logger.info( f"Optimizer: DeepSpeed FusedAdam (GPU fused kernels), " f"adam_w_mode=True, lr={self.optim_args['learning_rate']}, " f"weight_decay={self.optim_args['weight_decay']}" ) # otherwise just use AdamW else: optimizer = optim.AdamW( params, lr=self.optim_args['learning_rate'], weight_decay=self.optim_args['weight_decay'], ) self.logger.info("Optimizer: torch.optim.AdamW (fused_adam disabled)") return optimizer def _train_epoch(self, train_data, epoch_idx, show_progress=False, valid_data=None): self.model.train() total_loss = 0 # auto_resume: skip already-completed batches on the first resumed epoch. # Cleared after this epoch so subsequent epochs run full-length. skip_batches = 0 if epoch_idx == self.start_epoch and getattr(self, '_resume_batch_idx', 0) > 0: skip_batches = self._resume_batch_idx if self.rank == 0: self.logger.info( f"auto_resume: skipping first {skip_batches} batches of epoch {epoch_idx}" ) self._resume_batch_idx = 0 if self.rank == 0: pbar = tqdm( total=len(train_data), miniters=self.update_interval, desc=set_color(f"Train [{epoch_idx:>3}/{self.epochs:>3}]", 'pink'), file=sys.stdout ) accum_steps = self.grad_accum_steps grad_norm = None bwd_time = t.time() self.optimizer.zero_grad() for batch_idx, data in enumerate(train_data): if batch_idx < skip_batches: continue start_time = bwd_time data = self.to_device(data) data_time = t.time() losses = self.model(data) fwd_time = t.time() if self.config['loss'] and self.config['loss'] == 'nce': model_out = losses losses = model_out.pop('loss') self._check_nan(losses) total_loss = total_loss + losses.item() self.lite.backward(losses / accum_steps) is_accum_step = (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1) == len(train_data) if is_accum_step: grad_norm = self.optimizer.step() self.optimizer.zero_grad() bwd_time = t.time() if self.scheduler_config: self.lr_scheduler.step() else: bwd_time = t.time() # Step-based checkpoint saving (counted in micro-batches) global_step = epoch_idx * len(train_data) + batch_idx save_steps = self.config.get('save_steps', 0) if save_steps > 0 and batch_idx > 0 and batch_idx % save_steps == 0: step_ckpt_name = '{}-{}-step-{}.pth'.format(self.config['model'], epoch_idx, batch_idx) step_state = { "model": self.model, "optimizer": self.optimizer, 'config': self.config, 'epoch': epoch_idx, 'batch_idx': batch_idx, 'cur_step': self.cur_step, 'best_valid_score': self.best_valid_score, 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state() } self.lite.save(os.path.join(self.checkpoint_dir, step_ckpt_name), state=step_state) if self.rank == 0: self.logger.info(f"Step checkpoint saved: {step_ckpt_name}") # Keep only the latest N step checkpoints max_keep = self.config.get('max_keep_checkpoints', 3) if self.rank == 0 and max_keep > 0: import glob, shutil pattern = os.path.join(self.checkpoint_dir, '{}-*-step-*.pth'.format(self.config['model'])) existing = sorted(glob.glob(pattern), key=os.path.getmtime) while len(existing) > max_keep: old = existing.pop(0) shutil.rmtree(old, ignore_errors=True) self.logger.info(f"Removed old checkpoint: {os.path.basename(old)}") # Per-step W&B logging wandb_log_interval = self.config.get('wandb_log_interval', 100) if self.rank == 0 and wandb_log_interval > 0 and batch_idx % wandb_log_interval == 0: step_metrics = { 'step_loss': losses.item(), 'global_step': global_step, 'epoch_progress': batch_idx / len(train_data), } if self.scheduler_config: step_metrics['learning_rate'] = self.lr_scheduler.get_lr()[0] if self.config['loss'] and self.config['loss'] == 'nce' and isinstance(model_out, dict): for k, v in model_out.items(): step_metrics[k] = v.item() if hasattr(v, 'item') else v # Cross-run normalized metrics: NCE loss ceiling is log(N+1) where N = effective # negatives (random-guess baseline). Dividing by that puts any run on a 0->1 scale # (0 = perfect, 1 = random), so loss / top-k curves from different num_negatives # settings are directly comparable. Same idea for top-k lift over random = k/(N+1). nce_samples = step_metrics.get('nce_samples') if nce_samples is not None and nce_samples > 0: rand_ceil = math.log(nce_samples + 1.0) if rand_ceil > 0: step_metrics['loss_over_random'] = step_metrics['step_loss'] / rand_ceil for k in (1, 5, 10, 50, 100): acc_key = f'nce_top{k}_acc' if acc_key in step_metrics: random_p = k / (nce_samples + 1.0) if random_p > 0: step_metrics[f'{acc_key}_lift'] = step_metrics[acc_key] / random_p self.wandblogger.log_metrics(step_metrics, head='train_step') # In-epoch full-catalog validation: expensive (one pass of _full_sort_batch_eval, # ~300 s on dresses), but the only way to see paper-family Recall@K / NDCG@K before # the epoch boundary. Gated by `fast_eval_interval` (0 = disabled). Recommended: 500 # on dresses (~15% overhead), higher on larger catalogs. Logged under `valid_fast/*` # so it doesn't overwrite the authoritative per-epoch `valid/*` curve. fast_eval_interval = self.config.get('fast_eval_interval', 0) if (fast_eval_interval > 0 and valid_data is not None and batch_idx > 0 and batch_idx % fast_eval_interval == 0): torch.distributed.barrier() fast_start = t.time() fast_valid_result = self.evaluate(valid_data, load_best_model=False, show_progress=False) torch.distributed.barrier() self.model.train() # evaluate() sets model.eval(); flip back for training if self.rank == 0 and fast_valid_result: fast_elapsed = t.time() - fast_start self.logger.info( f"fast eval @ step {global_step}: " f"{dict2str(fast_valid_result)} (took {fast_elapsed:.1f}s)" ) self.wandblogger.log_metrics( {**fast_valid_result, 'global_step': global_step}, head='valid_fast' ) if show_progress and self.rank == 0 and batch_idx % self.update_interval == 0: msg = f"loss: {losses:.4f} data: {data_time-start_time:.3f} fwd: {fwd_time-data_time:.3f} bwd: {bwd_time-fwd_time:.3f}" if self.scheduler_config: msg = f"lr: {self.lr_scheduler.get_lr()[0]:.7f} " + msg if self.config['loss'] and self.config['loss'] == 'nce': for k, v in model_out.items(): if k.endswith('loss'): msg += f" {k}: {v:.3f}" if grad_norm: msg = msg + f" grad_norm: {grad_norm.sum():.4f}" pbar.set_postfix_str(msg, refresh=False) pbar.update(self.update_interval) self.logger.info("\n" + "-"*50) if self.config['debug'] and batch_idx >= 10: break self.global_step_count += 1 if self.max_steps > 0 and self.global_step_count >= self.max_steps: if self.rank == 0: self.logger.info(f"Reached max_steps={self.max_steps}, stopping training.") break return total_loss def _valid_epoch(self, valid_data, show_progress=False): torch.distributed.barrier() valid_result = self.evaluate(valid_data, load_best_model=False, show_progress=show_progress) valid_score = calculate_valid_score(valid_result, self.valid_metric) torch.distributed.barrier() return valid_score, valid_result def _save_checkpoint(self, epoch, verbose=True): r"""Store the model parameters information and training information. Args: epoch (int): the current epoch id """ state = { "model": self.model, "optimizer": self.optimizer, 'config': self.config, 'epoch': epoch, 'cur_step': self.cur_step, 'best_valid_score': self.best_valid_score, 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state() } self.lite.save(os.path.join(self.checkpoint_dir, self.saved_model_name), state=state) if self.rank == 0 and verbose: self.logger.info(set_color('Saving current', 'blue') + f': {self.saved_model_file}') def _check_nan(self, loss): if torch.isnan(loss): raise ValueError('Training loss is nan') def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses): des = self.config['loss_decimal_place'] or 4 train_loss_output = (set_color('epoch %d training', 'green') + ' [' + set_color('time', 'blue') + ': %.2fs, ') % (epoch_idx, e_time - s_time) if isinstance(losses, tuple): des = (set_color('train_loss%d', 'blue') + ': %.' + str(des) + 'f') train_loss_output += ', '.join(des % (idx + 1, loss) for idx, loss in enumerate(losses)) else: des = '%.' + str(des) + 'f' train_loss_output += set_color('train loss', 'blue') + ': ' + des % losses return train_loss_output + ']' def _add_train_loss_to_tensorboard(self, epoch_idx, losses, tag='Loss/Train'): if isinstance(losses, tuple): for idx, loss in enumerate(losses): self.tensorboard.add_scalar(tag + str(idx), loss, epoch_idx) else: self.tensorboard.add_scalar(tag, losses, epoch_idx) def _add_hparam_to_tensorboard(self, best_valid_result): # base hparam hparam_dict = { 'learning_rate': self.config['learning_rate'], 'weight_decay': self.config['weight_decay'], 'train_batch_size': self.config['train_batch_size'] } # unrecorded parameter unrecorded_parameter = { parameter for parameters in self.config.parameters.values() for parameter in parameters }.union({'model', 'dataset', 'config_files', 'device'}) # other model-specific hparam hparam_dict.update({ para: val for para, val in self.config.final_config_dict.items() if para not in unrecorded_parameter }) for k in hparam_dict: k = k.replace('@', '_') if hparam_dict[k] is not None and not isinstance(hparam_dict[k], (bool, str, float, int)): hparam_dict[k] = str(hparam_dict[k]) self.tensorboard.add_hparams(hparam_dict, {'hparam/best_valid_result': best_valid_result}) def to_device(self, data): device = self.device if isinstance(data, tuple) or isinstance(data, list): tdata = () for d in data: d = d.to(device) tdata += (d,) return tdata elif isinstance(data, dict): for k, v in data.items(): data[k] = v.to(device) return data else: return data.to(device) def _maybe_resume(self, saved=True): """Load model + optimizer + RNG state from the latest checkpoint in self.checkpoint_dir when config['auto_resume'] is truthy. Sets self.start_epoch and self._resume_batch_idx so the fit() loop and _train_epoch() pick up where the previous run stopped. No-op if auto_resume is off, the checkpoint dir is missing, or no checkpoint files are found (logs the reason and returns). Must be called AFTER self.lite.setup() and BEFORE torch.compile(), so state loads into the bare Fabric-wrapped model. """ self._resume_batch_idx = 0 # default: no skip if not saved or not self.config.get('auto_resume', False): return if not os.path.isdir(self.checkpoint_dir): if self.rank == 0: self.logger.info( f"auto_resume: no checkpoint dir at {self.checkpoint_dir}; starting fresh" ) return import glob model_name = self.config['model'] step_pat = os.path.join(self.checkpoint_dir, f'{model_name}-*-step-*.pth') epoch_pat = os.path.join(self.checkpoint_dir, f'{model_name}-*.pth') candidates = sorted( set(glob.glob(step_pat) + glob.glob(epoch_pat)), key=os.path.getmtime, ) candidates = [c for c in candidates if os.path.exists(c)] if not candidates: if self.rank == 0: self.logger.info( f"auto_resume: no checkpoint found in {self.checkpoint_dir}; starting fresh" ) return latest = candidates[-1] is_step_ckpt = '-step-' in os.path.basename(latest) # DeepSpeed checkpoints are directories containing # checkpoint/mp_rank_00_model_states.pt (model + user-state scalars) and # checkpoint/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt (optimizer). # Fabric.load handles the model + optimizer; read scalars directly because # the Fabric API returns loaded state via in-place tensor update, not via # scalar return. inner = os.path.join(latest, 'checkpoint', 'mp_rank_00_model_states.pt') if not os.path.isfile(inner): if self.rank == 0: self.logger.warning( f"auto_resume: {latest} has no inner state file; starting fresh" ) return inner_state = torch.load(inner, map_location='cpu', weights_only=False) saved_epoch = int(inner_state.get('epoch', 0)) saved_batch_idx = int(inner_state.get('batch_idx', 0)) if is_step_ckpt else 0 saved_cur_step = int(inner_state.get('cur_step', 0)) saved_best = inner_state.get('best_valid_score', self.best_valid_score) rng_state = inner_state.get('rng_state', None) cuda_rng_state = inner_state.get('cuda_rng_state', None) del inner_state # Load model + optimizer via Fabric (handles DeepSpeed ZeRO sharding). state = {"model": self.model, "optimizer": self.optimizer} self.lite.load(latest, state, strict=False) if rng_state is not None: if not isinstance(rng_state, torch.Tensor): rng_state = torch.tensor(rng_state, dtype=torch.uint8) torch.set_rng_state(rng_state) if cuda_rng_state is not None and torch.cuda.is_available(): if not isinstance(cuda_rng_state, torch.Tensor): cuda_rng_state = torch.tensor(cuda_rng_state, dtype=torch.uint8) torch.cuda.set_rng_state(cuda_rng_state) if is_step_ckpt: self.start_epoch = saved_epoch self._resume_batch_idx = saved_batch_idx else: # End-of-epoch ckpt: continue from the NEXT epoch at batch 0. self.start_epoch = saved_epoch + 1 self.cur_step = saved_cur_step self.best_valid_score = saved_best if self.rank == 0: kind = "step" if is_step_ckpt else "epoch" self.logger.info( f"auto_resume: Resuming from {os.path.basename(latest)} " f"({kind} ckpt: epoch={saved_epoch}, batch_idx={saved_batch_idx}, " f"cur_step={saved_cur_step}, best_valid_score={saved_best})" ) def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): if self.scheduler_config: warmup_rate = self.scheduler_config.get('warmup', 0.001) micro_steps = self.max_steps if self.max_steps > 0 else len(train_data) * self.epochs tot_steps = micro_steps // self.grad_accum_steps warmup_steps = tot_steps * warmup_rate self.lr_scheduler = self._build_scheduler(warmup_steps=warmup_steps, tot_steps=tot_steps) world_size, local_world_size = int(os.environ['WORLD_SIZE']), int(os.environ['LOCAL_WORLD_SIZE']) nnodes = world_size // local_world_size precision = self.config['precision'] if self.config['precision'] else '32' if self.config['strategy'] == 'deepspeed': self.logger.info(f"Use deepspeed strategy") strategy = DeepSpeedStrategy( stage=self.config["stage"], precision=precision, exclude_frozen_parameters=self.config.get('exclude_frozen_parameters', True), ) self.lite = L.Fabric(accelerator='gpu', strategy=strategy, precision=precision, num_nodes=nnodes) else: self.logger.info(f"Use DDP strategy") strategy = DDPStrategy(find_unused_parameters=True) self.lite = L.Fabric(accelerator='gpu', strategy=strategy, precision=precision, num_nodes=nnodes) self.lite.launch() self.model, self.optimizer = self.lite.setup(self.model, self.optimizer) # Resume BEFORE torch.compile so model + optimizer state load into the # bare Fabric-wrapped model, not the compiled wrapper. self._maybe_resume(saved=saved) # Goal 4: torch.compile the Fabric-wrapped model (includes LoRA adapters) if self.config.get('torch_compile', False): compile_mode = self.config.get('torch_compile_mode', 'default') self.logger.info( f"torch.compile: enabled, mode='{compile_mode}' " f"(first step will stall ~60-120s for graph capture)" ) self.model = torch.compile(self.model, mode=compile_mode) else: self.logger.info("torch.compile: disabled") valid_step = 0 for epoch_idx in range(self.start_epoch, self.epochs): # train if self.config['need_training'] == None or self.config['need_training']: train_data.sampler.set_epoch(epoch_idx) training_start_time = time() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress, valid_data=valid_data) self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss training_end_time = time() train_loss_output = \ self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) if self.rank == 0: self._add_train_loss_to_tensorboard(epoch_idx, train_loss) self.wandblogger.log_metrics({'epoch': epoch_idx, 'train_loss': train_loss, 'train_step': epoch_idx}, head='train') # Break out of epoch loop if max_steps reached if self.max_steps > 0 and self.global_step_count >= self.max_steps: if saved: self._save_checkpoint(epoch_idx, verbose=verbose) break if self.eval_step <= 0 or not valid_data: if saved: self._save_checkpoint(epoch_idx, verbose=verbose) continue if (epoch_idx + 1) % self.eval_step == 0: valid_start_time = time() valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress) self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping( valid_score, self.best_valid_score, self.cur_step, max_step=self.stopping_step, bigger=self.valid_metric_bigger ) valid_end_time = time() valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue') + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \ (epoch_idx, valid_end_time - valid_start_time, valid_score) valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result) if verbose: self.logger.info(valid_score_output) self.logger.info(valid_result_output) if self.rank == 0: self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx) for name, value in valid_result.items(): self.tensorboard.add_scalar(name.replace('@', '_'), value, epoch_idx) self.wandblogger.log_metrics({**valid_result, 'valid_step': valid_step}, head='valid') if update_flag: if saved: self._save_checkpoint(epoch_idx, verbose=verbose) self.best_valid_result = valid_result if callback_fn: callback_fn(epoch_idx, valid_score) # Guard: suppress early-stop until at least `min_epochs_before_early_stop` # epochs have completed. With eval_step=1 + stopping_step=2, the default # upstream behaviour lets training exit after 3 epochs, which is noisy on # runs where loss hasn't even crossed log(N+1) random baseline yet. min_epochs_before_stop = self.config.get('min_epochs_before_early_stop', 1) if stop_flag and (epoch_idx + 1) < min_epochs_before_stop: if verbose: self.logger.info( f"early-stop suppressed: epoch {epoch_idx + 1} < " f"min_epochs_before_early_stop={min_epochs_before_stop}" ) stop_flag = False if stop_flag: stop_output = 'Finished training, best eval result in epoch %d' % \ (epoch_idx - self.cur_step * self.eval_step) if verbose: self.logger.info(stop_output) break valid_step += 1 return self.best_valid_score, self.best_valid_result @torch.no_grad() def _full_sort_batch_eval(self, batched_data): user, time_seq, history_index, positive_u, positive_i = batched_data interaction = self.to_device(user) time_seq = self.to_device(time_seq) if self.config['model'] == 'HLLM': if self.config['stage'] == 3: scores = self.model.module.predict(interaction, time_seq, self.item_feature) else: scores = self.model((interaction, time_seq, self.item_feature), mode='predict') else: scores = self.model.module.predict(interaction, time_seq, self.item_feature) scores = scores.view(-1, self.tot_item_num) scores[:, 0] = -np.inf if history_index is not None: scores[history_index] = -np.inf return scores, positive_u, positive_i @torch.no_grad() def compute_item_feature(self, config, data): if self.use_text: item_data = BatchTextDataset(config, data) item_batch_size = config['MAX_ITEM_LIST_LENGTH'] * config['train_batch_size'] item_loader = DataLoader(item_data, batch_size=item_batch_size, num_workers=14, shuffle=False, pin_memory=True, collate_fn=customize_rmpad_collate) self.logger.info(f"Inference item_data with {item_batch_size = } {len(item_loader) = }") self.item_feature = [] with torch.no_grad(): for idx, items in tqdm(enumerate(item_loader), total=len(item_loader)): items = self.to_device(items) items = self.model(items, mode='compute_item') self.item_feature.append(items) if isinstance(items, tuple): self.item_feature = torch.cat([x[0] for x in self.item_feature]), torch.cat([x[1] for x in self.item_feature]) else: self.item_feature = torch.cat(self.item_feature) if self.config['stage'] == 3: self.item_feature = self.item_feature.bfloat16() else: with torch.no_grad(): self.item_feature = self.model.module.compute_item_all() def distributed_concat(self, tensor, num_total_examples): output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(output_tensors, tensor) concat = torch.cat(output_tensors, dim=0) return concat.sum() / num_total_examples def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progress=False, init_model=False): if not eval_data: return if init_model: world_size, local_world_size = int(os.environ['WORLD_SIZE']), int(os.environ['LOCAL_WORLD_SIZE']) nnodes = world_size // local_world_size if self.config['strategy'] == 'deepspeed': self.logger.info(f"Use deepspeed strategy") precision = self.config['precision'] if self.config['precision'] else '32' strategy = DeepSpeedStrategy( stage=self.config['stage'], precision=precision, exclude_frozen_parameters=self.config.get('exclude_frozen_parameters', True), ) self.lite = L.Fabric(accelerator='gpu', strategy=strategy, precision=precision, num_nodes=nnodes) self.lite.launch() self.model, self.optimizer = self.lite.setup(self.model, self.optimizer) else: self.logger.info(f"Use DDP strategy") precision = self.config['precision'] if self.config['precision'] else '32' strategy = DDPStrategy(find_unused_parameters=True) self.lite = L.Fabric(accelerator='gpu', strategy=strategy, precision=precision, num_nodes=nnodes) self.lite.launch() self.model = self.lite.setup(self.model) if load_best_model: checkpoint_file = model_file or self.saved_model_file state = {"model": self.model} self.lite.load(checkpoint_file, state, strict=False) message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file) self.logger.info(message_output) with torch.no_grad(): self.model.eval() eval_func = self._full_sort_batch_eval self.tot_item_num = eval_data.dataset.dataload.item_num self.compute_item_feature(self.config, eval_data.dataset.dataload) iter_data = ( tqdm( eval_data, total=len(eval_data), ncols=150, desc=set_color(f"Evaluate ", 'pink'), file=sys.stdout ) if show_progress and self.rank == 0 else eval_data ) fwd_time = t.time() for batch_idx, batched_data in enumerate(iter_data): start_time = fwd_time data_time = t.time() scores, positive_u, positive_i = eval_func(batched_data) fwd_time = t.time() if show_progress and self.rank == 0: iter_data.set_postfix_str(f"data: {data_time-start_time:.3f} fwd: {fwd_time-data_time:.3f}", refresh=False) self.eval_collector.eval_batch_collect(scores, positive_u, positive_i) num_total_examples = len(eval_data.sampler.dataset) struct = self.eval_collector.get_data_struct() result = self.evaluator.evaluate(struct) metric_decimal_place = 5 if self.config['metric_decimal_place'] == None else self.config['metric_decimal_place'] for k, v in result.items(): result_cpu = self.distributed_concat(torch.tensor([v]).to(self.device), num_total_examples).cpu() result[k] = round(result_cpu.item(), metric_decimal_place) self.wandblogger.log_eval_metrics(result, head='eval') return result