pytorch / 2 / generated / torch.nn.unflatten.html

Unflatten

class torch.nn.Unflatten(dim, unflattened_size) [source]

Unflattens a tensor dim expanding it to a desired shape. For use with Sequential.

  • dim specifies the dimension of the input tensor to be unflattened, and it can be either int or str when Tensor or NamedTensor is used, respectively.
  • unflattened_size is the new shape of the unflattened dimension of the tensor and it can be a tuple of ints or a list of ints or torch.Size for Tensor input; a NamedShape (tuple of (name, size) tuples) for NamedTensor input.
Shape:
  • Input: ( , S dim , ) (*, S_{\text{dim}}, *) , where S dim S_{\text{dim}} is the size at dimension dim and * means any number of dimensions including none.
  • Output: ( , U 1 , . . . , U n , ) (*, U_1, ..., U_n, *) , where U U = unflattened_size and i = 1 n U i = S dim \prod_{i=1}^n U_i = S_{\text{dim}} .
Parameters
  • dim (Union[int, str]) – Dimension to be unflattened
  • unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – New shape of the unflattened dimension

Examples

>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.nn.Unflatten.html