On this page
tf.distribute.DistributedIterator
An iterator over tf.distribute.DistributedDataset
.
tf.distribute.DistributedIterator
is the primary mechanism for enumerating elements of a tf.distribute.DistributedDataset
. It supports the Python Iterator protocol, which means it can be iterated over using a for-loop or by fetching individual elements explicitly via get_next()
.
You can create a tf.distribute.DistributedIterator
by calling iter
on a tf.distribute.DistributedDataset
or creating a python loop over a tf.distribute.DistributedDataset
.
Visit the tutorial on distributed input for more examples and caveats.
Attributes | |
---|---|
element_spec |
The type specification of an element of tf.distribute.DistributedIterator . The above example corresponds to the case where you have only one device. If you have two devices, for example, Then the final line will print out:
|
Methods
get_next
get_next()
Returns the next input from the iterator for all replicas.
Example use:
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.range(100).batch(2)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
dist_dataset_iterator = iter(dist_dataset)
@tf.function
def one_step(input):
return input
step_num = 5
for _ in range(step_num):
strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
strategy.experimental_local_results(dist_dataset_iterator.get_next())
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])>,)
The above example corresponds to the case where you have only one device. If you have two devices, for example,
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
Then the final line will print out:
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
Returns | |
---|---|
A single tf.Tensor or a tf.distribute.DistributedValues which contains the next input for all replicas. |
Raises | |
---|---|
tf.errors.OutOfRangeError : If the end of the iterator has been reached. |
get_next_as_optional
get_next_as_optional()
Returns a tf.experimental.Optional
that contains the next value for all replicas.
If the tf.distribute.DistributedIterator
has reached the end of the sequence, the returned tf.experimental.Optional
will have no value.
Example usage:
strategy = tf.distribute.MirroredStrategy()
global_batch_size = 2
steps_per_loop = 2
dataset = tf.data.Dataset.range(10).batch(global_batch_size)
distributed_iterator = iter(
strategy.experimental_distribute_dataset(dataset))
def step_fn(x):
return x
@tf.function
def train_fn(distributed_iterator):
for _ in tf.range(steps_per_loop):
optional_data = distributed_iterator.get_next_as_optional()
if not optional_data.has_value():
break
tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))
train_fn(distributed_iterator)
# ([0 1],)
# ([2 3],)
Returns | |
---|---|
An tf.experimental.Optional object representing the next value from the tf.distribute.DistributedIterator (if it has one) or no value. |
__iter__
__iter__()
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator