Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Non-core ops for LabeledTensor."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 import types
     22 
     23 import numpy as np
     24 from six import string_types
     25 
     26 from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
     27 from tensorflow.contrib.labeled_tensor.python.ops import core
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import functional_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import numerics
     34 from tensorflow.python.ops import random_ops
     35 from tensorflow.python.training import input  # pylint: disable=redefined-builtin
     36 
     37 
     38 @tc.returns(core.LabeledTensor)
     39 @tc.accepts(core.LabeledTensor, ops.Tensor, core.Axis,
     40             tc.Optional(string_types))
     41 def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
     42   with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope:
     43     temp_axes = core.Axes([axis] + list(
     44         labeled_tensor.axes.remove(axis.name).values()))
     45     transposed = core.transpose(labeled_tensor, temp_axes.keys())
     46     indexed = core.LabeledTensor(
     47         array_ops.gather(transposed.tensor, indexer), temp_axes)
     48     return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
     49 
     50 
     51 @tc.returns(core.LabeledTensor)
     52 @tc.accepts(core.LabeledTensorLike,
     53             tc.Mapping(string_types,
     54                        tc.Union(slice, collections.Hashable, list)),
     55             tc.Optional(string_types))
     56 def select(labeled_tensor, selection, name=None):
     57   """Slice out a subset of the tensor.
     58 
     59   Args:
     60     labeled_tensor: The input tensor.
     61     selection: A dictionary mapping an axis name to a scalar, slice or list of
     62       values to select. Currently supports two types of selections:
     63         (a) Any number of scalar and/or slice selections.
     64         (b) Exactly one list selection, without any scalars or slices.
     65     name: Optional op name.
     66 
     67   Returns:
     68     The selection as a `LabeledTensor`.
     69 
     70   Raises:
     71     ValueError: If the tensor doesn't have an axis in the selection or if
     72       that axis lacks labels.
     73     KeyError: If any labels in a selection are not found in the original axis.
     74     NotImplementedError: If you attempt to combine a list selection with
     75       scalar selection or another list selection.
     76   """
     77   with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope:
     78     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
     79 
     80     slices = {}
     81     indexers = {}
     82     for axis_name, value in selection.items():
     83       if axis_name not in labeled_tensor.axes:
     84         raise ValueError(
     85             'The tensor does not have an axis named %s. Its axes are: %r' %
     86             (axis_name, labeled_tensor.axes.keys()))
     87       axis = labeled_tensor.axes[axis_name]
     88       if axis.labels is None:
     89         raise ValueError(
     90             'The axis named %s does not have labels. The axis is: %r' %
     91             (axis_name, axis))
     92 
     93       if isinstance(value, slice):
     94         # TODO(shoyer): consider deprecating using slices in favor of lists
     95         if value.start is None:
     96           start = None
     97         else:
     98           start = axis.index(value.start)
     99 
    100         if value.stop is None:
    101           stop = None
    102         else:
    103           # For now, follow the pandas convention of making labeled slices
    104           # inclusive of both bounds.
    105           stop = axis.index(value.stop) + 1
    106 
    107         if value.step is not None:
    108           raise NotImplementedError('slicing with a step is not yet supported')
    109 
    110         slices[axis_name] = slice(start, stop)
    111 
    112       # Needs to be after checking for slices, since slice objects claim to be
    113       # instances of collections.Hashable but hash() on them fails.
    114       elif isinstance(value, collections.Hashable):
    115         slices[axis_name] = axis.index(value)
    116 
    117       elif isinstance(value, list):
    118         if indexers:
    119           raise NotImplementedError(
    120               'select does not yet support more than one list selection at '
    121               'the same time')
    122         indexer = [axis.index(v) for v in value]
    123         indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64)
    124 
    125       else:
    126         # If type checking is working properly, this shouldn't be possible.
    127         raise TypeError('cannot handle arbitrary types')
    128 
    129     if indexers and slices:
    130       raise NotImplementedError(
    131           'select does not yet support combined scalar and list selection')
    132 
    133     # For now, handle array selection separately, because tf.gather_nd does
    134     # not support gradients yet. Later, using gather_nd will let us combine
    135     # these paths.
    136     if indexers:
    137       (axis_name, indexer), = indexers.items()
    138       axis = core.Axis(axis_name, selection[axis_name])
    139       return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope)
    140     else:
    141       return core.slice_function(labeled_tensor, slices, name=scope)
    142 
    143 
    144 @tc.returns(core.LabeledTensor)
    145 @tc.accepts(
    146     tc.Collection(core.LabeledTensorLike), string_types,
    147     tc.Optional(string_types))
    148 def concat(labeled_tensors, axis_name, name=None):
    149   """Concatenate tensors along a dimension.
    150 
    151   See tf.concat.
    152 
    153   Args:
    154     labeled_tensors: A list of input LabeledTensors.
    155     axis_name: The name of the axis along which to concatenate.
    156     name: Optional op name.
    157 
    158   Returns:
    159     The concatenated tensor.
    160     The coordinate labels for the concatenation dimension are also concatenated,
    161     if they are available for every tensor.
    162 
    163   Raises:
    164     ValueError: If fewer than one tensor inputs is provided, if the tensors
    165       have incompatible axes, or if `axis_name` isn't the name of an axis.
    166   """
    167   with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
    168     labeled_tensors = [
    169         core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    170     ]
    171 
    172     if len(labeled_tensors) < 1:
    173       raise ValueError('concat expects at least 1 tensor, but received %s' %
    174                        labeled_tensors)
    175 
    176     # All tensors must have these axes.
    177     axes_0 = labeled_tensors[0].axes
    178     axis_names = list(axes_0.keys())
    179 
    180     if axis_name not in axis_names:
    181       raise ValueError('%s not in %s' % (axis_name, axis_names))
    182 
    183     shared_axes = axes_0.remove(axis_name)
    184 
    185     tensors = [labeled_tensors[0].tensor]
    186     concat_axis_list = [axes_0[axis_name]]
    187     for labeled_tensor in labeled_tensors[1:]:
    188       current_shared_axes = labeled_tensor.axes.remove(axis_name)
    189       if current_shared_axes != shared_axes:
    190         # TODO(shoyer): add more specific checks about what went wrong,
    191         # including raising AxisOrderError when appropriate
    192         raise ValueError('Mismatched shared axes: the first tensor '
    193                          'had axes %r but this tensor has axes %r.' %
    194                          (shared_axes, current_shared_axes))
    195 
    196       # Accumulate the axis labels, if they're available.
    197       concat_axis_list.append(labeled_tensor.axes[axis_name])
    198       tensors.append(labeled_tensor.tensor)
    199 
    200     concat_axis = core.concat_axes(concat_axis_list)
    201     concat_dimension = axis_names.index(axis_name)
    202     concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope)
    203     values = list(axes_0.values())
    204     concat_axes = (values[:concat_dimension] + [concat_axis] +
    205                    values[concat_dimension + 1:])
    206 
    207     return core.LabeledTensor(concat_tensor, concat_axes)
    208 
    209 
    210 # TODO(shoyer): rename pack/unpack to stack/unstack
    211 
    212 
    213 @tc.returns(core.LabeledTensor)
    214 @tc.accepts(
    215     tc.Collection(core.LabeledTensorLike),
    216     tc.Union(string_types, core.AxisLike), int, tc.Optional(string_types))
    217 def pack(labeled_tensors, new_axis, axis_position=0, name=None):
    218   """Pack tensors along a new axis.
    219 
    220   See tf.pack.
    221 
    222   Args:
    223     labeled_tensors: The input tensors, which must have identical axes.
    224     new_axis: The name of the new axis, or a tuple containing the name
    225       and coordinate labels.
    226     axis_position: Optional integer position at which to insert the new axis.
    227     name: Optional op name.
    228 
    229   Returns:
    230     The packed tensors as a single LabeledTensor, with `new_axis` in the given
    231     `axis_position`.
    232 
    233   Raises:
    234     ValueError: If fewer than one input tensors is provided, or if the tensors
    235       don't have identical axes.
    236   """
    237   with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope:
    238     labeled_tensors = [
    239         core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    240     ]
    241 
    242     if len(labeled_tensors) < 1:
    243       raise ValueError('pack expects at least 1 tensors, but received %s' %
    244                        labeled_tensors)
    245 
    246     axes_0 = labeled_tensors[0].axes
    247     for t in labeled_tensors:
    248       if t.axes != axes_0:
    249         raise ValueError('Non-identical axes. Expected %s but got %s' %
    250                          (axes_0, t.axes))
    251 
    252     pack_op = array_ops.stack(
    253         [t.tensor for t in labeled_tensors], axis=axis_position, name=scope)
    254     axes = list(axes_0.values())
    255     axes.insert(axis_position, new_axis)
    256     return core.LabeledTensor(pack_op, axes)
    257 
    258 
    259 @tc.returns(tc.List(core.LabeledTensor))
    260 @tc.accepts(core.LabeledTensorLike,
    261             tc.Optional(string_types), tc.Optional(string_types))
    262 def unpack(labeled_tensor, axis_name=None, name=None):
    263   """Unpack the tensor.
    264 
    265   See tf.unpack.
    266 
    267   Args:
    268     labeled_tensor: The input tensor.
    269     axis_name: Optional name of axis to unpack. By default, the first axis is
    270       used.
    271     name: Optional op name.
    272 
    273   Returns:
    274     The list of unpacked LabeledTensors.
    275 
    276   Raises:
    277     ValueError: If `axis_name` is not an axis on the input.
    278   """
    279   with ops.name_scope(name, 'lt_unpack', [labeled_tensor]) as scope:
    280     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    281 
    282     axis_names = list(labeled_tensor.axes.keys())
    283     if axis_name is None:
    284       axis_name = axis_names[0]
    285 
    286     if axis_name not in axis_names:
    287       raise ValueError('%s not in %s' % (axis_name, axis_names))
    288     axis = axis_names.index(axis_name)
    289 
    290     unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope)
    291     axes = [a for i, a in enumerate(labeled_tensor.axes.values()) if i != axis]
    292     return [core.LabeledTensor(t, axes) for t in unpack_ops]
    293 
    294 
    295 @tc.returns(core.LabeledTensor)
    296 @tc.accepts(core.LabeledTensorLike,
    297             tc.Collection(string_types),
    298             tc.Collection(tc.Union(string_types, core.AxisLike)),
    299             tc.Optional(string_types))
    300 def reshape(labeled_tensor, existing_axes, new_axes, name=None):
    301   """Reshape specific axes of a LabeledTensor.
    302 
    303   Non-indicated axes remain in their original locations.
    304 
    305   Args:
    306     labeled_tensor: The input tensor.
    307     existing_axes: List of axis names found on the input tensor. These must
    308       appear sequentially in the list of axis names on the input. In other
    309       words, they must be a valid slice of `list(labeled_tensor.axes.keys())`.
    310     new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects
    311       providing new axes with which to replace `existing_axes` in the reshaped
    312       result. At most one element of `new_axes` may be a string, indicating an
    313       axis with unknown size.
    314     name: Optional op name.
    315 
    316   Returns:
    317     The reshaped LabeledTensor.
    318 
    319   Raises:
    320     ValueError: If `existing_axes` are not all axes on the input, or if more
    321      than one of `new_axes` has unknown size.
    322     AxisOrderError: If `existing_axes` are not a slice of axis names on the
    323       input.
    324   """
    325   with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope:
    326     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    327 
    328     original_axis_names = list(labeled_tensor.axes.keys())
    329     existing_axes = list(existing_axes)
    330     if not set(existing_axes) <= set(original_axis_names):
    331       raise ValueError('existing_axes %r are not contained in the set of axis '
    332                        'names %r on the input labeled tensor' %
    333                        (existing_axes, original_axis_names))
    334 
    335     start = original_axis_names.index(existing_axes[0])
    336     stop = original_axis_names.index(existing_axes[-1]) + 1
    337 
    338     if existing_axes != original_axis_names[start:stop]:
    339       # We could support existing_axes that aren't a slice by using transpose,
    340       # but that could lead to unpredictable performance consequences because
    341       # transposes are not free in TensorFlow. If we did transpose
    342       # automatically, the user might never realize that their data is being
    343       # produced with the wrong order. (The later will occur with some frequency
    344       # because of how broadcasting automatically choose axis order.)
    345       # So for now we've taken the strict approach.
    346       raise core.AxisOrderError(
    347           'existing_axes %r are not a slice of axis names %r on the input '
    348           'labeled tensor. Use `transpose` or `impose_axis_order` to reorder '
    349           'axes on the input explicitly.' %
    350           (existing_axes, original_axis_names))
    351 
    352     if sum(isinstance(axis, string_types) for axis in new_axes) > 1:
    353       raise ValueError(
    354           'at most one axis in new_axes can have unknown size. All other '
    355           'axes must have an indicated integer size or labels: %r' % new_axes)
    356 
    357     original_values = list(labeled_tensor.axes.values())
    358     axis_size = lambda axis: -1 if axis.size is None else axis.size
    359     shape = [axis_size(axis) for axis in original_values[:start]]
    360     for axis_ref in new_axes:
    361       if isinstance(axis_ref, string_types):
    362         shape.append(-1)
    363       else:
    364         axis = core.as_axis(axis_ref)
    365         shape.append(axis_size(axis))
    366     shape.extend(axis_size(axis) for axis in original_values[stop:])
    367 
    368     reshaped_tensor = array_ops.reshape(
    369         labeled_tensor.tensor, shape, name=scope)
    370     axes = original_values[:start] + list(new_axes) + original_values[stop:]
    371     return core.LabeledTensor(reshaped_tensor, axes)
    372 
    373 
    374 @tc.returns(core.LabeledTensor)
    375 @tc.accepts(core.LabeledTensorLike, string_types, string_types,
    376             tc.Optional(string_types))
    377 def rename_axis(labeled_tensor, existing_name, new_name, name=None):
    378   """Rename an axis of LabeledTensor.
    379 
    380   Args:
    381     labeled_tensor: The input tensor.
    382     existing_name: Name for an existing axis on the input.
    383     new_name: Desired replacement name.
    384     name: Optional op name.
    385 
    386   Returns:
    387     LabeledTensor with renamed axis.
    388 
    389   Raises:
    390     ValueError: If `existing_name` is not an axis on the input.
    391   """
    392   with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope:
    393     if existing_name not in labeled_tensor.axes:
    394       raise ValueError('existing_name %r are not contained in the set of axis '
    395                        'names %r on the input labeled tensor' %
    396                        (existing_name, labeled_tensor.axes.keys()))
    397     new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value)
    398     return reshape(labeled_tensor, [existing_name], [new_axis], name=scope)
    399 
    400 
    401 @tc.returns(tc.List(core.LabeledTensor))
    402 @tc.accepts(string_types, collections.Callable, int, bool,
    403             tc.Collection(core.LabeledTensorLike), bool,
    404             tc.Optional(string_types))
    405 def _batch_helper(default_name,
    406                   batch_fn,
    407                   batch_size,
    408                   enqueue_many,
    409                   labeled_tensors,
    410                   allow_smaller_final_batch,
    411                   name=None):
    412   with ops.name_scope(name, default_name, labeled_tensors) as scope:
    413     labeled_tensors = [
    414         core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
    415     ]
    416 
    417     batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope)
    418     # TODO(shoyer): Remove this when they sanitize the TF API.
    419     if not isinstance(batch_ops, list):
    420       assert isinstance(batch_ops, ops.Tensor)
    421       batch_ops = [batch_ops]
    422 
    423     if allow_smaller_final_batch:
    424       batch_size = None
    425 
    426     @tc.returns(core.Axes)
    427     @tc.accepts(core.Axes)
    428     def output_axes(axes):
    429       if enqueue_many:
    430         if 'batch' not in axes or list(axes.keys()).index('batch') != 0:
    431           raise ValueError(
    432               'When enqueue_many is True, input tensors must have an axis '
    433               'called "batch" as their first dimension, '
    434               'but axes were %s' % axes)
    435         culled_axes = axes.remove('batch')
    436         return core.Axes([('batch', batch_size)] + list(culled_axes.values()))
    437       else:
    438         return core.Axes([('batch', batch_size)] + list(axes.values()))
    439 
    440     output_labeled_tensors = []
    441     for i, tensor in enumerate(batch_ops):
    442       axes = output_axes(labeled_tensors[i].axes)
    443       output_labeled_tensors.append(core.LabeledTensor(tensor, axes))
    444 
    445     return output_labeled_tensors
    446 
    447 
    448 @tc.returns(tc.List(core.LabeledTensor))
    449 @tc.accepts(
    450     tc.Collection(core.LabeledTensorLike), int, int, int, bool, bool,
    451     tc.Optional(string_types))
    452 def batch(labeled_tensors,
    453           batch_size,
    454           num_threads=1,
    455           capacity=32,
    456           enqueue_many=False,
    457           allow_smaller_final_batch=False,
    458           name=None):
    459   """Rebatch a tensor.
    460 
    461   See tf.batch.
    462 
    463   Args:
    464     labeled_tensors: The input tensors.
    465     batch_size: The output batch size.
    466     num_threads: See tf.batch.
    467     capacity: See tf.batch.
    468     enqueue_many: If true, the input tensors must contain a 'batch' axis as
    469       their first axis.
    470       If false, the input tensors must not contain a 'batch' axis.
    471       See tf.batch.
    472     allow_smaller_final_batch: See tf.batch.
    473     name: Optional op name.
    474 
    475   Returns:
    476     The rebatched tensors.
    477     If enqueue_many is false, the output tensors will have a new 'batch' axis
    478       as their first axis.
    479 
    480   Raises:
    481     ValueError: If enqueue_many is True and the first axis of the tensors
    482       isn't "batch".
    483   """
    484 
    485   def fn(tensors, scope):
    486     return input.batch(
    487         tensors,
    488         batch_size=batch_size,
    489         num_threads=num_threads,
    490         capacity=capacity,
    491         enqueue_many=enqueue_many,
    492         allow_smaller_final_batch=allow_smaller_final_batch,
    493         name=scope)
    494 
    495   return _batch_helper('lt_batch', fn, batch_size, enqueue_many,
    496                        labeled_tensors, allow_smaller_final_batch, name)
    497 
    498 
    499 @tc.returns(tc.List(core.LabeledTensor))
    500 @tc.accepts(
    501     tc.Collection(core.LabeledTensorLike), int, int, int, bool, int,
    502     tc.Optional(int), bool, tc.Optional(string_types))
    503 def shuffle_batch(labeled_tensors,
    504                   batch_size,
    505                   num_threads=1,
    506                   capacity=32,
    507                   enqueue_many=False,
    508                   min_after_dequeue=0,
    509                   seed=None,
    510                   allow_smaller_final_batch=False,
    511                   name=None):
    512   """Rebatch a tensor, with shuffling.
    513 
    514   See tf.batch.
    515 
    516   Args:
    517     labeled_tensors: The input tensors.
    518     batch_size: The output batch size.
    519     num_threads: See tf.batch.
    520     capacity: See tf.batch.
    521     enqueue_many: If true, the input tensors must contain a 'batch' axis as
    522       their first axis.
    523       If false, the input tensors must not contain a 'batch' axis.
    524       See tf.batch.
    525     min_after_dequeue: Minimum number of elements in the queue after a dequeue,
    526       used to ensure mixing.
    527     seed: Optional random seed.
    528     allow_smaller_final_batch: See tf.batch.
    529     name: Optional op name.
    530 
    531   Returns:
    532     The rebatched tensors.
    533     If enqueue_many is false, the output tensors will have a new 'batch' axis
    534       as their first axis.
    535 
    536   Raises:
    537     ValueError: If enqueue_many is True and the first axis of the tensors
    538       isn't "batch".
    539   """
    540 
    541   def fn(tensors, scope):
    542     return input.shuffle_batch(
    543         tensors,
    544         batch_size=batch_size,
    545         num_threads=num_threads,
    546         capacity=capacity,
    547         enqueue_many=enqueue_many,
    548         min_after_dequeue=min_after_dequeue,
    549         seed=seed,
    550         allow_smaller_final_batch=allow_smaller_final_batch,
    551         name=scope)
    552 
    553   return _batch_helper('lt_shuffle_batch', fn, batch_size, enqueue_many,
    554                        labeled_tensors, allow_smaller_final_batch, name)
    555 
    556 
    557 @tc.returns(core.LabeledTensor)
    558 @tc.accepts(core.LabeledTensorLike,
    559             tc.Mapping(string_types, int),
    560             tc.Optional(int), tc.Optional(string_types))
    561 def random_crop(labeled_tensor, shape_map, seed=None, name=None):
    562   """Randomly crops a tensor to a given size.
    563 
    564   See tf.random_crop.
    565 
    566   Args:
    567     labeled_tensor: The input tensor.
    568     shape_map: A dictionary mapping axis names to the size of the random crop
    569       for that dimension.
    570     seed: An optional random seed.
    571     name: An optional op name.
    572 
    573   Returns:
    574     A tensor of the same rank as `labeled_tensor`, cropped randomly in the
    575     selected dimensions.
    576 
    577   Raises:
    578     ValueError: If the shape map contains an axis name not in the input tensor.
    579   """
    580   with ops.name_scope(name, 'lt_random_crop', [labeled_tensor]) as scope:
    581     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    582 
    583     for axis_name in shape_map:
    584       if axis_name not in labeled_tensor.axes:
    585         raise ValueError('Selection axis %s not in axes %s' %
    586                          (axis_name, labeled_tensor.axes))
    587 
    588     shape = []
    589     axes = []
    590     for axis in labeled_tensor.axes.values():
    591       if axis.name in shape_map:
    592         size = shape_map[axis.name]
    593         shape.append(size)
    594         # We lose labels for the axes we crop, leaving just the size.
    595         axes.append((axis.name, size))
    596       else:
    597         shape.append(len(axis))
    598         axes.append(axis)
    599 
    600     crop_op = random_ops.random_crop(
    601         labeled_tensor.tensor, shape, seed=seed, name=scope)
    602 
    603     return core.LabeledTensor(crop_op, axes)
    604 
    605 
    606 # TODO(shoyer): Allow the user to select the axis over which to map.
    607 @tc.returns(core.LabeledTensor)
    608 @tc.accepts(collections.Callable, core.LabeledTensorLike,
    609             tc.Optional(string_types))
    610 def map_fn(fn, labeled_tensor, name=None):
    611   """Map on the list of tensors unpacked from labeled_tensor.
    612 
    613   See tf.map_fn.
    614 
    615   Args:
    616     fn: The function to apply to each unpacked LabeledTensor.
    617       It should have type LabeledTensor -> LabeledTensor.
    618     labeled_tensor: The input tensor.
    619     name: Optional op name.
    620 
    621   Returns:
    622     A tensor that packs the results of applying fn to the list of tensors
    623     unpacked from labeled_tensor.
    624   """
    625   with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope:
    626     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    627 
    628     unpack_lts = unpack(labeled_tensor)
    629 
    630     # TODO(ericmc): Fix this upstream.
    631     if labeled_tensor.dtype == dtypes.string:
    632       # We must construct the full graph here, because functional_ops.map_fn
    633       # doesn't work for string-valued tensors.
    634       # Constructing the full graph may be slow.
    635       map_lts = [fn(t) for t in unpack_lts]
    636       return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope)
    637     else:
    638       # Figure out what the axis labels should be, but use tf.map_fn to
    639       # construct the graph because it's efficient.
    640       # It may be slow to construct the full graph, so we infer the labels from
    641       # the first element.
    642       # TODO(ericmc): This builds a subgraph which then gets thrown away.
    643       # Find a more elegant solution.
    644       first_map_lt = fn(unpack_lts[0])
    645       final_axes = list(labeled_tensor.axes.values())[:1] + list(
    646           first_map_lt.axes.values())
    647 
    648       @tc.returns(ops.Tensor)
    649       @tc.accepts(ops.Tensor)
    650       def tf_fn(tensor):
    651         original_axes = list(labeled_tensor.axes.values())[1:]
    652         tensor_lt = core.LabeledTensor(tensor, original_axes)
    653         return fn(tensor_lt).tensor
    654 
    655       map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor)
    656       map_lt = core.LabeledTensor(map_op, final_axes)
    657 
    658       return core.identity(map_lt, name=scope)
    659 
    660 
    661 @tc.returns(core.LabeledTensor)
    662 @tc.accepts(collections.Callable, core.LabeledTensorLike,
    663             core.LabeledTensorLike, tc.Optional(string_types))
    664 def foldl(fn, labeled_tensor, initial_value, name=None):
    665   """Left fold on the list of tensors unpacked from labeled_tensor.
    666 
    667   See tf.foldl.
    668 
    669   Args:
    670     fn: The function to apply to each unpacked LabeledTensor.
    671       It should have type (LabeledTensor, LabeledTensor) -> LabeledTensor.
    672       Its arguments are (accumulated_value, next_value).
    673     labeled_tensor: The input tensor.
    674     initial_value: The initial value of the accumulator.
    675     name: Optional op name.
    676 
    677   Returns:
    678     The accumulated value.
    679   """
    680   with ops.name_scope(name, 'lt_foldl',
    681                       [labeled_tensor, initial_value]) as scope:
    682     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    683     initial_value = core.convert_to_labeled_tensor(initial_value)
    684 
    685     @tc.returns(ops.Tensor)
    686     @tc.accepts(ops.Tensor, ops.Tensor)
    687     def tf_fn(accumulator, next_element):
    688       accumulator_lt = core.LabeledTensor(accumulator, initial_value.axes)
    689       next_element_lt = core.LabeledTensor(
    690           next_element, list(labeled_tensor.axes.values())[1:])
    691       return fn(accumulator_lt, next_element_lt).tensor
    692 
    693     foldl_op = functional_ops.foldl(
    694         tf_fn, labeled_tensor.tensor, initializer=initial_value.tensor)
    695     foldl_lt = core.LabeledTensor(foldl_op, initial_value.axes)
    696 
    697     return core.identity(foldl_lt, name=scope)
    698 
    699 
    700 @tc.returns(core.LabeledTensor)
    701 @tc.accepts(core.LabeledTensorLike,
    702             tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
    703 def squeeze(labeled_tensor, axis_names=None, name=None):
    704   """Remove size-1 dimensions.
    705 
    706   See tf.squeeze.
    707 
    708   Args:
    709     labeled_tensor: The input tensor.
    710     axis_names: The names of the dimensions to remove, or None to remove
    711       all size-1 dimensions.
    712     name: Optional op name.
    713 
    714   Returns:
    715     A tensor with the specified dimensions removed.
    716 
    717   Raises:
    718     ValueError: If the named axes are not in the tensor, or if they are
    719       not size-1.
    720   """
    721   with ops.name_scope(name, 'lt_squeeze', [labeled_tensor]) as scope:
    722     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    723 
    724     if axis_names is None:
    725       axis_names = [a.name for a in labeled_tensor.axes.values() if len(a) == 1]
    726 
    727     for axis_name in axis_names:
    728       if axis_name not in labeled_tensor.axes:
    729         raise ValueError('axis %s is not in tensor axes %s' %
    730                          (axis_name, labeled_tensor.axes))
    731       elif len(labeled_tensor.axes[axis_name]) != 1:
    732         raise ValueError(
    733             'cannot squeeze axis with size greater than 1: (%s, %s)' %
    734             (axis_name, labeled_tensor.axes[axis_name]))
    735 
    736     squeeze_dimensions = []
    737     axes = []
    738     for i, axis in enumerate(labeled_tensor.axes.values()):
    739       if axis.name in axis_names:
    740         squeeze_dimensions.append(i)
    741       else:
    742         axes.append(axis)
    743 
    744     if squeeze_dimensions:
    745       squeeze_op = array_ops.squeeze(
    746           labeled_tensor.tensor, squeeze_dimensions, name=scope)
    747     else:
    748       squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope)
    749 
    750     return core.LabeledTensor(squeeze_op, axes)
    751 
    752 
    753 # pylint: disable=invalid-name
    754 ReduceAxis = tc.Union(string_types,
    755                       tc.Tuple(string_types, collections.Hashable))
    756 ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis)))
    757 # pylint: enable=invalid-name
    758 
    759 
    760 @tc.returns(core.LabeledTensor)
    761 @tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
    762             tc.Optional(string_types))
    763 def matmul(a, b, name=None):
    764   """Matrix multiply two tensors with rank 1 or 2.
    765 
    766   If both tensors have rank 2, a matrix-matrix product is performed.
    767   If one tensor has rank 1 and the other has rank 2, then a matrix-vector
    768   product is performed.
    769   If both tensors have rank 1, then a vector dot-product is performed.
    770   (This behavior matches that of `numpy.dot`.)
    771 
    772   Both tensors must share exactly one dimension in common, which is the
    773   dimension the operation is summed along. The inputs will be automatically
    774   transposed if necessary as part of the matmul op.
    775 
    776   We intend to eventually support `matmul` on higher rank input, and also
    777   eventually support summing over any number shared dimensions (via an `axis`
    778   argument), but neither of these features has been implemented yet.
    779 
    780   Args:
    781     a: First LabeledTensor.
    782     b: Second LabeledTensor.
    783     name: Optional op name.
    784 
    785   Returns:
    786     LabeledTensor with the result of matrix multiplication. Axes are ordered by
    787     the current axis_order_scope, if set, or in or order of appearance on the
    788     inputs.
    789 
    790   Raises:
    791     NotImplementedError: If inputs have rank >2 or share multiple axes.
    792     ValueError: If the inputs have rank 0 or do not share any axes.
    793   """
    794   with ops.name_scope(name, 'lt_matmul', [a, b]) as scope:
    795 
    796     a = core.convert_to_labeled_tensor(a)
    797     b = core.convert_to_labeled_tensor(b)
    798 
    799     if len(a.axes) > 2 or len(b.axes) > 2:
    800       # We could pass batched inputs to tf.matmul to make this work, but we
    801       # would also need to use tf.tile and/or tf.transpose. These are more
    802       # expensive than doing reshapes, so it's not clear if it's a good idea to
    803       # do this automatically.
    804       raise NotImplementedError(
    805           'matmul currently requires inputs with rank 2 or less, but '
    806           'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
    807 
    808     if not a.axes or not b.axes:
    809       raise ValueError(
    810           'matmul currently requires inputs with at least rank 1, but '
    811           'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
    812 
    813     shared_axes = set(a.axes) & set(b.axes)
    814     if len(shared_axes) > 1:
    815       raise NotImplementedError(
    816           'matmul does not yet support summing over multiple shared axes: %r. '
    817           'Use transpose and reshape to create a single shared axis to sum '
    818           'over.' % shared_axes)
    819     if not shared_axes:
    820       raise ValueError('there must have exactly one axis in common between '
    821                        'input to matmul: %r, %r' %
    822                        (a.axes.keys(), b.axes.keys()))
    823     shared_axis, = shared_axes
    824 
    825     if a.axes[shared_axis] != b.axes[shared_axis]:
    826       raise ValueError('axis %r does not match on input arguments: %r vs %r' %
    827                        (shared_axis, a.axes[shared_axis].value,
    828                         b.axes[shared_axis].value))
    829 
    830     result_axes = []
    831     for axes in [a.axes, b.axes]:
    832       for axis in axes.values():
    833         if axis.name != shared_axis:
    834           result_axes.append(axis)
    835 
    836     axis_scope_order = core.get_axis_order()
    837     if axis_scope_order is not None:
    838       result_axis_names = [axis.name for axis in result_axes]
    839       new_axis_names = [
    840           name for name in axis_scope_order if name in result_axis_names
    841       ]
    842       if new_axis_names != result_axis_names:
    843         # switch a and b
    844         b, a = a, b
    845         # result_axes is a list of length 1 or 2
    846         result_axes = result_axes[::-1]
    847 
    848     squeeze_dims = []
    849 
    850     if len(a.axes) == 1:
    851       a_tensor = array_ops.reshape(a.tensor, (1, -1))
    852       squeeze_dims.append(0)
    853       transpose_a = False
    854     else:
    855       a_tensor = a.tensor
    856       transpose_a = list(a.axes.keys()).index(shared_axis) == 0
    857 
    858     if len(b.axes) == 1:
    859       b_tensor = array_ops.reshape(b.tensor, (-1, 1))
    860       squeeze_dims.append(1)
    861       transpose_b = False
    862     else:
    863       b_tensor = b.tensor
    864       transpose_b = list(b.axes.keys()).index(shared_axis) == 1
    865 
    866     result_op = math_ops.matmul(
    867         a_tensor, b_tensor, transpose_a=transpose_a, transpose_b=transpose_b)
    868 
    869     if squeeze_dims:
    870       result_op = array_ops.squeeze(result_op, squeeze_dims)
    871     result_op = array_ops.identity(result_op, name=scope)
    872 
    873     return core.LabeledTensor(result_op, result_axes)
    874 
    875 
    876 @tc.returns(types.FunctionType)
    877 @tc.accepts(string_types, collections.Callable)
    878 def define_reduce_op(op_name, reduce_fn):
    879   """Define a reduction op for labeled tensors.
    880 
    881   Args:
    882     op_name: string name of the TensorFlow op.
    883     reduce_fn: function to call to evaluate the op on a tf.Tensor.
    884 
    885   Returns:
    886     Function defining the given reduction op that acts on a LabeledTensor.
    887   """
    888 
    889   default_name = 'lt_%s' % op_name
    890 
    891   @tc.returns(core.LabeledTensor)
    892   @tc.accepts(core.LabeledTensorLike, ReduceAxes, tc.Optional(string_types))
    893   def op(labeled_tensor, axes=None, name=None):
    894     """Computes the given reduction across the given axes of a LabeledTensor.
    895 
    896     See `tf.{op_name}` for full details.
    897 
    898     Args:
    899       labeled_tensor: The input tensor.
    900       axes: A set of axes or None.
    901         If None, all axes will be reduced.
    902         Axes must all be strings, in which case those dimensions will be
    903         removed, or pairs of (name, None) or (name, label), in which case those
    904         dimensions will be kept.
    905       name: Optional op name.
    906 
    907     Returns:
    908       The reduced LabeledTensor.
    909 
    910     Raises:
    911       ValueError: if any of the axes to reduce over are not found on
    912         `labeled_tensor`.
    913     """
    914     with ops.name_scope(name, default_name, [labeled_tensor]) as scope:
    915       labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
    916 
    917       if axes is None:
    918         axes = labeled_tensor.axes.keys()
    919 
    920       if isinstance(axes, (string_types, tuple)):
    921         axes = [axes]
    922 
    923       reduction_axes = {}
    924       axes_to_squeeze = []
    925       for a in axes:
    926         if isinstance(a, string_types):
    927           # We squeeze out this axis.
    928           reduction_axes[a] = a
    929           axes_to_squeeze.append(a)
    930         else:
    931           # We keep this axis, with the user-provided labels.
    932           (axis_name, label) = a
    933           if label is not None:
    934             # The input was a single label, so make it a list so it can be
    935             # turned into an Axis.
    936             label = [label]
    937           reduction_axes[axis_name] = (axis_name, label)
    938 
    939       for axis_name in reduction_axes:
    940         if axis_name not in labeled_tensor.axes:
    941           raise ValueError('Axis %s not in axes %s' %
    942                            (axis_name, labeled_tensor.axes))
    943 
    944       intermediate_axes = []
    945       reduction_dimensions = []
    946       for i, axis in enumerate(labeled_tensor.axes.values()):
    947         if axis.name in reduction_axes:
    948           intermediate_axes.append(reduction_axes[axis.name])
    949           reduction_dimensions.append(i)
    950         else:
    951           intermediate_axes.append(axis)
    952 
    953       reduce_op = reduce_fn(
    954           labeled_tensor.tensor, reduction_dimensions, keepdims=True)
    955       reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes)
    956 
    957       return squeeze(reduce_lt, axes_to_squeeze, name=scope)
    958 
    959   op.__doc__ = op.__doc__.format(op_name=op_name)
    960   op.__name__ = op_name
    961 
    962   return op
    963 
    964 
    965 reduce_all = define_reduce_op('reduce_all', math_ops.reduce_all)
    966 reduce_any = define_reduce_op('reduce_any', math_ops.reduce_any)
    967 reduce_logsumexp = define_reduce_op('reduce_logsumexp',
    968                                     math_ops.reduce_logsumexp)
    969 reduce_max = define_reduce_op('reduce_max', math_ops.reduce_max)
    970 reduce_mean = define_reduce_op('reduce_mean', math_ops.reduce_mean)
    971 reduce_min = define_reduce_op('reduce_min', math_ops.reduce_min)
    972 reduce_prod = define_reduce_op('reduce_prod', math_ops.reduce_prod)
    973 reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum)
    974 
    975 
    976 @tc.returns(core.LabeledTensor)
    977 @tc.accepts(core.LabeledTensorLike,
    978             tc.Mapping(str, tc.Union(int, ops.Tensor)),
    979             tc.Optional(string_types))
    980 def tile(labeled_tensor, multiples, name=None):
    981   """Constructs a tensor by tiling a given tensor.
    982 
    983   Only axes without tick-labels can be tiled. (Otherwise, axis labels on tiled
    984   tensors would no longer be unique.)
    985 
    986   See lt.tile.
    987 
    988   Args:
    989     labeled_tensor: The input tensor.
    990     multiples: A mapping where the keys are axis names and the values are the
    991       integer number of times to tile along that axis. Only axes with a multiple
    992       different than 1 need be included.
    993     name: Optional op name.
    994 
    995   Returns:
    996     A tensor with the indicated axes tiled.
    997 
    998   Raises:
    999     ValueError: If the tiled axes are not axes in the input tensor, or if any
   1000       axes in multiples have tick labels.
   1001   """
   1002   with ops.name_scope(name, 'lt_tile', [labeled_tensor]) as scope:
   1003     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1004 
   1005     if not set(multiples.keys()) <= set(labeled_tensor.axes.keys()):
   1006       raise ValueError('tile axes %r are not contained in the set of axis '
   1007                        'names %r on the input labeled tensor' %
   1008                        (multiples.keys(), labeled_tensor.axes))
   1009 
   1010     labeled_axes = [
   1011         name for name in multiples
   1012         if labeled_tensor.axes[name].labels is not None
   1013     ]
   1014     if labeled_axes:
   1015       raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes)
   1016 
   1017     multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes]
   1018     tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope)
   1019 
   1020     new_axes = [
   1021         axis.name if axis.labels is None else axis
   1022         for axis in labeled_tensor.axes.values()
   1023     ]
   1024     return core.LabeledTensor(tile_op, new_axes)
   1025 
   1026 
   1027 @tc.returns(core.LabeledTensor)
   1028 @tc.accepts(core.LabeledTensorLike,
   1029             tc.Mapping(str, tc.Tuple(core.AxisValue, core.AxisValue)),
   1030             string_types, tc.Optional(string_types))
   1031 def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
   1032   """Pads a tensor.
   1033 
   1034   See tf.pad.
   1035 
   1036   Args:
   1037     labeled_tensor: The input tensor.
   1038     paddings: A mapping where the keys are axis names and the values are
   1039       tuples where the first element is the padding to insert at the beginning
   1040       of the axis and the second is the padding to insert at the end of the
   1041       axis.
   1042     mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
   1043     name: Optional op name.
   1044 
   1045   Returns:
   1046     A tensor with the indicated axes padded, optionally with those axes extended
   1047     with the provided labels.
   1048 
   1049   Raises:
   1050     ValueError: If the padded axes are not axes in the input tensor.
   1051   """
   1052   with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
   1053     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1054 
   1055     if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
   1056       raise ValueError('pad axes %r are not contained in the set of axis '
   1057                        'names %r on the input labeled tensor' %
   1058                        (paddings.keys(), labeled_tensor.axes))
   1059 
   1060     new_axes = []
   1061     padding_pairs = []
   1062     for name, axis in labeled_tensor.axes.items():
   1063       if name in paddings:
   1064         padding_before, padding_after = paddings[name]
   1065         axis_before = core.Axis(name, padding_before)
   1066         axis_after = core.Axis(name, padding_after)
   1067         new_axes.append(core.concat_axes([axis_before, axis, axis_after]))
   1068         padding_pairs.append((len(axis_before), len(axis_after)))
   1069       else:
   1070         new_axes.append(axis)
   1071         padding_pairs.append((0, 0))
   1072 
   1073     pad_op = array_ops.pad(labeled_tensor.tensor,
   1074                            padding_pairs,
   1075                            mode,
   1076                            name=scope)
   1077 
   1078     return core.LabeledTensor(pad_op, new_axes)
   1079 
   1080 
   1081 @tc.returns(core.LabeledTensor)
   1082 @tc.accepts(
   1083     tc.Union(np.ndarray, list, tuple, core.Scalar),
   1084     tc.Optional(dtypes.DType),
   1085     tc.Optional(
   1086         tc.Union(core.Axes, tc.Collection(
   1087             tc.Union(string_types, core.AxisLike)))), tc.Optional(string_types))
   1088 def constant(value, dtype=None, axes=None, name=None):
   1089   """Creates a constant tensor.
   1090 
   1091   If `axes` includes any strings, shape is inferred from `value`. Otherwise,
   1092   the sizes of the given `axes` are used to set `shape` for `tf.constant`.
   1093 
   1094   See tf.constant for more details.
   1095 
   1096   Args:
   1097     value: The input tensor.
   1098     dtype: The type of the returned tensor.
   1099     axes: Optional Axes, list of strings or list of objects coercible to Axis
   1100       objects. By default, axes are assumed to be an empty list (i.e., `value`
   1101       is treated as a scalar).
   1102     name: Optional op name.
   1103 
   1104   Returns:
   1105     The tensor with elements set to zero.
   1106   """
   1107   with ops.name_scope(name, 'lt_constant', [value]) as scope:
   1108 
   1109     if axes is None:
   1110       axes = []
   1111 
   1112     if isinstance(axes, core.Axes):
   1113       axes = axes.values()
   1114 
   1115     if any(isinstance(ax, string_types) for ax in axes):
   1116       # need to infer shape
   1117       shape = None
   1118     else:
   1119       # axes already indicate shape
   1120       axes = [core.as_axis(a) for a in axes]
   1121       shape = [a.size for a in axes]
   1122 
   1123     op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope)
   1124     return core.LabeledTensor(op, axes)
   1125 
   1126 
   1127 @tc.returns(core.LabeledTensor)
   1128 @tc.accepts(core.LabeledTensorLike,
   1129             tc.Optional(dtypes.DType), tc.Optional(string_types))
   1130 def zeros_like(labeled_tensor, dtype=None, name=None):
   1131   """Creates an identical tensor with all elements set to zero.
   1132 
   1133   Args:
   1134     labeled_tensor: The input tensor.
   1135     dtype: The type of the returned tensor.
   1136     name: Optional op name.
   1137 
   1138   Returns:
   1139     The tensor with elements set to zero.
   1140   """
   1141   with ops.name_scope(name, 'lt_zeros_like', [labeled_tensor]) as scope:
   1142     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1143     op = array_ops.zeros_like(labeled_tensor.tensor, dtype=dtype, name=scope)
   1144     return core.LabeledTensor(op, labeled_tensor.axes)
   1145 
   1146 
   1147 @tc.returns(core.LabeledTensor)
   1148 @tc.accepts(core.LabeledTensorLike,
   1149             tc.Optional(dtypes.DType), tc.Optional(string_types))
   1150 def ones_like(labeled_tensor, dtype=None, name=None):
   1151   """Creates an identical tensor with all elements set to one.
   1152 
   1153   Args:
   1154     labeled_tensor: The input tensor.
   1155     dtype: The type of the returned tensor.
   1156     name: Optional op name.
   1157 
   1158   Returns:
   1159     The tensor with elements set to one.
   1160   """
   1161   with ops.name_scope(name, 'lt_ones_like', [labeled_tensor]) as scope:
   1162     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1163     op = array_ops.ones_like(labeled_tensor.tensor, dtype=dtype, name=scope)
   1164     return core.LabeledTensor(op, labeled_tensor.axes)
   1165 
   1166 
   1167 @tc.returns(core.LabeledTensor)
   1168 @tc.accepts(core.LabeledTensorLike,
   1169             tc.Optional(dtypes.DType), tc.Optional(string_types))
   1170 def cast(labeled_tensor, dtype=None, name=None):
   1171   """Casts a labeled tensor to a new type.
   1172 
   1173   Args:
   1174     labeled_tensor: The input tensor.
   1175     dtype: The type of the returned tensor.
   1176     name: Optional op name.
   1177 
   1178   Returns:
   1179     A labeled tensor with the new dtype.
   1180   """
   1181   with ops.name_scope(name, 'lt_cast', [labeled_tensor]) as scope:
   1182     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1183     op = math_ops.cast(labeled_tensor.tensor, dtype=dtype, name=scope)
   1184     return core.LabeledTensor(op, labeled_tensor.axes)
   1185 
   1186 
   1187 @tc.returns(core.LabeledTensor)
   1188 @tc.accepts(core.LabeledTensorLike, string_types, tc.Optional(string_types))
   1189 def verify_tensor_all_finite(labeled_tensor, message, name=None):
   1190   """Asserts a tensor doesn't contain NaNs or Infs.
   1191 
   1192   See tf.verify_tensor_all_finite.
   1193 
   1194   Args:
   1195     labeled_tensor: The input tensor.
   1196     message: Message to log on failure.
   1197     name: Optional op name.
   1198 
   1199   Returns:
   1200     The input tensor.
   1201   """
   1202   with ops.name_scope(name, 'lt_verify_tensor_all_finite',
   1203                       [labeled_tensor]) as scope:
   1204     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1205     op = numerics.verify_tensor_all_finite(
   1206         labeled_tensor.tensor, msg=message, name=scope)
   1207     return core.LabeledTensor(op, labeled_tensor.axes)
   1208 
   1209 
   1210 @tc.returns(core.LabeledTensor)
   1211 @tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
   1212             tc.Optional(string_types))
   1213 def boolean_mask(labeled_tensor, mask, name=None):
   1214   """Apply a boolean mask to a labeled tensor.
   1215 
   1216   Unlike `tf.boolean_mask`, this currently only works on 1-dimensional masks.
   1217   The mask is applied to the first axis of `labeled_tensor`. Labels on the first
   1218   axis are removed, because True indices in `mask` may not be known dynamically.
   1219 
   1220   Args:
   1221     labeled_tensor: The input tensor.
   1222     mask: The type of the returned tensor.
   1223     name: Optional op name.
   1224 
   1225   Returns:
   1226     The masked labeled tensor.
   1227 
   1228   Raises:
   1229     ValueError: if the first axis of the mask
   1230   """
   1231   with ops.name_scope(name, 'lt_boolean_mask', [labeled_tensor, mask]) as scope:
   1232     labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
   1233     mask = core.convert_to_labeled_tensor(mask)
   1234 
   1235     if len(mask.axes) > 1:
   1236       raise NotImplementedError(
   1237           "LabeledTensor's boolean_mask currently only supports 1D masks")
   1238     mask_axis = list(mask.axes.values())[0]
   1239     lt_axis = list(labeled_tensor.axes.values())[0]
   1240     if mask_axis != lt_axis:
   1241       raise ValueError('the first axis of the labeled tensor and the mask '
   1242                        'are not equal:\n%r\n%r' % (lt_axis, mask_axis))
   1243     op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope)
   1244     # TODO(shoyer): attempt to infer labels for the masked values, by calling
   1245     # tf.contrib.util.constant_value on the mask?
   1246     axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:]
   1247     return core.LabeledTensor(op, axes)
   1248 
   1249 
   1250 @tc.returns(core.LabeledTensor)
   1251 @tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
   1252             core.LabeledTensorLike, tc.Optional(string_types))
   1253 def where(condition, x, y, name=None):
   1254   """Return elements from x or y depending on condition.
   1255 
   1256   See `tf.where` for more details. This function currently only implements the
   1257   three argument version of where.
   1258 
   1259   Args:
   1260     condition: LabeledTensor of type `bool`.
   1261     x: LabeledTensor for values where condition is true.
   1262     y: LabeledTensor for values where condition is false.
   1263     name: Optional op name.
   1264 
   1265   Returns:
   1266     The labeled tensor with values according to condition.
   1267 
   1268   Raises:
   1269     ValueError: if `x` and `y` have different axes, or if the axes of `x` do not
   1270       start with the axes of `condition`.
   1271   """
   1272   with ops.name_scope(name, 'lt_where', [condition, x, y]) as scope:
   1273     condition = core.convert_to_labeled_tensor(condition)
   1274     x = core.convert_to_labeled_tensor(x)
   1275     y = core.convert_to_labeled_tensor(y)
   1276 
   1277     if not condition.axes == x.axes == y.axes:
   1278       raise ValueError('all inputs to `where` must have equal axes')
   1279 
   1280     op = array_ops.where(condition.tensor, x.tensor, y.tensor, name=scope)
   1281     return core.LabeledTensor(op, x.axes)
   1282