On this page
tf.distribute.experimental.CollectiveHints
Hints for collective operations like AllReduce.
tf.distribute.experimental.CollectiveHints(
    bytes_per_pack=0, timeout_seconds=None
)
This can be passed to methods like tf.distribute.get_replica_context().all_reduce() to optimize collective operation performance. Note that these are only hints, which may or may not change the actual behavior. Some options only apply to certain strategy and are ignored by others.
One common optimization is to break gradients all-reduce into multiple packs so that weight updates can overlap with gradient all-reduce.
Examples:
- bytes_per_pack
hints = tf.distribute.experimental.CollectiveHints(
    bytes_per_pack=50 * 1024 * 1024)
grads = tf.distribute.get_replica_context().all_reduce(
    'sum', grads, experimental_hints=hints)
optimizer.apply_gradients(zip(grads, vars),
    experimental_aggregate_gradients=False)
- timeout_seconds
strategy = tf.distribute.MirroredStrategy()
hints = tf.distribute.experimental.CollectiveHints(
    timeout_seconds=120)
try:
  strategy.reduce("sum", v, axis=None, experimental_hints=hints)
except tf.errors.DeadlineExceededError:
  do_something()
| Args | |
|---|---|
| bytes_per_pack | a non-negative integer. Breaks collective operations into packs of certain size. If it's zero, the value is determined automatically. This only applies to all-reduce with MultiWorkerMirroredStrategycurrently. | 
| timeout_seconds | a float or None, timeout in seconds. If not None, the collective raises tf.errors.DeadlineExceededErrorif it takes longer than this timeout. This can be useful when debugging hanging issues. This should only be used for debugging since it creates a new thread for each collective, i.e. an overhead oftimeout_seconds * num_collectives_per_secondmore threads. This only works fortf.distribute.experimental.MultiWorkerMirroredStrategy. | 
| Raises | |
|---|---|
| ValueError | When arguments have invalid value. | 
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
 https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/distribute/experimental/CollectiveHints