dawn-bench-models/pytorch/CIFAR10/benchmark/yellowfin.py

384 lines
14 KiB
Python

import math
import numpy as np
import torch
# eps for numerical stability
eps = 1e-15
class YFOptimizer(object):
def __init__(self, var_list, lr=0.1, mu=0.0, clip_thresh=None, weight_decay=0.0,
beta=0.999, curv_win_width=20, zero_debias=True, sparsity_debias=True, delta_mu=0.0,
auto_clip_fac=None, force_non_inc_step=False):
'''
clip thresh is the threshold value on ||lr * gradient||
delta_mu can be place holder/variable/python scalar. They are used for additional
momentum in situations such as asynchronous-parallel training. The default is 0.0
for basic usage of the optimizer.
Args:
lr: python scalar. The initial value of learning rate, we use 1.0 in our paper.
mu: python scalar. The initial value of momentum, we use 0.0 in our paper.
clip_thresh: python scalar. The manaully-set clipping threshold for tf.clip_by_global_norm.
if None, the automatic clipping can be carried out. The automatic clipping
feature is parameterized by argument auto_clip_fac. The auto clip feature
can be switched off with auto_clip_fac = None
beta: python scalar. The smoothing parameter for estimations.
sparsity_debias: gradient norm and curvature are biased to larger values when
calculated with sparse gradient. This is useful when the model is very sparse,
e.g. LSTM with word embedding. For non-sparse CNN, turning it off could slightly
accelerate the speed.
delta_mu: for extensions. Not necessary in the basic use.
force_non_inc_step: in some very rare cases, it is necessary to force ||lr * gradient||
to be not increasing dramatically for stableness after some iterations.
In practice, if turned on, we enforce lr * sqrt(smoothed ||grad||^2)
to be less than 2x of the minimal value of historical value on smoothed || lr * grad ||.
This feature is turned off by default.
Other features:
If you want to manually control the learning rates, self.lr_factor is
an interface to the outside, it is an multiplier for the internal learning rate
in YellowFin. It is helpful when you want to do additional hand tuning
or some decaying scheme to the tuned learning rate in YellowFin.
Example on using lr_factor can be found here:
https://github.com/JianGoForIt/YellowFin_Pytorch/blob/master/pytorch-cifar/main.py#L109
'''
self._lr = lr
self._mu = mu
# we convert var_list from generator to list so that
# it can be used for multiple times
self._var_list = list(var_list)
self._clip_thresh = clip_thresh
self._auto_clip_fac = auto_clip_fac
self._beta = beta
self._curv_win_width = curv_win_width
self._zero_debias = zero_debias
self._sparsity_debias = sparsity_debias
self._force_non_inc_step = force_non_inc_step
self._optimizer = torch.optim.SGD(self._var_list, lr=self._lr,
momentum=self._mu, weight_decay=weight_decay)
self._iter = 0
# global states are the statistics
self._global_state = {}
# for decaying learning rate and etc.
self._lr_factor = 1.0
def state_dict(self):
# for checkpoint saving
sgd_state_dict = self._optimizer.state_dict()
global_state = self._global_state
lr_factor = self._lr_factor
iter = self._iter
lr = self._lr
mu = self._mu
clip_thresh = self._clip_thresh
beta = self._beta
curv_win_width = self._curv_win_width
zero_debias = self._zero_debias
h_min = self._h_min
h_max = self._h_max
return {
"sgd_state_dict": sgd_state_dict,
"global_state": global_state,
"lr_factor": lr_factor,
"iter": iter,
"lr": lr,
"mu": mu,
"clip_thresh": clip_thresh,
"beta": beta,
"curv_win_width": curv_win_width,
"zero_debias": zero_debias,
"h_min": h_min,
"h_max": h_max
}
def load_state_dict(self, state_dict):
# for checkpoint saving
self._optimizer.load_state_dict(state_dict['sgd_state_dict'])
self._global_state = state_dict['global_state']
self._lr_factor = state_dict['lr_factor']
self._iter = state_dict['iter']
self._lr = state_dict['lr']
self._mu = state_dict['mu']
self._clip_thresh = state_dict['clip_thresh']
self._beta = state_dict['beta']
self._curv_win_width = state_dict['curv_win_width']
self._zero_debias = state_dict['zero_debias']
self._h_min = state_dict["h_min"]
self._h_max = state_dict["h_max"]
return
def set_lr_factor(self, factor):
self._lr_factor = factor
return
def get_lr_factor(self):
return self._lr_factor
def zero_grad(self):
self._optimizer.zero_grad()
return
def zero_debias_factor(self):
return 1.0 - self._beta ** (self._iter + 1)
def zero_debias_factor_delay(self, delay):
# for exponentially averaged stat which starts at non-zero iter
return 1.0 - self._beta ** (self._iter - delay + 1)
def curvature_range(self):
global_state = self._global_state
if self._iter == 0:
global_state["curv_win"] = torch.FloatTensor(self._curv_win_width, 1).zero_()
curv_win = global_state["curv_win"]
grad_norm_squared = self._global_state["grad_norm_squared"]
curv_win[self._iter % self._curv_win_width] = np.log(grad_norm_squared + eps)
valid_end = min(self._curv_win_width, self._iter + 1)
# we use running average over log scale, accelerating
# h_max / min in the begining to follow the varying trend of curvature.
beta = self._beta
if self._iter == 0:
global_state["h_min_avg"] = 0.0
global_state["h_max_avg"] = 0.0
self._h_min = 0.0
self._h_max = 0.0
global_state["h_min_avg"] = \
global_state["h_min_avg"] * beta + (1 - beta) * torch.min(curv_win[:valid_end] )
global_state["h_max_avg"] = \
global_state["h_max_avg"] * beta + (1 - beta) * torch.max(curv_win[:valid_end] )
if self._zero_debias:
debias_factor = self.zero_debias_factor()
self._h_min = np.exp(global_state["h_min_avg"] / debias_factor)
self._h_max = np.exp(global_state["h_max_avg"] / debias_factor)
else:
self._h_min = np.exp(global_state["h_min_avg"] )
self._h_max = np.exp(global_state["h_max_avg"] )
if self._sparsity_debias:
self._h_min *= self._sparsity_avg
self._h_max *= self._sparsity_avg
return
def grad_variance(self):
global_state = self._global_state
beta = self._beta
self._grad_var = np.array(0.0, dtype=np.float32)
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self._optimizer.state[p]
if self._iter == 0:
state["grad_avg"] = grad.new().resize_as_(grad).zero_()
state["grad_avg_squared"] = 0.0
state["grad_avg"].mul_(beta).add_(1 - beta, grad)
self._grad_var += torch.sum(state["grad_avg"] * state["grad_avg"] )
if self._zero_debias:
debias_factor = self.zero_debias_factor()
else:
debias_factor = 1.0
self._grad_var /= -(debias_factor**2)
self._grad_var += global_state['grad_norm_squared_avg'] / debias_factor
# in case of negative variance: the two term are using different debias factors
self._grad_var = max(self._grad_var, eps)
if self._sparsity_debias:
self._grad_var *= self._sparsity_avg
return
def dist_to_opt(self):
global_state = self._global_state
beta = self._beta
if self._iter == 0:
global_state["grad_norm_avg"] = 0.0
global_state["dist_to_opt_avg"] = 0.0
global_state["grad_norm_avg"] = \
global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] )
global_state["dist_to_opt_avg"] = \
global_state["dist_to_opt_avg"] * beta \
+ (1 - beta) * global_state["grad_norm_avg"] / (global_state['grad_norm_squared_avg'] + eps)
if self._zero_debias:
debias_factor = self.zero_debias_factor()
self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor
else:
self._dist_to_opt = global_state["dist_to_opt_avg"]
if self._sparsity_debias:
self._dist_to_opt /= (np.sqrt(self._sparsity_avg) + eps)
return
def grad_sparsity(self):
global_state = self._global_state
if self._iter == 0:
global_state["sparsity_avg"] = 0.0
non_zero_cnt = 0.0
all_entry_cnt = 0.0
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad_non_zero = grad.nonzero()
if grad_non_zero.dim() > 0:
non_zero_cnt += grad_non_zero.size()[0]
all_entry_cnt += torch.numel(grad)
beta = self._beta
global_state["sparsity_avg"] = beta * global_state["sparsity_avg"] \
+ (1 - beta) * non_zero_cnt / float(all_entry_cnt)
self._sparsity_avg = \
global_state["sparsity_avg"] / self.zero_debias_factor()
return
def lr_grad_norm_avg(self):
# this is for enforcing lr * grad_norm not
# increasing dramatically in case of instability.
# Not necessary for basic use.
global_state = self._global_state
beta = self._beta
if "lr_grad_norm_avg" not in global_state:
global_state['grad_norm_squared_avg_log'] = 0.0
global_state['grad_norm_squared_avg_log'] = \
global_state['grad_norm_squared_avg_log'] * beta \
+ (1 - beta) * np.log(global_state['grad_norm_squared'] + eps)
if "lr_grad_norm_avg" not in global_state:
global_state["lr_grad_norm_avg"] = \
0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
# we monitor the minimal smoothed ||lr * grad||
global_state["lr_grad_norm_avg_min"] = \
np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() )
else:
global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \
+ (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
global_state["lr_grad_norm_avg_min"] = \
min(global_state["lr_grad_norm_avg_min"],
np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) )
def after_apply(self):
# compute running average of gradient and norm of gradient
beta = self._beta
global_state = self._global_state
if self._iter == 0:
global_state["grad_norm_squared_avg"] = 0.0
global_state["grad_norm_squared"] = 0.0
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
global_state['grad_norm_squared'] += torch.sum(grad * grad)
global_state['grad_norm_squared_avg'] = \
global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared']
if self._sparsity_debias:
self.grad_sparsity()
self.curvature_range()
self.grad_variance()
self.dist_to_opt()
if self._iter > 0:
self.get_mu()
self.get_lr()
self._lr = beta * self._lr + (1 - beta) * self._lr_t
self._mu = beta * self._mu + (1 - beta) * self._mu_t
return
def get_lr(self):
self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / (self._h_min + eps)
return
def get_cubic_root(self):
# We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2
# where x = sqrt(mu).
# We substitute x, which is sqrt(mu), with x = y + 1.
# It gives y^3 + py = q
# where p = (D^2 h_min^2)/(2*C) and q = -p.
# We use the Vieta's substution to compute the root.
# There is only one real solution y (which is in [0, 1] ).
# http://mathworld.wolfram.com/VietasSubstitution.html
# eps in the numerator is to prevent momentum = 1 in case of zero gradient
p = (self._dist_to_opt + eps)**2 * (self._h_min + eps)**2 / 2 / (self._grad_var + eps)
w3 = (-math.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0
w = math.copysign(1.0, w3) * math.pow(math.fabs(w3), 1.0/3.0)
y = w - p / 3.0 / (w + eps)
x = y + 1
return x
def get_mu(self):
root = self.get_cubic_root()
dr = self._h_max / self._h_min
self._mu_t = max(root**2, ( (np.sqrt(dr) - 1) / (np.sqrt(dr) + 1) )**2 )
return
def update_hyper_param(self):
for group in self._optimizer.param_groups:
group['momentum'] = self._mu
if self._force_non_inc_step == False:
group['lr'] = self._lr * self._lr_factor
elif self._iter > self._curv_win_width:
# force to guarantee lr * grad_norm not increasing dramatically.
# Not necessary for basic use. Please refer to the comments
# in YFOptimizer.__init__ for more details
self.lr_grad_norm_avg()
debias_factor = self.zero_debias_factor()
group['lr'] = min(self._lr * self._lr_factor,
2.0 * self._global_state["lr_grad_norm_avg_min"] \
/ np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) )
return
def auto_clip_thresh(self):
# Heuristic to automatically prevent sudden exploding gradient
# Not necessary for basic use.
return math.sqrt(self._h_max) * self._auto_clip_fac
def step(self):
# add weight decay
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
if self._clip_thresh != None:
torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh)
elif (self._iter != 0 and self._auto_clip_fac != None):
# do not clip the first iteration
torch.nn.utils.clip_grad_norm(self._var_list, self.auto_clip_thresh() )
# apply update
self._optimizer.step()
# after appply
self.after_apply()
# update learning rate and momentum
self.update_hyper_param()
self._iter += 1
return