1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Experimental `dataset` API for parsing example.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import structure 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import sparse_tensor 25 from tensorflow.python.ops import gen_experimental_dataset_ops 26 from tensorflow.python.ops import parsing_ops 27 from tensorflow.python.util.tf_export import tf_export 28 29 30 class _ParseExampleDataset(dataset_ops.UnaryDataset): 31 """A `Dataset` that parses `example` dataset into a `dict` dataset.""" 32 33 def __init__(self, input_dataset, features, num_parallel_calls): 34 self._input_dataset = input_dataset 35 if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access 36 structure.TensorStructure(dtypes.string, [None])): 37 raise TypeError("Input dataset should be a dataset of vectors of strings") 38 self._num_parallel_calls = num_parallel_calls 39 # pylint: disable=protected-access 40 self._features = parsing_ops._prepend_none_dimension(features) 41 # sparse_keys and dense_keys come back sorted here. 42 (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, 43 dense_shapes) = parsing_ops._features_to_raw_params( 44 self._features, [ 45 parsing_ops.VarLenFeature, parsing_ops.SparseFeature, 46 parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature 47 ]) 48 # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. 49 (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, 50 dense_shape_as_shape) = parsing_ops._process_raw_parameters( 51 None, dense_defaults, sparse_keys, sparse_types, dense_keys, 52 dense_types, dense_shapes) 53 # pylint: enable=protected-access 54 self._sparse_keys = sparse_keys 55 self._sparse_types = sparse_types 56 self._dense_keys = dense_keys 57 self._dense_defaults = dense_defaults_vec 58 self._dense_shapes = dense_shapes 59 self._dense_types = dense_types 60 input_dataset_shape = dataset_ops.get_legacy_output_shapes( 61 self._input_dataset) 62 dense_output_shapes = [input_dataset_shape.concatenate(shape) 63 for shape in dense_shape_as_shape] 64 sparse_output_shapes = [input_dataset_shape.concatenate([None]) 65 for _ in range(len(sparse_keys))] 66 67 output_shapes = dict( 68 zip(self._dense_keys + self._sparse_keys, 69 dense_output_shapes + sparse_output_shapes)) 70 output_types = dict( 71 zip(self._dense_keys + self._sparse_keys, 72 self._dense_types + self._sparse_types)) 73 output_classes = dict( 74 zip(self._dense_keys + self._sparse_keys, 75 [ops.Tensor for _ in range(len(self._dense_defaults))] + 76 [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) 77 ])) 78 self._structure = structure.convert_legacy_structure( 79 output_types, output_shapes, output_classes) 80 81 variant_tensor = ( 82 gen_experimental_dataset_ops.experimental_parse_example_dataset( 83 self._input_dataset._variant_tensor, # pylint: disable=protected-access 84 self._num_parallel_calls, 85 self._dense_defaults, 86 self._sparse_keys, 87 self._dense_keys, 88 self._sparse_types, 89 self._dense_shapes, 90 **dataset_ops.flat_structure(self))) 91 super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) 92 93 @property 94 def _element_structure(self): 95 return self._structure 96 97 98 # TODO(b/111553342): add arguments names and example names as well. 99 @tf_export("data.experimental.parse_example_dataset") 100 def parse_example_dataset(features, num_parallel_calls=1): 101 """A transformation that parses `Example` protos into a `dict` of tensors. 102 103 Parses a number of serialized `Example` protos given in `serialized`. We refer 104 to `serialized` as a batch with `batch_size` many entries of individual 105 `Example` protos. 106 107 This op parses serialized examples into a dictionary mapping keys to `Tensor` 108 and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`, 109 `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature` 110 and `SparseFeature` is mapped to a `SparseTensor`, and each 111 `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more 112 details about feature dictionaries. 113 114 Args: 115 features: A `dict` mapping feature keys to `FixedLenFeature`, 116 `VarLenFeature`, and `SparseFeature` values. 117 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 118 representing the number of parsing processes to call in parallel. 119 120 Returns: 121 A dataset transformation function, which can be passed to 122 `tf.data.Dataset.apply`. 123 124 Raises: 125 ValueError: if features argument is None. 126 """ 127 if features is None: 128 raise ValueError("Missing: features was %s." % features) 129 130 def _apply_fn(dataset): 131 """Function from `Dataset` to `Dataset` that applies the transformation.""" 132 out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls) 133 if any( 134 isinstance(feature, parsing_ops.SparseFeature) 135 for _, feature in features.items() 136 ): 137 # pylint: disable=protected-access 138 # pylint: disable=g-long-lambda 139 out_dataset = out_dataset.map( 140 lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features( 141 features, x), num_parallel_calls=num_parallel_calls) 142 return out_dataset 143 144 return _apply_fn 145