Source code for flambe.optim.noam

from flambe.optim.scheduler import LambdaLR


[docs]class NoamScheduler(LambdaLR): """Linear warmup and then quadratic decay. Linearly increases the learning rate from 0 to 1 over `warmup` steps. Quadratically decreases the learning rate after. This scheduler is generally used after every training batch. """ def __init__(self, optimizer, warmup: int, d_model: int): """Initialize the NoamScheduler. Parameters ---------- optimizer : torch.optim.Optimizer Wrapped optimizer. warmup : int The number of linear warmup phases d_model : int, optional The index of last step. Default: -1 """ self.warmup = warmup self.d_model = d_model super().__init__(optimizer, lr_lambda=self.lr_lambda, last_epoch=-1) # type: ignore
[docs] def lr_lambda(self, step: int) -> float: """Compue the learning rate factor. Parameters ---------- step : int The current step. Could be training over validation steps. Returns ------- float The output factor """ if step == 0 and self.warmup == 0: return 1. / (self.d_model ** 0.5) else: if step > self.warmup: return 1. / (self.d_model ** 0.5) / (step ** 0.5) else: return step / (self.d_model ** 0.5) / (self.warmup ** 1.5)