pytorch / 2 / generated / torch.autograd.function.jvp.html

torch.autograd.Function.jvp

static Function.jvp(ctx, *grad_inputs)

Defines a formula for differentiating the operation with forward mode automatic differentiation. This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward() got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward(). Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

Return type

Any

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.autograd.Function.jvp.html