On this page
tf.data.experimental.get_single_element
Returns the single element of the dataset as a nested structure of tensors. (deprecated)
tf.data.experimental.get_single_element(
    dataset
)
The function enables you to use a tf.data.Dataset in a stateless "tensor-in tensor-out" expression, without creating an iterator. This facilitates the ease of data transformation on tensors using the optimized tf.data.Dataset abstraction on top of them.
For example, lets consider a preprocessing_fn which would take as an input the raw features and returns the processed feature along with it's label.
def preprocessing_fn(raw_feature):
  # ... the raw_feature is preprocessed as per the use-case
  return feature
raw_features = ...  # input batch of BATCH_SIZE elements.
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
           .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
           .batch(BATCH_SIZE))
processed_features = tf.data.experimental.get_single_element(dataset)
In the above example, the raw_features tensor of length=BATCH_SIZE was converted to a tf.data.Dataset. Next, each of the raw_feature was mapped using the preprocessing_fn and the processed features were grouped into a single batch. The final dataset contains only one element which is a batch of all the processed features.
Note: The dataset should contain only one element.
  
  Now, instead of creating an iterator for the dataset and retrieving the batch of features, the tf.data.experimental.get_single_element() function is used to skip the iterator creation process and directly output the batch of features.
This can be particularly useful when your tensor transformations are expressed as tf.data.Dataset operations, and you want to use those transformations while serving your model.
Keras
model = ... # A pre-built or custom model
class PreprocessingModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__(self)
    self.model = model
  @tf.function(input_signature=[...])
  def serving_fn(self, data):
    ds = tf.data.Dataset.from_tensor_slices(data)
    ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
    ds = ds.batch(batch_size=BATCH_SIZE)
    return tf.argmax(
      self.model(tf.data.experimental.get_single_element(ds)),
      axis=-1
    )
preprocessing_model = PreprocessingModel(model)
your_exported_model_dir = ... # save the model to this path.
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
              signatures={'serving_default': preprocessing_model.serving_fn})
Estimator
In the case of estimators, you need to generally define a serving_input_fn which would require the features to be processed by the model while inferencing.
def serving_input_fn():
  raw_feature_spec = ... # Spec for the raw_features
  input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  )
  serving_input_receiver = input_fn()
  raw_features = serving_input_receiver.features
  def preprocessing_fn(raw_feature):
    # ... the raw_feature is preprocessed as per the use-case
    return feature
  dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
            .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
            .batch(BATCH_SIZE))
  processed_features = tf.data.experimental.get_single_element(dataset)
  # Please note that the value of `BATCH_SIZE` should be equal to
  # the size of the leading dimension of `raw_features`. This ensures
  # that `dataset` has only element, which is a pre-requisite for
  # using `tf.data.experimental.get_single_element(dataset)`.
  return tf.estimator.export.ServingInputReceiver(
      processed_features, serving_input_receiver.receiver_tensors)
estimator = ... # A pre-built or custom estimator
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
| Args | |
|---|---|
| dataset | A tf.data.Datasetobject containing a single element. | 
| Returns | |
|---|---|
| A nested structure of tf.Tensorobjects, corresponding to the single element ofdataset. | 
| Raises | |
|---|---|
| TypeError | if datasetis not atf.data.Datasetobject. | 
| InvalidArgumentError | (at runtime) if datasetdoes not contain exactly one element. | 
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
 https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/experimental/get_single_element