diff --git a/training/lr_scheduler.py b/training/lr_scheduler.py index cbbac5a..9af43bc 100644 --- a/training/lr_scheduler.py +++ b/training/lr_scheduler.py @@ -27,22 +27,11 @@ def get_lr(self): UserWarning, ) - if self.last_epoch == 0 or self.last_epoch > self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] - - if self.last_epoch <= self.warmup_iters: - return [ - base_lr * self.last_epoch / self.warmup_iters - for base_lr in self.base_lrs - ] - else: - l = self.last_epoch - w = self.warmup_iters - t = self.total_iters - decay_factor = ( - (1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w)) - ) ** self.power - return [(1/2 * group["lr"]) * (1+decay_factor) for group in self.optimizer.param_groups] + # Use closed-form calculation for correctness and numerical stability + # Previous implementation had a formula bug: (1/2 * lr) * (1 + decay_factor) + # is NOT equivalent to polynomial decay. It causes slower decay than intended. + # The closed-form directly computes: base_lr * (1 - progress)^power + return self._get_closed_form_lr() def _get_closed_form_lr(self): if self.last_epoch <= self.warmup_iters: