pytorch / 1 / generated / torch.where.html

torch.where

torch.where(condition, x, y) → Tensor

Return a tensor of elements selected from either x or y, depending on condition.

The operation is defined as:

out i = { x i if condition i y i otherwise \text{out}_i = \begin{cases} \text{x}_i & \text{if } \text{condition}_i \\ \text{y}_i & \text{otherwise} \\ \end{cases}

Note

The tensors condition, x, y must be broadcastable.

Parameters:
  • condition (BoolTensor) – When True (nonzero), yield x, otherwise yield y
  • x (Tensor or Scalar) – value (if x is a scalar) or values selected at indices where condition is True
  • y (Tensor or Scalar) – value (if y is a scalar) or values selected at indices where condition is False
Returns:

A tensor of shape equal to the broadcasted shape of condition, x, y

Return type:

Tensor

Example:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) tuple of LongTensor

torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).

Note

See also torch.nonzero().

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/1.13/generated/torch.where.html