pytorch / 2 / generated / torch.nn.flatten.html

Flatten

class torch.nn.Flatten(start_dim=1, end_dim=-1) [source]

Flattens a contiguous range of dims into a tensor. For use with Sequential. See torch.flatten() for details.

Shape:
  • Input: ( , S start , . . . , S i , . . . , S end , ) (*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *) ,’ where S i S_{i} is the size at dimension i i and * means any number of dimensions including none.
  • Output: ( , i = start end S i , ) (*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *) .
Parameters
  • start_dim (int) – first dim to flatten (default = 1).
  • end_dim (int) – last dim to flatten (default = -1).
Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.nn.Flatten.html