On this page
SGD
class torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)[source]-
Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from On the importance of initialization and momentum in deep learning.
- Parameters
-
- params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
- lr (float) – learning rate
- momentum (float, optional) – momentum factor (default: 0)
- weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
- dampening (float, optional) – dampening for momentum (default: 0)
- nesterov (bool, optional) – enables Nesterov momentum (default: False)
- maximize (bool, optional) – maximize the params based on the objective, instead of minimizing (default: False)
- foreach (bool, optional) – whether foreach implementation of optimizer is used. If unspecified by the user (so foreach is None), we will try to use foreach over the for-loop implementation on CUDA, since it is usually significantly more performant. Note that the foreach implementation uses ~ sizeof(params) more peak memory than the for-loop version due to the intermediates being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer parameters through the optimizer at a time or switch this flag to False (default: None)
- differentiable (bool, optional) – whether autograd should occur through the optimizer step in training. Otherwise, the step() function runs in a torch.no_grad() context. Setting to True can impair performance, so leave it False if you don’t intend to run autograd through this instance (default: False)
Example
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()Note
The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
where , , and denote the parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form
The Nesterov version is analogously modified.
Moreover, the initial value of the momentum buffer is set to the gradient value at the first step. This is in contrast to some other frameworks that initialize it to all zeros.
add_param_group(param_group)-
Add a param group to the
Optimizersparam_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizeras training progresses.- Parameters
-
param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.
load_state_dict(state_dict)-
Loads the optimizer state.
- Parameters
-
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict().
register_load_state_dict_post_hook(hook, prepend=False)-
Register a load_state_dict post-hook which will be called after
load_state_dict()is called. It should have the following signature:hook(optimizer) -> NoneThe
optimizerargument is the optimizer instance being used.The hook will be called with argument
selfafter callingload_state_dictonself. The registered hook can be used to perform post-processing afterload_state_dicthas loaded thestate_dict.- Parameters
-
- hook (Callable) – The user defined hook to be registered.
- prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns
-
a handle that can be used to remove the added hook by calling
handle.remove() - Return type
-
torch.utils.hooks.RemoveableHandle
register_load_state_dict_pre_hook(hook, prepend=False)-
Register a load_state_dict pre-hook which will be called before
load_state_dict()is called. It should have the following signature:hook(optimizer, state_dict) -> state_dict or NoneThe
optimizerargument is the optimizer instance being used and thestate_dictargument is a shallow copy of thestate_dictthe user passed in toload_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.The hook will be called with argument
selfandstate_dictbefore callingload_state_dictonself. The registered hook can be used to perform pre-processing before theload_state_dictcall is made.- Parameters
-
- hook (Callable) – The user defined hook to be registered.
- prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns
-
a handle that can be used to remove the added hook by calling
handle.remove() - Return type
-
torch.utils.hooks.RemoveableHandle
register_state_dict_post_hook(hook, prepend=False)-
Register a state dict post-hook which will be called after
state_dict()is called. It should have the following signature:hook(optimizer, state_dict) -> state_dict or NoneThe hook will be called with arguments
selfandstate_dictafter generating astate_dictonself. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on thestate_dictbefore it is returned.- Parameters
-
- hook (Callable) – The user defined hook to be registered.
- prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns
-
a handle that can be used to remove the added hook by calling
handle.remove() - Return type
-
torch.utils.hooks.RemoveableHandle
register_state_dict_pre_hook(hook, prepend=False)-
Register a state dict pre-hook which will be called before
state_dict()is called. It should have the following signature:hook(optimizer) -> NoneThe
optimizerargument is the optimizer instance being used. The hook will be called with argumentselfbefore callingstate_dictonself. The registered hook can be used to perform pre-processing before thestate_dictcall is made.- Parameters
-
- hook (Callable) – The user defined hook to be registered.
- prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns
-
a handle that can be used to remove the added hook by calling
handle.remove() - Return type
-
torch.utils.hooks.RemoveableHandle
register_step_post_hook(hook)-
Register an optimizer step post hook which will be called after optimizer step. It should have the following signature:
hook(optimizer, args, kwargs) -> NoneThe
optimizerargument is the optimizer instance being used.- Parameters
-
hook (Callable) – The user defined hook to be registered.
- Returns
-
a handle that can be used to remove the added hook by calling
handle.remove() - Return type
-
torch.utils.hooks.RemovableHandle
register_step_pre_hook(hook)-
Register an optimizer step pre hook which will be called before optimizer step. It should have the following signature:
hook(optimizer, args, kwargs) -> None or modified args and kwargsThe
optimizerargument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.- Parameters
-
hook (Callable) – The user defined hook to be registered.
- Returns
-
a handle that can be used to remove the added hook by calling
handle.remove() - Return type
-
torch.utils.hooks.RemovableHandle
state_dict()-
Returns the state of the optimizer as a
dict.It contains two entries:
-
state: a Dict holding current optimization state. Its content-
differs between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
stateis a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
-
param_groups: a List containing all parameter groups where each-
parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params(int IDs) and the optimizerparam_groups(actualnn.Parameters) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] } ] } -
step(closure=None)[source]-
Performs a single optimization step.
- Parameters
-
closure (Callable, optional) – A closure that reevaluates the model and returns the loss.
zero_grad(set_to_none=True)-
Resets the gradients of all optimized
torch.Tensors.- Parameters
-
set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests
zero_grad(set_to_none=True)followed by a backward pass,.grads are guaranteed to be None for params that did not receive a gradient. 3.torch.optimoptimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.optim.SGD.html