On this page
GRUCell
class torch.nn.GRUCell(input_size, hidden_size, bias=True, device=None, dtype=None)
[source]-
A gated recurrent unit (GRU) cell
where is the sigmoid function, and is the Hadamard product.
- Parameters
- Inputs: input, hidden
-
- input : tensor containing input features
- hidden : tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided.
- Outputs: h’
-
- h’ : tensor containing the next hidden state for each element in the batch
- Shape:
-
- input:
or
tensor containing input features where
=
input_size
. - hidden:
or
tensor containing the initial hidden state where
=
hidden_size
. Defaults to zero if not provided. - output: or tensor containing the next hidden state.
- input:
or
tensor containing input features where
=
- Variables
-
- weight_ih (torch.Tensor) – the learnable input-hidden weights, of shape
(3*hidden_size, input_size)
- weight_hh (torch.Tensor) – the learnable hidden-hidden weights, of shape
(3*hidden_size, hidden_size)
- bias_ih – the learnable input-hidden bias, of shape
(3*hidden_size)
- bias_hh – the learnable hidden-hidden bias, of shape
(3*hidden_size)
- weight_ih (torch.Tensor) – the learnable input-hidden weights, of shape
Note
All the weights and biases are initialized from where
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Examples:
>>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx)
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.nn.GRUCell.html