Home | History | Annotate | Download | only in ops
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experimental `dataset` API for parsing example."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import structure
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import sparse_tensor
     25 from tensorflow.python.ops import gen_experimental_dataset_ops
     26 from tensorflow.python.ops import parsing_ops
     27 from tensorflow.python.util.tf_export import tf_export
     28 
     29 
     30 class _ParseExampleDataset(dataset_ops.UnaryDataset):
     31   """A `Dataset` that parses `example` dataset into a `dict` dataset."""
     32 
     33   def __init__(self, input_dataset, features, num_parallel_calls):
     34     self._input_dataset = input_dataset
     35     if not input_dataset._element_structure.is_compatible_with(  # pylint: disable=protected-access
     36         structure.TensorStructure(dtypes.string, [None])):
     37       raise TypeError("Input dataset should be a dataset of vectors of strings")
     38     self._num_parallel_calls = num_parallel_calls
     39     # pylint: disable=protected-access
     40     self._features = parsing_ops._prepend_none_dimension(features)
     41     # sparse_keys and dense_keys come back sorted here.
     42     (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
     43      dense_shapes) = parsing_ops._features_to_raw_params(
     44          self._features, [
     45              parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
     46              parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
     47          ])
     48     # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
     49     (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
     50      dense_shape_as_shape) = parsing_ops._process_raw_parameters(
     51          None, dense_defaults, sparse_keys, sparse_types, dense_keys,
     52          dense_types, dense_shapes)
     53     # pylint: enable=protected-access
     54     self._sparse_keys = sparse_keys
     55     self._sparse_types = sparse_types
     56     self._dense_keys = dense_keys
     57     self._dense_defaults = dense_defaults_vec
     58     self._dense_shapes = dense_shapes
     59     self._dense_types = dense_types
     60     input_dataset_shape = dataset_ops.get_legacy_output_shapes(
     61         self._input_dataset)
     62     dense_output_shapes = [input_dataset_shape.concatenate(shape)
     63                            for shape in dense_shape_as_shape]
     64     sparse_output_shapes = [input_dataset_shape.concatenate([None])
     65                             for _ in range(len(sparse_keys))]
     66 
     67     output_shapes = dict(
     68         zip(self._dense_keys + self._sparse_keys,
     69             dense_output_shapes + sparse_output_shapes))
     70     output_types = dict(
     71         zip(self._dense_keys + self._sparse_keys,
     72             self._dense_types + self._sparse_types))
     73     output_classes = dict(
     74         zip(self._dense_keys + self._sparse_keys,
     75             [ops.Tensor for _ in range(len(self._dense_defaults))] +
     76             [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
     77             ]))
     78     self._structure = structure.convert_legacy_structure(
     79         output_types, output_shapes, output_classes)
     80 
     81     variant_tensor = (
     82         gen_experimental_dataset_ops.experimental_parse_example_dataset(
     83             self._input_dataset._variant_tensor,  # pylint: disable=protected-access
     84             self._num_parallel_calls,
     85             self._dense_defaults,
     86             self._sparse_keys,
     87             self._dense_keys,
     88             self._sparse_types,
     89             self._dense_shapes,
     90             **dataset_ops.flat_structure(self)))
     91     super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
     92 
     93   @property
     94   def _element_structure(self):
     95     return self._structure
     96 
     97 
     98 # TODO(b/111553342): add arguments names and example names as well.
     99 @tf_export("data.experimental.parse_example_dataset")
    100 def parse_example_dataset(features, num_parallel_calls=1):
    101   """A transformation that parses `Example` protos into a `dict` of tensors.
    102 
    103   Parses a number of serialized `Example` protos given in `serialized`. We refer
    104   to `serialized` as a batch with `batch_size` many entries of individual
    105   `Example` protos.
    106 
    107   This op parses serialized examples into a dictionary mapping keys to `Tensor`
    108   and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
    109   `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
    110   and `SparseFeature` is mapped to a `SparseTensor`, and each
    111   `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
    112   details about feature dictionaries.
    113 
    114   Args:
    115    features: A `dict` mapping feature keys to `FixedLenFeature`,
    116      `VarLenFeature`, and `SparseFeature` values.
    117    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
    118       representing the number of parsing processes to call in parallel.
    119 
    120   Returns:
    121     A dataset transformation function, which can be passed to
    122     `tf.data.Dataset.apply`.
    123 
    124   Raises:
    125     ValueError: if features argument is None.
    126   """
    127   if features is None:
    128     raise ValueError("Missing: features was %s." % features)
    129 
    130   def _apply_fn(dataset):
    131     """Function from `Dataset` to `Dataset` that applies the transformation."""
    132     out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
    133     if any(
    134         isinstance(feature, parsing_ops.SparseFeature)
    135         for _, feature in features.items()
    136     ):
    137       # pylint: disable=protected-access
    138       # pylint: disable=g-long-lambda
    139       out_dataset = out_dataset.map(
    140           lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
    141               features, x), num_parallel_calls=num_parallel_calls)
    142     return out_dataset
    143 
    144   return _apply_fn
    145