import math import numpy as np import torch # eps for numerical stability eps = 1e-15 class YFOptimizer(object): def __init__(self, var_list, lr=0.1, mu=0.0, clip_thresh=None, weight_decay=0.0, beta=0.999, curv_win_width=20, zero_debias=True, sparsity_debias=True, delta_mu=0.0, auto_clip_fac=None, force_non_inc_step=False): ''' clip thresh is the threshold value on ||lr * gradient|| delta_mu can be place holder/variable/python scalar. They are used for additional momentum in situations such as asynchronous-parallel training. The default is 0.0 for basic usage of the optimizer. Args: lr: python scalar. The initial value of learning rate, we use 1.0 in our paper. mu: python scalar. The initial value of momentum, we use 0.0 in our paper. clip_thresh: python scalar. The manaully-set clipping threshold for tf.clip_by_global_norm. if None, the automatic clipping can be carried out. The automatic clipping feature is parameterized by argument auto_clip_fac. The auto clip feature can be switched off with auto_clip_fac = None beta: python scalar. The smoothing parameter for estimations. sparsity_debias: gradient norm and curvature are biased to larger values when calculated with sparse gradient. This is useful when the model is very sparse, e.g. LSTM with word embedding. For non-sparse CNN, turning it off could slightly accelerate the speed. delta_mu: for extensions. Not necessary in the basic use. force_non_inc_step: in some very rare cases, it is necessary to force ||lr * gradient|| to be not increasing dramatically for stableness after some iterations. In practice, if turned on, we enforce lr * sqrt(smoothed ||grad||^2) to be less than 2x of the minimal value of historical value on smoothed || lr * grad ||. This feature is turned off by default. Other features: If you want to manually control the learning rates, self.lr_factor is an interface to the outside, it is an multiplier for the internal learning rate in YellowFin. It is helpful when you want to do additional hand tuning or some decaying scheme to the tuned learning rate in YellowFin. Example on using lr_factor can be found here: https://github.com/JianGoForIt/YellowFin_Pytorch/blob/master/pytorch-cifar/main.py#L109 ''' self._lr = lr self._mu = mu # we convert var_list from generator to list so that # it can be used for multiple times self._var_list = list(var_list) self._clip_thresh = clip_thresh self._auto_clip_fac = auto_clip_fac self._beta = beta self._curv_win_width = curv_win_width self._zero_debias = zero_debias self._sparsity_debias = sparsity_debias self._force_non_inc_step = force_non_inc_step self._optimizer = torch.optim.SGD(self._var_list, lr=self._lr, momentum=self._mu, weight_decay=weight_decay) self._iter = 0 # global states are the statistics self._global_state = {} # for decaying learning rate and etc. self._lr_factor = 1.0 def state_dict(self): # for checkpoint saving sgd_state_dict = self._optimizer.state_dict() global_state = self._global_state lr_factor = self._lr_factor iter = self._iter lr = self._lr mu = self._mu clip_thresh = self._clip_thresh beta = self._beta curv_win_width = self._curv_win_width zero_debias = self._zero_debias h_min = self._h_min h_max = self._h_max return { "sgd_state_dict": sgd_state_dict, "global_state": global_state, "lr_factor": lr_factor, "iter": iter, "lr": lr, "mu": mu, "clip_thresh": clip_thresh, "beta": beta, "curv_win_width": curv_win_width, "zero_debias": zero_debias, "h_min": h_min, "h_max": h_max } def load_state_dict(self, state_dict): # for checkpoint saving self._optimizer.load_state_dict(state_dict['sgd_state_dict']) self._global_state = state_dict['global_state'] self._lr_factor = state_dict['lr_factor'] self._iter = state_dict['iter'] self._lr = state_dict['lr'] self._mu = state_dict['mu'] self._clip_thresh = state_dict['clip_thresh'] self._beta = state_dict['beta'] self._curv_win_width = state_dict['curv_win_width'] self._zero_debias = state_dict['zero_debias'] self._h_min = state_dict["h_min"] self._h_max = state_dict["h_max"] return def set_lr_factor(self, factor): self._lr_factor = factor return def get_lr_factor(self): return self._lr_factor def zero_grad(self): self._optimizer.zero_grad() return def zero_debias_factor(self): return 1.0 - self._beta ** (self._iter + 1) def zero_debias_factor_delay(self, delay): # for exponentially averaged stat which starts at non-zero iter return 1.0 - self._beta ** (self._iter - delay + 1) def curvature_range(self): global_state = self._global_state if self._iter == 0: global_state["curv_win"] = torch.FloatTensor(self._curv_win_width, 1).zero_() curv_win = global_state["curv_win"] grad_norm_squared = self._global_state["grad_norm_squared"] curv_win[self._iter % self._curv_win_width] = np.log(grad_norm_squared + eps) valid_end = min(self._curv_win_width, self._iter + 1) # we use running average over log scale, accelerating # h_max / min in the begining to follow the varying trend of curvature. beta = self._beta if self._iter == 0: global_state["h_min_avg"] = 0.0 global_state["h_max_avg"] = 0.0 self._h_min = 0.0 self._h_max = 0.0 global_state["h_min_avg"] = \ global_state["h_min_avg"] * beta + (1 - beta) * torch.min(curv_win[:valid_end] ) global_state["h_max_avg"] = \ global_state["h_max_avg"] * beta + (1 - beta) * torch.max(curv_win[:valid_end] ) if self._zero_debias: debias_factor = self.zero_debias_factor() self._h_min = np.exp(global_state["h_min_avg"] / debias_factor) self._h_max = np.exp(global_state["h_max_avg"] / debias_factor) else: self._h_min = np.exp(global_state["h_min_avg"] ) self._h_max = np.exp(global_state["h_max_avg"] ) if self._sparsity_debias: self._h_min *= self._sparsity_avg self._h_max *= self._sparsity_avg return def grad_variance(self): global_state = self._global_state beta = self._beta self._grad_var = np.array(0.0, dtype=np.float32) for group in self._optimizer.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self._optimizer.state[p] if self._iter == 0: state["grad_avg"] = grad.new().resize_as_(grad).zero_() state["grad_avg_squared"] = 0.0 state["grad_avg"].mul_(beta).add_(1 - beta, grad) self._grad_var += torch.sum(state["grad_avg"] * state["grad_avg"] ) if self._zero_debias: debias_factor = self.zero_debias_factor() else: debias_factor = 1.0 self._grad_var /= -(debias_factor**2) self._grad_var += global_state['grad_norm_squared_avg'] / debias_factor # in case of negative variance: the two term are using different debias factors self._grad_var = max(self._grad_var, eps) if self._sparsity_debias: self._grad_var *= self._sparsity_avg return def dist_to_opt(self): global_state = self._global_state beta = self._beta if self._iter == 0: global_state["grad_norm_avg"] = 0.0 global_state["dist_to_opt_avg"] = 0.0 global_state["grad_norm_avg"] = \ global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] ) global_state["dist_to_opt_avg"] = \ global_state["dist_to_opt_avg"] * beta \ + (1 - beta) * global_state["grad_norm_avg"] / (global_state['grad_norm_squared_avg'] + eps) if self._zero_debias: debias_factor = self.zero_debias_factor() self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor else: self._dist_to_opt = global_state["dist_to_opt_avg"] if self._sparsity_debias: self._dist_to_opt /= (np.sqrt(self._sparsity_avg) + eps) return def grad_sparsity(self): global_state = self._global_state if self._iter == 0: global_state["sparsity_avg"] = 0.0 non_zero_cnt = 0.0 all_entry_cnt = 0.0 for group in self._optimizer.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data grad_non_zero = grad.nonzero() if grad_non_zero.dim() > 0: non_zero_cnt += grad_non_zero.size()[0] all_entry_cnt += torch.numel(grad) beta = self._beta global_state["sparsity_avg"] = beta * global_state["sparsity_avg"] \ + (1 - beta) * non_zero_cnt / float(all_entry_cnt) self._sparsity_avg = \ global_state["sparsity_avg"] / self.zero_debias_factor() return def lr_grad_norm_avg(self): # this is for enforcing lr * grad_norm not # increasing dramatically in case of instability. # Not necessary for basic use. global_state = self._global_state beta = self._beta if "lr_grad_norm_avg" not in global_state: global_state['grad_norm_squared_avg_log'] = 0.0 global_state['grad_norm_squared_avg_log'] = \ global_state['grad_norm_squared_avg_log'] * beta \ + (1 - beta) * np.log(global_state['grad_norm_squared'] + eps) if "lr_grad_norm_avg" not in global_state: global_state["lr_grad_norm_avg"] = \ 0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps) # we monitor the minimal smoothed ||lr * grad|| global_state["lr_grad_norm_avg_min"] = \ np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) else: global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \ + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps) global_state["lr_grad_norm_avg_min"] = \ min(global_state["lr_grad_norm_avg_min"], np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) ) def after_apply(self): # compute running average of gradient and norm of gradient beta = self._beta global_state = self._global_state if self._iter == 0: global_state["grad_norm_squared_avg"] = 0.0 global_state["grad_norm_squared"] = 0.0 for group in self._optimizer.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data global_state['grad_norm_squared'] += torch.sum(grad * grad) global_state['grad_norm_squared_avg'] = \ global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared'] if self._sparsity_debias: self.grad_sparsity() self.curvature_range() self.grad_variance() self.dist_to_opt() if self._iter > 0: self.get_mu() self.get_lr() self._lr = beta * self._lr + (1 - beta) * self._lr_t self._mu = beta * self._mu + (1 - beta) * self._mu_t return def get_lr(self): self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / (self._h_min + eps) return def get_cubic_root(self): # We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2 # where x = sqrt(mu). # We substitute x, which is sqrt(mu), with x = y + 1. # It gives y^3 + py = q # where p = (D^2 h_min^2)/(2*C) and q = -p. # We use the Vieta's substution to compute the root. # There is only one real solution y (which is in [0, 1] ). # http://mathworld.wolfram.com/VietasSubstitution.html # eps in the numerator is to prevent momentum = 1 in case of zero gradient p = (self._dist_to_opt + eps)**2 * (self._h_min + eps)**2 / 2 / (self._grad_var + eps) w3 = (-math.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0 w = math.copysign(1.0, w3) * math.pow(math.fabs(w3), 1.0/3.0) y = w - p / 3.0 / (w + eps) x = y + 1 return x def get_mu(self): root = self.get_cubic_root() dr = self._h_max / self._h_min self._mu_t = max(root**2, ( (np.sqrt(dr) - 1) / (np.sqrt(dr) + 1) )**2 ) return def update_hyper_param(self): for group in self._optimizer.param_groups: group['momentum'] = self._mu if self._force_non_inc_step == False: group['lr'] = self._lr * self._lr_factor elif self._iter > self._curv_win_width: # force to guarantee lr * grad_norm not increasing dramatically. # Not necessary for basic use. Please refer to the comments # in YFOptimizer.__init__ for more details self.lr_grad_norm_avg() debias_factor = self.zero_debias_factor() group['lr'] = min(self._lr * self._lr_factor, 2.0 * self._global_state["lr_grad_norm_avg_min"] \ / np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) ) return def auto_clip_thresh(self): # Heuristic to automatically prevent sudden exploding gradient # Not necessary for basic use. return math.sqrt(self._h_max) * self._auto_clip_fac def step(self): # add weight decay for group in self._optimizer.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if group['weight_decay'] != 0: grad = grad.add(group['weight_decay'], p.data) if self._clip_thresh != None: torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) elif (self._iter != 0 and self._auto_clip_fac != None): # do not clip the first iteration torch.nn.utils.clip_grad_norm(self._var_list, self.auto_clip_thresh() ) # apply update self._optimizer.step() # after appply self.after_apply() # update learning rate and momentum self.update_hyper_param() self._iter += 1 return