# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup
from tinygrad.tensor import Tensor
[docs]
class Optimizer:
"""
The optimizer class for updating parameters with gradients.
Attributes:
params (List[Tensor]): List of parameters to optimize. These parameters are assumed to be differentiable, and their gradient will be used during optimization.
lr (float): Learning rate for the optimizer
device: Device where the parameters are located
buffers (List[Tensor]): List of non-differentiable parameters (or buffers) that need to be realized along with the parameters. These typically include batch statistics in BatchNorm, running averages in Adam, etc.
"""
def __init__(self, params: List[Tensor], lr: float):
"""
Initializes the optimizer with a list of parameters and a learning rate.
If requires_grad attribute of a parameter is None, sets it to True. This is because when a parameter is put into an optimizer, it is assumed that user wants to optimize it.
Deduplicates the parameters that require gradient (i.e., removes duplicate parameters). Asserts that there is at least one parameter to optimize.
Determines the device where the parameters are located and realizes all the parameters and buffers on this device.
Args:
params (List[Tensor]): List of parameters to optimize.
lr (float): Learning rate for the optimizer.
"""
# if it's None, but being put into an optimizer, set it to True
for x in params:
if x.requires_grad is None:
x.requires_grad = True
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: List[Tensor] = dedup(
[x for x in params if not x.requires_grad]
) # buffers are still realized
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()
[docs]
def zero_grad(self):
"""
Sets the gradient of each optimized parameter to None. This is called at the beginning of each iteration.
"""
for param in self.params:
param.grad = None
[docs]
def realize(self, extra=None):
"""
Realizes all the parameters and buffers on the device. If extra parameters are provided, they will be realized as well.
Args:
extra (List[Tensor], optional): Extra parameters that need to be realized. Defaults to None.
"""
# NOTE: in extra is too late for most of the params due to issues with assign
Tensor.corealize(
extra + self.params + self.buffers
if extra is not None
else self.params + self.buffers
)
[docs]
class SGD(Optimizer):
"""
Implements stochastic gradient descent (optionally with momentum).
Attributes:
params (List[Tensor]): The parameters to optimize.
lr (float): Learning rate. Default value is 0.001.
momentum (float): Momentum factor. Default value is 0. If set to None, no momentum will be used.
weight_decay (float): Weight decay (L2 penalty). Default value is 0.
nesterov (bool): Enables Nesterov momentum. Default value is False.
"""
def __init__(
self,
params: List[Tensor],
lr=0.001,
momentum=0,
weight_decay=0.0,
nesterov=False,
):
super().__init__(params, lr)
self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov
self.b = (
[
Tensor.zeros(*t.shape, device=t.device, requires_grad=False)
for t in self.params
]
if self.momentum
else []
)
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
[docs]
def step(self) -> None:
"""
Performs a single optimization step.
This method implements the weight update procedure of stochastic gradient descent with momentum and weight decay.
The actual updates are done in-place, so it's safe to discard the return value.
"""
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize() + self.wd * t.detach()
if self.momentum:
self.b[i].assign(
self.momentum * self.b[i] + g
).realize() # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W.
[docs]
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01):
"""
Function to implement the AdamW optimizer.
This function is essentially just the trust ratio part of LARS applied to Adam/W. If we just set the trust ratio to 1.0 its just Adam/W.
Attributes:
params (List[Tensor]): A list of Tensors that will be updated by the optimizer.
lr (float): The learning rate. Default is 0.001.
b1 (float): Exponential decay rate for the first moment estimates. Default is 0.9.
b2 (float): Exponential decay rate for the second moment estimates. Default is 0.999.
eps (float): A small constant added to the denominator to prevent division by zero. Default is 1e-8.
wd (float): Weight decay parameter. Default is 0.01.
"""
return LAMB(params, lr, b1, b2, eps, wd, adam=True)
[docs]
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
"""
Create an Adam optimizer.
This method creates an Adam optimizer for the given parameters with the following default values:
Attributes:
lr (float): Learning rate. Default is 0.001.
b1 (float): First beta parameter. Default is 0.9.
b2 (float): Second beta parameter. Default is 0.999.
eps (float): Epsilon value to prevent division by zero. Default is 1e-8.
Returns:
LAMB: The LAMB optimizer with Adam settings.
"""
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)
[docs]
class LAMB(Optimizer):
"""
LAMB optimizer class.
Attributes:
params (List[Tensor]): The list of parameters to optimize.
lr (float): The learning rate. Defaults to 0.001.
b1 (float): The exponential decay rate for the first moment estimates. Defaults to 0.9.
b2 (float): The exponential decay rate for the second moment estimates. Defaults to 0.999.
eps (float): A small constant for numerical stability. Defaults to 1e-6.
wd (float): The weight decay coefficient. Defaults to 0.0.
adam (bool): Whether to use the Adam optimizer. Defaults to False.
"""
def __init__(
self,
params: List[Tensor],
lr=0.001,
b1=0.9,
b2=0.999,
eps=1e-6,
wd=0.0,
adam=False,
):
super().__init__(params, lr)
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = (
b1,
b2,
eps,
wd,
adam,
Tensor([0], requires_grad=False).realize(),
)
self.m = [
Tensor.zeros(*t.shape, device=t.device, requires_grad=False)
for t in self.params
]
self.v = [
Tensor.zeros(*t.shape, device=t.device, requires_grad=False)
for t in self.params
]
[docs]
def step(self) -> None:
"""
Perform one optimization step.
This method updates the parameters according to the LAMB or Adam algorithm.
"""
self.t.assign(self.t + 1).realize()
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize()
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
m_hat = self.m[i] / (1.0 - self.b1**self.t)
v_hat = self.v[i] / (1.0 - self.b2**self.t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)