pytorch / 2 / generated / torch.masked_select.html

torch.masked_select

torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.

The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

Note

The returned tensor does not use the same storage as the original tensor

Parameters
  • input (Tensor) – the input tensor.
  • mask (BoolTensor) – the tensor containing the binary mask to index with
Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
        [-1.2035,  1.2252,  0.5002,  0.6248],
        [ 0.1307, -2.0608,  0.1244,  2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
        [False, True, True, True],
        [False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252,  0.5002,  0.6248,  2.0139])

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.masked_select.html