pytorch / 1 / generated / torch.tensor.sparse_mask.html

torch.Tensor.sparse_mask

Tensor.sparse_mask(mask) → Tensor

Returns a new sparse tensor with values from a strided tensor self filtered by the indices of the sparse tensor mask. The values of mask sparse tensor are ignored. self and mask tensors must have the same shape.

Note

The returned sparse tensor has the same indices as the sparse tensor mask, even when the corresponding values in self are zeros.

Parameters:

mask (Tensor) – a sparse tensor whose indices are used as a filter

Example:

>>> nse = 5
>>> dims = (5, 5, 2, 2)
>>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)),
...                torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse)
>>> V = torch.randn(nse, dims[2], dims[3])
>>> S = torch.sparse_coo_tensor(I, V, dims).coalesce()
>>> D = torch.randn(dims)
>>> D.sparse_mask(S)
tensor(indices=tensor([[0, 0, 0, 2],
                       [0, 1, 4, 3]]),
       values=tensor([[[ 1.6550,  0.2397],
                       [-0.1611, -0.0779]],

                      [[ 0.2326, -1.0558],
                       [ 1.4711,  1.9678]],

                      [[-0.5138, -0.0411],
                       [ 1.9417,  0.5158]],

                      [[ 0.0793,  0.0036],
                       [-0.2569, -0.1055]]]),
       size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo)

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/1.13/generated/torch.Tensor.sparse_mask.html