pytorch / 1 / generated / torch.tensor.register_hook.html

torch.Tensor.register_hook

Tensor.register_hook(hook) [source]

Registers a backward hook.

The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the following signature:

hook(grad) -> Tensor or None

The hook should not modify its argument, but it can optionally return a new gradient which will be used in place of grad.

This function returns a handle with a method handle.remove() that removes the hook from the module.

Example:

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/1.13/generated/torch.Tensor.register_hook.html