On this page
tf.distribute.ReplicaContext
tf.distribute.Strategy
API when in a replica context.
tf.distribute.ReplicaContext(
strategy, replica_id_in_sync_group
)
You can use tf.distribute.get_replica_context
to get an instance of ReplicaContext
. This should be inside your replicated step function, such as in a tf.distribute.Strategy.run
call.
Attributes | |
---|---|
devices |
The devices this replica is to be executed on, as a tuple of strings. |
num_replicas_in_sync |
Returns number of replicas over which gradients are aggregated. |
replica_id_in_sync_group |
Returns the id of the replica being defined. This identifies the replica that is part of a sync group. Currently we assume that all sync groups contain the same number of replicas. The value of the replica id can range from 0 to Note: This is not guaranteed to be the same ID as the XLA replica ID use for low-level operations such as collective_permute. |
strategy |
The current tf.distribute.Strategy object. |
Methods
all_reduce
all_reduce(
reduce_op, value, experimental_hints=None
)
All-reduces the given value Tensor
nest across replicas.
If all_reduce
is called in any replica, it must be called in all replicas. The nested structure and Tensor
shapes must be identical in all replicas.
Example with two replicas: Replica 0 value
: {'a': 1, 'b': [40, 1]} Replica 1 value
: {'a': 3, 'b': [ 2, 98]}
If reduce_op
== SUM
: Result (on all replicas): {'a': 4, 'b': [42, 99]}
If reduce_op
== MEAN
: Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
Args | |
---|---|
reduce_op |
Reduction type, an instance of tf.distribute.ReduceOp enum. |
value |
The nested structure of Tensor s to all-reduce. The structure must be compatible with tf.nest . |
experimental_hints |
A tf.distrbute.experimental.CollectiveHints . Hints to perform collective operations. |
Returns | |
---|---|
A Tensor nest with the reduced value s from each replica. |
merge_call
merge_call(
merge_fn, args=(), kwargs=None
)
Merge args across replicas and run merge_fn
in a cross-replica context.
This allows communication and coordination when there are multiple calls to the step_fn triggered by a call to strategy.run(step_fn, ...)
.
See tf.distribute.Strategy.run
for an explanation.
If not inside a distributed scope, this is equivalent to:
strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
return merge_fn(strategy, *args, **kwargs)
Args | |
---|---|
merge_fn |
Function that joins arguments from threads that are given as PerReplica. It accepts tf.distribute.Strategy object as the first argument. |
args |
List or tuple with positional per-thread arguments for merge_fn . |
kwargs |
Dict with keyword per-thread arguments for merge_fn . |
Returns | |
---|---|
The return value of merge_fn , except for PerReplica values which are unpacked. |
__enter__
__enter__()
__exit__
__exit__(
exception_type, exception_value, traceback
)
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/ReplicaContext