On this page
torch.diag_embed
torch.diag_embed(input, offset=0, dim1=-2, dim2=-1) → Tensor
-
Creates a tensor whose diagonals of certain 2D planes (specified by
dim1
anddim2
) are filled byinput
. To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.The argument
offset
controls which diagonal to consider:- If
offset
= 0, it is the main diagonal. - If
offset
> 0, it is above the main diagonal. - If
offset
< 0, it is below the main diagonal.
The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension. Note that for
offset
other than , the order ofdim1
anddim2
matters. Exchanging them is equivalent to changing the sign ofoffset
.Applying
torch.diagonal()
to the output of this function with the same arguments yields a matrix identical to input. However,torch.diagonal()
has different default dimensions, so those need to be explicitly specified.- Parameters
-
- input (Tensor) – the input tensor. Must be at least 1-dimensional.
- offset (int, optional) – which diagonal to consider. Default: 0 (main diagonal).
- dim1 (int, optional) – first dimension with respect to which to take diagonal. Default: -2.
- dim2 (int, optional) – second dimension with respect to which to take diagonal. Default: -1.
Example:
>>> a = torch.randn(2, 3) >>> torch.diag_embed(a) tensor([[[ 1.5410, 0.0000, 0.0000], [ 0.0000, -0.2934, 0.0000], [ 0.0000, 0.0000, -2.1788]], [[ 0.5684, 0.0000, 0.0000], [ 0.0000, -1.0845, 0.0000], [ 0.0000, 0.0000, -1.3986]]]) >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], [ 0.0000, 0.5684, 0.0000, 0.0000]], [[ 0.0000, 0.0000, -0.2934, 0.0000], [ 0.0000, 0.0000, -1.0845, 0.0000]], [[ 0.0000, 0.0000, 0.0000, -2.1788], [ 0.0000, 0.0000, 0.0000, -1.3986]], [[ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000]]])
- If
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.diag_embed.html