On this page
torch.baddbmm
torch.baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) → Tensor
-
Performs a batch matrix-matrix product of matrices in
batch1
andbatch2
.input
is added to the final result.batch1
andbatch2
must be 3-D tensors each containing the same number of matrices.If
batch1
is a tensor,batch2
is a tensor, theninput
must be broadcastable with a tensor andout
will be a tensor. Bothalpha
andbeta
mean the same as the scaling factors used intorch.addbmm()
.If
beta
is 0, theninput
will be ignored, andnan
andinf
in it will not be propagated.For inputs of type
FloatTensor
orDoubleTensor
, argumentsbeta
andalpha
must be real numbers, otherwise they should be integers.This operator supports TensorFloat32.
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
- Parameters
- Keyword Arguments
-
- beta (Number, optional) – multiplier for
input
( ) - alpha (Number, optional) – multiplier for ( )
- out (Tensor, optional) – the output tensor.
- beta (Number, optional) – multiplier for
Example:
>>> M = torch.randn(10, 3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5])
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.baddbmm.html