Source code for flambe.optim.linear

from flambe.optim.scheduler import LambdaLR


[docs]class WarmupLinearScheduler(LambdaLR): """Linear warmup and then linear decay. Linearly increases learning rate from 0 to 1 over `warmup` training steps. Linearly decreases learning rate from 1. to 0. over remaining `n_steps - warmup` steps. This scheduler is generally used after every training batch. """ def __init__(self, optimizer, warmup: int, n_steps: int): """Initialize the WarmupLinearScheduler. Parameters ---------- optimizer : torch.optim.Optimizer Wrapped optimizer. warmup : int The number of linear warmup phases n_steps : int, optional The index of last step. Default: -1 """ self.warmup = warmup self.n_steps = n_steps super().__init__(optimizer, lr_lambda=self.lr_lambda, last_epoch=-1) # type: ignore
[docs] def lr_lambda(self, step: int) -> float: """Compue the learning rate factor. Parameters ---------- step : int The current step. Could be training over validation steps. Returns ------- float The output factor """ if step < self.warmup: return float(step) / float(max(1, self.warmup)) return max(0.0, float(self.n_steps - step) / float(max(1.0, self.n_steps - self.warmup)))