LookaHead优化器
欧陆资讯 | 2024-06-24 13:29 以下是使用 PyTorch 实现的深度学习优化器 Ranger 的代码:
```python
import math
from torch.optim.optimizer import Optimizer
import torch.optim as optim
class Ranger(Optimizer):
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(0.95, 0.999), eps=1e-5, weight_decay=0):
defaults=dict(lr=lr, alpha=alpha, k=k, N_sma_threshhold=N_sma_threshhold, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
def step(self, closure=None):
loss=None
if closure is not None:
loss=closure()
# Gradient centralization
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad=p.grad.data
if grad.is_sparse:
raise RuntimeError('Ranger optimizer does not support sparse gradients')
grad_data=grad.data
if len(grad_data.shape) > 1:
mean=torch.mean(grad_data, dim=tuple(range(1, len(grad_data.shape))), keepdim=True)
var=torch.var(grad_data, dim=tuple(range(1, len(grad_data.shape))), keepdim=True)
grad_data=(grad_data - mean) / (torch.sqrt(var) + group['eps'])
p.grad.data=grad_data
# Perform optimization step
beta1, beta2=group['betas']
N_sma_threshhold=group['N_sma_threshhold']
grad_ema_beta=1 - beta1
sqr_ema_beta=1 - beta2
step_size=group['lr']
eps=group['eps']
k=group['k']
alpha=group['alpha']
weight_decay=group['weight_decay']
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad=p.grad.data
if grad.is_sparse:
raise RuntimeError('Ranger optimizer does not support sparse gradients')
state=self.state[p]
# State initialization
if len(state)==0:
state['step']=0
state['exp_avg']=torch.zeros_like(p.data)
state['exp_avg_sq']=torch.zeros_like(p.data)
state['SMA']=0
exp_avg, exp_avg_sq=state['exp_avg'], state['exp_avg_sq']
SMA=state['SMA']
state['step'] +=1
# Gradient centralization
grad_data=grad.data
if len(grad_data.shape) > 1:
mean=torch.mean(grad_data, dim=tuple(range(1, len(grad_data.shape))), keepdim=True)
var=torch.var(grad_data, dim=tuple(range(1, len(grad_data.shape))), keepdim=True)
grad_data=(grad_data - mean) / (torch.sqrt(var) + eps)
grad=grad_data
bias_correction1=1 - beta1 ** state['step']
bias_correction2=1 - beta2 ** state['step']
step_size=step_size * math.sqrt(bias_correction2) / bias_correction1
# Compute exponential moving average of gradient and squared gradient
exp_avg=beta1 * exp_avg + grad_ema_beta * grad
exp_avg_sq=beta2 * exp_avg_sq + sqr_ema_beta * grad * grad
# Compute SMA
SMA_prev=SMA
SMA=alpha * SMA + (1 - alpha) * exp_avg_sq.mean()
# Update parameters
if state['step'] <=k:
# Warmup
p.data.add_(-step_size * exp_avg / (torch.sqrt(exp_avg_sq) + eps))
else:
if SMA > SMA_prev or state['step'] <=N_sma_threshhold:
# If SMA is increasing, skip lookahead and perform RAdam step
denom=torch.sqrt(exp_avg_sq) + eps
p.data.add_(-step_size * exp_avg / denom)
else:
# Lookahead
slow_state=state['slow_buffer']
if len(slow_state)==0:
slow_state['step']=0
slow_state['exp_avg']=torch.zeros_like(p.data)
slow_state['exp_avg_sq']=torch.zeros_like(p.data)
slow_state['SMA']=0
for key in state.keys():
if key !='slow_buffer':
slow_state[key]=state[key].clone()
slow_exp_avg, slow_exp_avg_sq=slow_state['exp_avg'], slow_state['exp_avg_sq']
slow_SMA=slow_state['SMA']
slow_state['step'] +=1
# Gradient centralization
grad_data=grad.data
if len(grad_data.shape) > 1:
mean=torch.mean(grad_data, dim=tuple(range(1, len(grad_data.shape))), keepdim=True)
var=torch.var(grad_data, dim=tuple(range(1, len(grad_data.shape))), keepdim=True)
grad_data=(grad_data - mean) / (torch.sqrt(var) + eps)
grad=grad_data
# Compute exponential moving average of gradient and squared gradient
slow_exp_avg=beta1 * slow_exp_avg + grad_ema_beta * grad
slow_exp_avg_sq=beta2 * slow_exp_avg_sq + sqr_ema_beta * grad * grad
# Compute SMA
slow_SMA_prev=slow_SMA
slow_SMA=alpha * slow_SMA + (1 - alpha) * slow_exp_avg_sq.mean()
# Update parameters
if slow_state['step'] <=k:
# Warmup
pass
else:
if slow_SMA > slow_SMA_prev or slow_state['step'] <=N_sma_threshhold:
# If SMA is increasing, skip lookahead and perform RAdam step
denom=torch.sqrt(slow_exp_avg_sq) + eps
p.data.add_(-step_size * slow_exp_avg / denom)
else:
# Lookahead
p.data.add_(-step_size * (exp_avg + slow_exp_avg) / (2 * torch.sqrt((beta2 * exp_avg_sq + sqr_ema_beta * slow_exp_avg_sq) / (1 - bias_correction2 ** state['step'])) + eps))
# Weight decay
if weight_decay !=0:
p.data.add_(-step_size * weight_decay * p.data)
return loss
```
以上的代码实现了 Ranger 优化器,其中包括了 RAdam 和 LookAhead 的结合,以及动态学习率和权重衰减等技巧。可以将其应用于 PyTorch 中的深度学习模型训练中。