On this page
Flatten
class torch.nn.Flatten(start_dim=1, end_dim=-1)
[source]-
Flattens a contiguous range of dims into a tensor. For use with
Sequential
. Seetorch.flatten()
for details.- Shape:
-
- Input: ,’ where is the size at dimension and means any number of dimensions including none.
- Output: .
- Parameters
- Examples::
-
>>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5])
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.nn.Flatten.html