pytorch / 1 / generated / torch.nn.multiheadattention.html

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None) [source]

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Multi-Head Attention is defined as:

MultiHead ( Q , K , V ) = Concat ( h e a d 1 , , h e a d h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where h e a d i = Attention ( Q W i Q , K W i K , V W i V ) head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) .

forward() will use a special optimized implementation if all of the following conditions are met:

  • self attention is being computed (i.e., query, key, and value are the same tensor. This restriction will be loosened in the future.)
  • Either autograd is disabled (using torch.inference_mode or torch.no_grad) or no tensor argument requires_grad
  • training is disabled (using .eval())
  • dropout is 0
  • add_bias_kv is False
  • add_zero_attn is False
  • batch_first is True and the input is batched
  • kdim and vdim are equal to embed_dim
  • at most one of key_padding_mask or attn_mask is passed
  • if a NestedTensor is passed, neither key_padding_mask nor attn_mask is passed

If the optimized implementation is in use, a NestedTensor can be passed for query/key/value to represent padding more efficiently than using a padding mask. In this case, a NestedTensor will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.

Parameters:
  • embed_dim – Total dimension of the model.
  • num_heads – Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).
  • dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
  • bias – If specified, adds bias to input / output projection layers. Default: True.
  • add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.
  • add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.
  • kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).
  • vdim – Total number of features for values. Default: None (uses vdim=embed_dim).
  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True) [source]
Parameters:
  • query (Tensor) – Query embeddings of shape ( L , E q ) (L, E_q) for unbatched input, ( L , N , E q ) (L, N, E_q) when batch_first=False or ( N , L , E q ) (N, L, E_q) when batch_first=True, where L L is the target sequence length, N N is the batch size, and E q E_q is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.
  • key (Tensor) – Key embeddings of shape ( S , E k ) (S, E_k) for unbatched input, ( S , N , E k ) (S, N, E_k) when batch_first=False or ( N , S , E k ) (N, S, E_k) when batch_first=True, where S S is the source sequence length, N N is the batch size, and E k E_k is the key embedding dimension kdim. See “Attention Is All You Need” for more details.
  • value (Tensor) – Value embeddings of shape ( S , E v ) (S, E_v) for unbatched input, ( S , N , E v ) (S, N, E_v) when batch_first=False or ( N , S , E v ) (N, S, E_v) when batch_first=True, where S S is the source sequence length, N N is the batch size, and E v E_v is the value embedding dimension vdim. See “Attention Is All You Need” for more details.
  • key_padding_mask (Optional[Tensor]) – If specified, a mask of shape ( N , S ) (N, S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be ( S ) (S) . Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding key value.
  • need_weights (bool) – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.
  • attn_mask (Optional[Tensor]) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape ( L , S ) (L, S) or ( N num_heads , L , S ) (N\cdot\text{num\_heads}, L, S) , where N N is the batch size, L L is the target sequence length, and S S is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.
  • average_attn_weights (bool) – If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)
Return type:

Tuple[Tensor, Optional[Tensor]]

Outputs:
  • attn_output - Attention outputs of shape ( L , E ) (L, E) when input is unbatched, ( L , N , E ) (L, N, E) when batch_first=False or ( N , L , E ) (N, L, E) when batch_first=True, where L L is the target sequence length, N N is the batch size, and E E is the embedding dimension embed_dim.
  • attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads of shape ( L , S ) (L, S) when input is unbatched or ( N , L , S ) (N, L, S) , where N N is the batch size, L L is the target sequence length, and S S is the source sequence length. If average_attn_weights=False, returns attention weights per head of shape ( num_heads , L , S ) (\text{num\_heads}, L, S) when input is unbatched or ( N , num_heads , L , S ) (N, \text{num\_heads}, L, S) .

Note

batch_first argument is ignored for unbatched inputs.

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/1.13/generated/torch.nn.MultiheadAttention.html