On this page
torch.set_float32_matmul_precision
torch.set_float32_matmul_precision(precision)
[source]-
Sets the internal precision of float32 matrix multiplications.
Running float32 matrix multiplications in lower precision may significantly increase performance, and in some programs the loss of precision has a negligible impact.
Supports three settings:
- “highest”, float32 matrix multiplications use the float32 datatype for internal computations.
- “high”, float32 matrix multiplications use the TensorFloat32 or bfloat16_3x datatypes for internal computations, if fast matrix multiplication algorithms using those datatypes internally are available. Otherwise float32 matrix multiplications are computed as if the precision is “highest”.
- “medium”, float32 matrix multiplications use the bfloat16 datatype for internal computations, if a fast matrix multiplication algorithm using that datatype internally is available. Otherwise float32 matrix multiplications are computed as if the precision is “high”.
Note
This does not change the output dtype of float32 matrix multiplications, it controls how the internal computation of the matrix multiplication is performed.
Note
This does not change the precision of convolution operations. Other flags, like
torch.backends.cudnn.allow_tf32
, may control the precision of convolution operations.Note
This flag currently only affects one native device type: CUDA. If “high” or “medium” are set then the TensorFloat32 datatype will be used when computing float32 matrix multiplications, equivalent to setting
torch.backends.cuda.matmul.allow_tf32 = True
. When “highest” (the default) is set then the float32 datatype is used for internal computations, equivalent to settingtorch.backends.cuda.matmul.allow_tf32 = False
.- Parameters:
-
precision (str) – can be set to “highest” (default), “high”, or “medium” (see above).
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/1.13/generated/torch.set_float32_matmul_precision.html