On this page
torch.where
torch.where(condition, input, other, *, out=None) → Tensor-
Return a tensor of elements selected from either
inputorother, depending oncondition.The operation is defined as:
Note
The tensors
condition,input,othermust be broadcastable.- Parameters
-
- condition (BoolTensor) – When True (nonzero), yield input, otherwise yield other
- input (Tensor or Scalar) – value (if
inputis a scalar) or values selected at indices whereconditionisTrue - other (Tensor or Scalar) – value (if
otheris a scalar) or values selected at indices whereconditionisFalse
- Keyword Arguments
-
out (Tensor, optional) – the output tensor.
- Returns
-
A tensor of shape equal to the broadcasted shape of
condition,input,other - Return type
Example:
>>> x = torch.randn(3, 2) >>> y = torch.ones(3, 2) >>> x tensor([[-0.4620, 0.3139], [ 0.3898, -0.7197], [ 0.0478, -0.1657]]) >>> torch.where(x > 0, 1.0, 0.0) tensor([[0., 1.], [1., 0.], [1., 0.]]) >>> torch.where(x > 0, x, y) tensor([[ 1.0000, 0.3139], [ 0.3898, 1.0000], [ 0.0478, 1.0000]]) >>> x = torch.randn(2, 2, dtype=torch.double) >>> x tensor([[ 1.0779, 0.0383], [-0.8785, -1.1089]], dtype=torch.float64) >>> torch.where(x > 0, x, 0.) tensor([[1.0779, 0.0383], [0.0000, 0.0000]], dtype=torch.float64)- torch.where(condition) tuple of LongTensor
torch.where(condition)is identical totorch.nonzero(condition, as_tuple=True).Note
See also
torch.nonzero().
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.where.html