On this page
torch.autograd.Function.vmap
static Function.vmap(info, in_dims, *args)[source]- 
    
Defines a rule for the behavior of this autograd.Function underneath
torch.vmap(). For atorch.autograd.Function()to supporttorch.vmap(), you must either override this staticmethod, or setgenerate_vmap_ruletoTrue(you may not do both).If you choose to override this staticmethod: it must accept
- an 
infoobject as the first argument.info.batch_sizespecifies the size of the dimension being vmapped over, whileinfo.randomnessis the randomness option passed totorch.vmap(). - an 
in_dimstuple as the second argument. For each arg inargs,in_dimshas a correspondingOptional[int]. It isNoneif the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over. *args, which is the same as the args toforward().
The return of the vmap staticmethod is a tuple of
(output, out_dims). Similar toin_dims,out_dimsshould be of the same structure asoutputand contain oneout_dimper output that specifies if the output has the vmapped dimension and what index it is in.Please see Extending torch.func with autograd.Function for more details.
 - an 
 
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
 https://pytorch.org/docs/2.1/generated/torch.autograd.Function.vmap.html