Home | History | Annotate | Download | only in data
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contains the TFExampleDecoder its associated helper classes.
     16 
     17 The TFExampleDecode is a DataDecoder used to decode TensorFlow Example protos.
     18 In order to do so each requested item must be paired with one or more Example
     19 features that are parsed to produce the Tensor-based manifestation of the item.
     20 """
     21 
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import abc
     27 
     28 import six
     29 
     30 from tensorflow.contrib.slim.python.slim.data import data_decoder
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import sparse_tensor
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import control_flow_ops
     35 from tensorflow.python.ops import map_fn
     36 from tensorflow.python.ops import image_ops
     37 from tensorflow.python.ops import math_ops
     38 from tensorflow.python.ops import parsing_ops
     39 from tensorflow.python.ops import sparse_ops
     40 
     41 
     42 @six.add_metaclass(abc.ABCMeta)
     43 class ItemHandler(object):
     44   """Specifies the item-to-Features mapping for tf.parse_example.
     45 
     46   An ItemHandler both specifies a list of Features used for parsing an Example
     47   proto as well as a function that post-processes the results of Example
     48   parsing.
     49   """
     50 
     51   def __init__(self, keys):
     52     """Constructs the handler with the name of the tf.Feature keys to use.
     53 
     54     See third_party/tensorflow/core/example/feature.proto
     55 
     56     Args:
     57       keys: the name of the TensorFlow Example Feature.
     58     """
     59     if not isinstance(keys, (tuple, list)):
     60       keys = [keys]
     61     self._keys = keys
     62 
     63   @property
     64   def keys(self):
     65     return self._keys
     66 
     67   @abc.abstractmethod
     68   def tensors_to_item(self, keys_to_tensors):
     69     """Maps the given dictionary of tensors to the requested item.
     70 
     71     Args:
     72       keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
     73 
     74     Returns:
     75       the final tensor representing the item being handled.
     76     """
     77     pass
     78 
     79 
     80 class ItemHandlerCallback(ItemHandler):
     81   """An ItemHandler that converts the parsed tensors via a given function.
     82 
     83   Unlike other ItemHandlers, the ItemHandlerCallback resolves its item via
     84   a callback function rather than using prespecified behavior.
     85   """
     86 
     87   def __init__(self, keys, func):
     88     """Initializes the ItemHandler.
     89 
     90     Args:
     91       keys: a list of TF-Example keys.
     92       func: a function that takes as an argument a dictionary from `keys` to
     93         parsed Tensors.
     94     """
     95     super(ItemHandlerCallback, self).__init__(keys)
     96     self._func = func
     97 
     98   def tensors_to_item(self, keys_to_tensors):
     99     return self._func(keys_to_tensors)
    100 
    101 
    102 class BoundingBox(ItemHandler):
    103   """An ItemHandler that concatenates a set of parsed Tensors to Bounding Boxes.
    104   """
    105 
    106   def __init__(self, keys=None, prefix=''):
    107     """Initialize the bounding box handler.
    108 
    109     Args:
    110       keys: A list of four key names representing the ymin, xmin, ymax, mmax
    111       prefix: An optional prefix for each of the bounding box keys.
    112         If provided, `prefix` is appended to each key in `keys`.
    113 
    114     Raises:
    115       ValueError: if keys is not `None` and also not a list of exactly 4 keys
    116     """
    117     if keys is None:
    118       keys = ['ymin', 'xmin', 'ymax', 'xmax']
    119     elif len(keys) != 4:
    120       raise ValueError('BoundingBox expects 4 keys but got {}'.format(
    121           len(keys)))
    122     self._prefix = prefix
    123     self._keys = keys
    124     self._full_keys = [prefix + k for k in keys]
    125     super(BoundingBox, self).__init__(self._full_keys)
    126 
    127   def tensors_to_item(self, keys_to_tensors):
    128     """Maps the given dictionary of tensors to a concatenated list of bboxes.
    129 
    130     Args:
    131       keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
    132 
    133     Returns:
    134       [num_boxes, 4] tensor of bounding box coordinates,
    135         i.e. 1 bounding box per row, in order [y_min, x_min, y_max, x_max].
    136     """
    137     sides = []
    138     for key in self._full_keys:
    139       side = keys_to_tensors[key]
    140       if isinstance(side, sparse_tensor.SparseTensor):
    141         side = side.values
    142       side = array_ops.expand_dims(side, 0)
    143       sides.append(side)
    144 
    145     bounding_box = array_ops.concat(sides, 0)
    146     return array_ops.transpose(bounding_box)
    147 
    148 
    149 class Tensor(ItemHandler):
    150   """An ItemHandler that returns a parsed Tensor."""
    151 
    152   def __init__(self, tensor_key, shape_keys=None, shape=None, default_value=0):
    153     """Initializes the Tensor handler.
    154 
    155     Tensors are, by default, returned without any reshaping. However, there are
    156     two mechanisms which allow reshaping to occur at load time. If `shape_keys`
    157     is provided, both the `Tensor` corresponding to `tensor_key` and
    158     `shape_keys` is loaded and the former `Tensor` is reshaped with the values
    159     of the latter. Alternatively, if a fixed `shape` is provided, the `Tensor`
    160     corresponding to `tensor_key` is loaded and reshape appropriately.
    161     If neither `shape_keys` nor `shape` are provided, the `Tensor` will be
    162     returned without any reshaping.
    163 
    164     Args:
    165       tensor_key: the name of the `TFExample` feature to read the tensor from.
    166       shape_keys: Optional name or list of names of the TF-Example feature in
    167         which the tensor shape is stored. If a list, then each corresponds to
    168         one dimension of the shape.
    169       shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
    170         reshaped accordingly.
    171       default_value: The value used when the `tensor_key` is not found in a
    172         particular `TFExample`.
    173 
    174     Raises:
    175       ValueError: if both `shape_keys` and `shape` are specified.
    176     """
    177     if shape_keys and shape is not None:
    178       raise ValueError('Cannot specify both shape_keys and shape parameters.')
    179     if shape_keys and not isinstance(shape_keys, list):
    180       shape_keys = [shape_keys]
    181     self._tensor_key = tensor_key
    182     self._shape_keys = shape_keys
    183     self._shape = shape
    184     self._default_value = default_value
    185     keys = [tensor_key]
    186     if shape_keys:
    187       keys.extend(shape_keys)
    188     super(Tensor, self).__init__(keys)
    189 
    190   def tensors_to_item(self, keys_to_tensors):
    191     tensor = keys_to_tensors[self._tensor_key]
    192     shape = self._shape
    193     if self._shape_keys:
    194       shape_dims = []
    195       for k in self._shape_keys:
    196         shape_dim = keys_to_tensors[k]
    197         if isinstance(shape_dim, sparse_tensor.SparseTensor):
    198           shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
    199         shape_dims.append(shape_dim)
    200       shape = array_ops.reshape(array_ops.stack(shape_dims), [-1])
    201     if isinstance(tensor, sparse_tensor.SparseTensor):
    202       if shape is not None:
    203         tensor = sparse_ops.sparse_reshape(tensor, shape)
    204       tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
    205     else:
    206       if shape is not None:
    207         tensor = array_ops.reshape(tensor, shape)
    208     return tensor
    209 
    210 
    211 class LookupTensor(Tensor):
    212   """An ItemHandler that returns a parsed Tensor, the result of a lookup."""
    213 
    214   def __init__(self,
    215                tensor_key,
    216                table,
    217                shape_keys=None,
    218                shape=None,
    219                default_value=''):
    220     """Initializes the LookupTensor handler.
    221 
    222     See Tensor.  Simply calls a vocabulary (most often, a label mapping) lookup.
    223 
    224     Args:
    225       tensor_key: the name of the `TFExample` feature to read the tensor from.
    226       table: A tf.lookup table.
    227       shape_keys: Optional name or list of names of the TF-Example feature in
    228         which the tensor shape is stored. If a list, then each corresponds to
    229         one dimension of the shape.
    230       shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
    231         reshaped accordingly.
    232       default_value: The value used when the `tensor_key` is not found in a
    233         particular `TFExample`.
    234 
    235     Raises:
    236       ValueError: if both `shape_keys` and `shape` are specified.
    237     """
    238     self._table = table
    239     super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
    240                                        default_value)
    241 
    242   def tensors_to_item(self, keys_to_tensors):
    243     unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
    244     return self._table.lookup(unmapped_tensor)
    245 
    246 
    247 class BackupHandler(ItemHandler):
    248   """An ItemHandler that tries two ItemHandlers in order."""
    249 
    250   def __init__(self, handler, backup):
    251     """Initializes the BackupHandler handler.
    252 
    253     If the first Handler's tensors_to_item returns a Tensor with no elements,
    254     the second Handler is used.
    255 
    256     Args:
    257       handler: The primary ItemHandler.
    258       backup: The backup ItemHandler.
    259 
    260     Raises:
    261       ValueError: if either is not an ItemHandler.
    262     """
    263     if not isinstance(handler, ItemHandler):
    264       raise ValueError('Primary handler is of type %s instead of ItemHandler'
    265                        % type(handler))
    266     if not isinstance(backup, ItemHandler):
    267       raise ValueError('Backup handler is of type %s instead of ItemHandler'
    268                        % type(backup))
    269     self._handler = handler
    270     self._backup = backup
    271     super(BackupHandler, self).__init__(handler.keys + backup.keys)
    272 
    273   def tensors_to_item(self, keys_to_tensors):
    274     item = self._handler.tensors_to_item(keys_to_tensors)
    275     return control_flow_ops.cond(
    276         pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
    277         true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
    278         false_fn=lambda: item)
    279 
    280 
    281 class SparseTensor(ItemHandler):
    282   """An ItemHandler for SparseTensors."""
    283 
    284   def __init__(self,
    285                indices_key=None,
    286                values_key=None,
    287                shape_key=None,
    288                shape=None,
    289                densify=False,
    290                default_value=0):
    291     """Initializes the Tensor handler.
    292 
    293     Args:
    294       indices_key: the name of the TF-Example feature that contains the ids.
    295         Defaults to 'indices'.
    296       values_key: the name of the TF-Example feature that contains the values.
    297         Defaults to 'values'.
    298       shape_key: the name of the TF-Example feature that contains the shape.
    299         If provided it would be used.
    300       shape: the output shape of the SparseTensor. If `shape_key` is not
    301         provided this `shape` would be used.
    302       densify: whether to convert the SparseTensor into a dense Tensor.
    303       default_value: Scalar value to set when making dense for indices not
    304         specified in the `SparseTensor`.
    305     """
    306     indices_key = indices_key or 'indices'
    307     values_key = values_key or 'values'
    308     self._indices_key = indices_key
    309     self._values_key = values_key
    310     self._shape_key = shape_key
    311     self._shape = shape
    312     self._densify = densify
    313     self._default_value = default_value
    314     keys = [indices_key, values_key]
    315     if shape_key:
    316       keys.append(shape_key)
    317     super(SparseTensor, self).__init__(keys)
    318 
    319   def tensors_to_item(self, keys_to_tensors):
    320     indices = keys_to_tensors[self._indices_key]
    321     values = keys_to_tensors[self._values_key]
    322     if self._shape_key:
    323       shape = keys_to_tensors[self._shape_key]
    324       if isinstance(shape, sparse_tensor.SparseTensor):
    325         shape = sparse_ops.sparse_tensor_to_dense(shape)
    326     elif self._shape:
    327       shape = self._shape
    328     else:
    329       shape = indices.dense_shape
    330     indices_shape = array_ops.shape(indices.indices)
    331     rank = indices_shape[1]
    332     ids = math_ops.cast(indices.values, dtypes.int64)
    333     indices_columns_to_preserve = array_ops.slice(
    334         indices.indices, [0, 0], array_ops.stack([-1, rank - 1]))
    335     new_indices = array_ops.concat(
    336         [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)
    337 
    338     tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
    339     if self._densify:
    340       tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
    341     return tensor
    342 
    343 
    344 class Image(ItemHandler):
    345   """An ItemHandler that decodes a parsed Tensor as an image."""
    346 
    347   def __init__(self,
    348                image_key=None,
    349                format_key=None,
    350                shape=None,
    351                channels=3,
    352                dtype=dtypes.uint8,
    353                repeated=False,
    354                dct_method=''):
    355     """Initializes the image.
    356 
    357     Args:
    358       image_key: the name of the TF-Example feature in which the encoded image
    359         is stored.
    360       format_key: the name of the TF-Example feature in which the image format
    361         is stored.
    362       shape: the output shape of the image as 1-D `Tensor`
    363         [height, width, channels]. If provided, the image is reshaped
    364         accordingly. If left as None, no reshaping is done. A shape should
    365         be supplied only if all the stored images have the same shape.
    366       channels: the number of channels in the image.
    367       dtype: images will be decoded at this bit depth. Different formats
    368         support different bit depths.
    369           See tf.image.decode_image,
    370               tf.decode_raw,
    371       repeated: if False, decodes a single image. If True, decodes a
    372         variable number of image strings from a 1D tensor of strings.
    373       dct_method: An optional string. Defaults to empty string. It only takes
    374         effect when image format is jpeg, used to specify a hint about the
    375         algorithm used for jpeg decompression. Currently valid values
    376         are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
    377         example, the jpeg library does not have that specific option.
    378     """
    379     if not image_key:
    380       image_key = 'image/encoded'
    381     if not format_key:
    382       format_key = 'image/format'
    383 
    384     super(Image, self).__init__([image_key, format_key])
    385     self._image_key = image_key
    386     self._format_key = format_key
    387     self._shape = shape
    388     self._channels = channels
    389     self._dtype = dtype
    390     self._repeated = repeated
    391     self._dct_method = dct_method
    392 
    393   def tensors_to_item(self, keys_to_tensors):
    394     """See base class."""
    395     image_buffer = keys_to_tensors[self._image_key]
    396     image_format = keys_to_tensors[self._format_key]
    397 
    398     if self._repeated:
    399       return map_fn.map_fn(lambda x: self._decode(x, image_format),
    400                            image_buffer, dtype=self._dtype)
    401     else:
    402       return self._decode(image_buffer, image_format)
    403 
    404   def _decode(self, image_buffer, image_format):
    405     """Decodes the image buffer.
    406 
    407     Args:
    408       image_buffer: The tensor representing the encoded image tensor.
    409       image_format: The image format for the image in `image_buffer`. If image
    410         format is `raw`, all images are expected to be in this format, otherwise
    411         this op can decode a mix of `jpg` and `png` formats.
    412 
    413     Returns:
    414       A tensor that represents decoded image of self._shape, or
    415       (?, ?, self._channels) if self._shape is not specified.
    416     """
    417 
    418     def decode_image():
    419       """Decodes a image based on the headers."""
    420       return math_ops.cast(
    421           image_ops.decode_image(image_buffer, channels=self._channels),
    422           self._dtype)
    423 
    424     def decode_jpeg():
    425       """Decodes a jpeg image with specified '_dct_method'."""
    426       return math_ops.cast(
    427           image_ops.decode_jpeg(
    428               image_buffer,
    429               channels=self._channels,
    430               dct_method=self._dct_method), self._dtype)
    431 
    432     def check_jpeg():
    433       """Checks if an image is jpeg."""
    434       # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
    435       # in order to feed the jpeg specify parameter 'dct_method'.
    436       return control_flow_ops.cond(
    437           image_ops.is_jpeg(image_buffer),
    438           decode_jpeg,
    439           decode_image,
    440           name='cond_jpeg')
    441 
    442     def decode_raw():
    443       """Decodes a raw image."""
    444       return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
    445 
    446     pred_fn_pairs = {
    447         math_ops.logical_or(
    448             math_ops.equal(image_format, 'raw'),
    449             math_ops.equal(image_format, 'RAW')): decode_raw,
    450     }
    451     image = control_flow_ops.case(
    452         pred_fn_pairs, default=check_jpeg, exclusive=True)
    453 
    454     image.set_shape([None, None, self._channels])
    455     if self._shape is not None:
    456       image = array_ops.reshape(image, self._shape)
    457 
    458     return image
    459 
    460 
    461 class TFExampleDecoder(data_decoder.DataDecoder):
    462   """A decoder for TensorFlow Examples.
    463 
    464   Decoding Example proto buffers is comprised of two stages: (1) Example parsing
    465   and (2) tensor manipulation.
    466 
    467   In the first stage, the tf.parse_example function is called with a list of
    468   FixedLenFeatures and SparseLenFeatures. These instances tell TF how to parse
    469   the example. The output of this stage is a set of tensors.
    470 
    471   In the second stage, the resulting tensors are manipulated to provide the
    472   requested 'item' tensors.
    473 
    474   To perform this decoding operation, an ExampleDecoder is given a list of
    475   ItemHandlers. Each ItemHandler indicates the set of features for stage 1 and
    476   contains the instructions for post_processing its tensors for stage 2.
    477   """
    478 
    479   def __init__(self, keys_to_features, items_to_handlers):
    480     """Constructs the decoder.
    481 
    482     Args:
    483       keys_to_features: a dictionary from TF-Example keys to either
    484         tf.VarLenFeature or tf.FixedLenFeature instances. See tensorflow's
    485         parsing_ops.py.
    486       items_to_handlers: a dictionary from items (strings) to ItemHandler
    487         instances. Note that the ItemHandler's are provided the keys that they
    488         use to return the final item Tensors.
    489     """
    490     self._keys_to_features = keys_to_features
    491     self._items_to_handlers = items_to_handlers
    492 
    493   def list_items(self):
    494     """See base class."""
    495     return list(self._items_to_handlers.keys())
    496 
    497   def decode(self, serialized_example, items=None):
    498     """Decodes the given serialized TF-example.
    499 
    500     Args:
    501       serialized_example: a serialized TF-example tensor.
    502       items: the list of items to decode. These must be a subset of the item
    503         keys in self._items_to_handlers. If `items` is left as None, then all
    504         of the items in self._items_to_handlers are decoded.
    505 
    506     Returns:
    507       the decoded items, a list of tensor.
    508     """
    509     example = parsing_ops.parse_single_example(serialized_example,
    510                                                self._keys_to_features)
    511 
    512     # Reshape non-sparse elements just once, adding the reshape ops in
    513     # deterministic order.
    514     for k in sorted(self._keys_to_features):
    515       v = self._keys_to_features[k]
    516       if isinstance(v, parsing_ops.FixedLenFeature):
    517         example[k] = array_ops.reshape(example[k], v.shape)
    518 
    519     if not items:
    520       items = self._items_to_handlers.keys()
    521 
    522     outputs = []
    523     for item in items:
    524       handler = self._items_to_handlers[item]
    525       keys_to_tensors = {key: example[key] for key in handler.keys}
    526       outputs.append(handler.tensors_to_item(keys_to_tensors))
    527     return outputs
    528