Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 from six.moves import range  # pylint: disable=redefined-builtin
     22 
     23 from tensorflow.contrib.labeled_tensor.python.ops import core
     24 from tensorflow.contrib.labeled_tensor.python.ops import ops
     25 from tensorflow.contrib.labeled_tensor.python.ops import test_util
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors_impl
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import string_ops
     32 from tensorflow.python.platform import test as test_lib
     33 
     34 
     35 class Base(test_util.Base):
     36 
     37   def setUp(self):
     38     super(Base, self).setUp()
     39 
     40     self.x_size = 7
     41     self.channel_size = 3
     42     self.z_size = 4
     43     self.probs_size = 11
     44 
     45     tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size *
     46                             self.probs_size)
     47     tensor = array_ops.reshape(
     48         tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size])
     49     a0 = ('x', range(self.x_size))
     50     a1 = ('channel', ['red', 'green', 'blue'])
     51     a2 = 'z'
     52     a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
     53 
     54     self.tensor = tensor
     55     self.a0 = a0
     56     self.a1 = a1
     57     self.a2 = a2
     58     self.a2_resolved = ('z', self.z_size)
     59     self.a3 = a3
     60     self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
     61 
     62     self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})
     63     self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})
     64     self.channel_probs_lt = core.slice_function(self.original_lt,
     65                                                 {'x': 3,
     66                                                  'z': 0})
     67 
     68 
     69 class SelectTest(Base):
     70 
     71   def test_name(self):
     72     select_lt = ops.select(self.original_lt, {'channel': 'green'})
     73     self.assertIn('lt_select', select_lt.name)
     74 
     75   def test_scalar(self):
     76     select_lt = ops.select(self.original_lt, {'channel': 'green'})
     77     golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :],
     78                                    [self.a0, self.a2, self.a3])
     79     self.assertLabeledTensorsEqual(select_lt, golden_lt)
     80 
     81   def test_slice(self):
     82     select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')})
     83     a1_sliced = ('channel', ['red', 'green'])
     84     golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
     85                                    [self.a0, a1_sliced, self.a2, self.a3])
     86     self.assertLabeledTensorsEqual(select_lt, golden_lt)
     87 
     88   def test_slices(self):
     89     select_lt = ops.select(self.original_lt,
     90                            {'x': slice(1, 4),
     91                             'channel': slice('green', None)})
     92 
     93     a0_sliced = ('x', range(1, 5))
     94     a1_sliced = ('channel', ['green', 'blue'])
     95     golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
     96                                    [a0_sliced, a1_sliced, self.a2, self.a3])
     97     self.assertLabeledTensorsEqual(select_lt, golden_lt)
     98 
     99   def test_list(self):
    100     select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
    101     a1_sliced = ('channel', ['red', 'green'])
    102     golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
    103                                    [self.a0, a1_sliced, self.a2, self.a3])
    104     self.assertLabeledTensorsEqual(select_lt, golden_lt)
    105 
    106   def test_list_one_item(self):
    107     select_lt = ops.select(self.original_lt, {'channel': ['red']})
    108     a1_sliced = ('channel', ['red'])
    109     golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],
    110                                    [self.a0, a1_sliced, self.a2, self.a3])
    111     self.assertLabeledTensorsEqual(select_lt, golden_lt)
    112 
    113   def test_list_zero_items(self):
    114     select_lt = ops.select(self.original_lt, {'channel': []})
    115     golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],
    116                                    [self.a0, 'channel', self.a2, self.a3])
    117     self.assertLabeledTensorsEqual(select_lt, golden_lt)
    118 
    119   def test_scalars(self):
    120     select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'})
    121     golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], [self.a2, self.a3])
    122     self.assertLabeledTensorsEqual(select_lt, golden_lt)
    123 
    124   def test_tuple(self):
    125     original_lt = core.LabeledTensor(constant_op.constant([5, 6]),
    126                                      [('x', [(1, 2), (3, 4)])])
    127     select_lt = ops.select(original_lt, {'x': (1, 2)})
    128     golden_lt = core.LabeledTensor(constant_op.constant(5), [])
    129     self.assertLabeledTensorsEqual(select_lt, golden_lt)
    130 
    131   def test_invalid_input(self):
    132     with self.assertRaises(ValueError):
    133       ops.select(self.original_lt, {'foo': 1})
    134     with self.assertRaises(ValueError):
    135       ops.select(self.original_lt, {'z': 1})
    136     with self.assertRaises(KeyError):
    137       ops.select(self.original_lt, {'channel': 'purple'})
    138     with self.assertRaises(KeyError):
    139       ops.select(self.original_lt, {'channel': ['red', 'purple']})
    140     with self.assertRaises(NotImplementedError):
    141       ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})
    142     with self.assertRaises(NotImplementedError):
    143       ops.select(self.original_lt, {'channel': ['red'], 'x': 1})
    144     with self.assertRaises(NotImplementedError):
    145       ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
    146 
    147 
    148 class ConcatTest(Base):
    149 
    150   def setUp(self):
    151     super(ConcatTest, self).setUp()
    152 
    153     self.red_lt = ops.select(self.original_lt, {'channel': ['red']})
    154     self.green_lt = ops.select(self.original_lt, {'channel': ['green']})
    155     self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
    156 
    157   def test_name(self):
    158     concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel')
    159     self.assertIn('lt_concat', concat_lt.name)
    160 
    161   def test(self):
    162     concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')
    163     golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
    164 
    165     self.assertLabeledTensorsEqual(concat_lt, golden_lt)
    166 
    167   def test_transposed(self):
    168     green_transposed = core.transpose(self.green_lt,
    169                                       ['probs', 'channel', 'z', 'x'])
    170     with self.assertRaises(ValueError):
    171       ops.concat([self.red_lt, green_transposed], 'channel')
    172 
    173   def test_invalid_input(self):
    174     with self.assertRaises(ValueError):
    175       ops.concat([], 'channel')
    176     with self.assertRaises(ValueError):
    177       ops.concat([self.red_lt, self.red_lt], 'channel')
    178     with self.assertRaises(ValueError):
    179       ops.concat([self.red_lt, self.red_lt], 'foo')
    180 
    181 
    182 class PackTest(Base):
    183 
    184   def test_name(self):
    185     pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
    186     self.assertIn('lt_pack', pack_lt.name)
    187 
    188   def test(self):
    189     pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
    190     golden_lt = core.LabeledTensor(
    191         array_ops.stack([self.original_lt.tensor, self.original_lt.tensor]),
    192         ['batch', self.a0, self.a1, self.a2, self.a3])
    193 
    194     self.assertLabeledTensorsEqual(pack_lt, golden_lt)
    195 
    196   def test_axis(self):
    197     pack_lt = ops.pack(
    198         [self.original_lt, self.original_lt], new_axis='batch', axis_position=4)
    199     golden_lt = core.LabeledTensor(
    200         array_ops.stack(
    201             [self.original_lt.tensor, self.original_lt.tensor], axis=4),
    202         [self.a0, self.a1, self.a2, self.a3, 'batch'])
    203 
    204     self.assertLabeledTensorsEqual(pack_lt, golden_lt)
    205 
    206   def test_invalid_input(self):
    207     with self.assertRaises(ValueError):
    208       ops.pack([self.original_lt, self.original_lt], 'channel')
    209 
    210 
    211 class UnpackTest(Base):
    212 
    213   def test_name(self):
    214     unpack_lts = ops.unpack(self.original_lt)
    215     for t in unpack_lts:
    216       self.assertIn('lt_unpack', t.name)
    217 
    218   def test(self):
    219     unpack_lt = ops.unpack(self.original_lt)[0]
    220     golden_lt = core.LabeledTensor(
    221         array_ops.unstack(self.original_lt.tensor)[0],
    222         [self.a1, self.a2, self.a3])
    223 
    224     self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
    225 
    226   def test_axis(self):
    227     unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0]
    228     golden_lt = core.LabeledTensor(
    229         array_ops.unstack(
    230             self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3])
    231 
    232     self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
    233 
    234   def test_invalid_input(self):
    235     with self.assertRaises(ValueError):
    236       ops.unpack(self.original_lt, axis_name='not_found')
    237 
    238 
    239 class ReshapeTest(Base):
    240 
    241   def test_name(self):
    242     reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
    243     self.assertIn('lt_reshape', reshape_lt.name)
    244 
    245   def test_identity(self):
    246     reshape_lt = ops.reshape(self.original_lt,
    247                              self.original_lt.axes.keys(),
    248                              self.original_lt.axes.values())
    249     self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
    250 
    251   def test_known_size(self):
    252     new_dim_size = self.channel_size * self.z_size * self.probs_size
    253     reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
    254                              [('new_dim', new_dim_size)])
    255     golden_lt = core.LabeledTensor(
    256         array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
    257         [self.original_lt.axes['x'], 'new_dim'])
    258     self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
    259 
    260   def test_unknown_size(self):
    261     reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
    262                              ['new_dim'])
    263     golden_lt = core.LabeledTensor(
    264         array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
    265         [self.original_lt.axes['x'], 'new_dim'])
    266     self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
    267 
    268   def test_unknown_dimension(self):
    269     orig_lt = core.LabeledTensor(
    270         array_ops.placeholder(dtypes.float32, [None]), ['x'])
    271     reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
    272     self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
    273     with self.test_session() as sess:
    274       result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
    275       np.testing.assert_array_equal(result, [[1], [2]])
    276 
    277   def test_with_labels(self):
    278     new_dim_size = self.channel_size * self.z_size * self.probs_size
    279     reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
    280                              [('new_dim', range(new_dim_size))])
    281     golden_lt = core.LabeledTensor(
    282         array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
    283         [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
    284     self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
    285 
    286   def test_invalid_input(self):
    287     with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
    288       ops.reshape(self.original_lt, ['foo'], ['bar'])
    289     with self.assertRaisesRegexp(core.AxisOrderError,
    290                                  'not a slice of axis names'):
    291       ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
    292     with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
    293       ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
    294 
    295 
    296 class RenameAxisTest(Base):
    297 
    298   def test_name(self):
    299     rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
    300     self.assertIn('lt_rename_axis', rename_axis_lt.name)
    301 
    302   def test_identity(self):
    303     rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel')
    304     self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt)
    305 
    306   def test_new_name(self):
    307     rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
    308     expected_axes = [(name if name != 'channel' else 'foo', axis.value)
    309                      for name, axis in self.original_lt.axes.items()]
    310     expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes)
    311     self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt)
    312 
    313   def test_invalid_input(self):
    314     with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
    315       ops.rename_axis(self.original_lt, 'foo', 'bar')
    316 
    317 
    318 class BatchTest(Base):
    319 
    320   def setUp(self):
    321     super(BatchTest, self).setUp()
    322 
    323     tensors = []
    324     for i in range(10):
    325       offset_lt = core.LabeledTensor(constant_op.constant(i), [])
    326       tensors.append(core.add(self.original_lt, offset_lt))
    327     self.pack_lt = ops.pack(tensors, 'batch')
    328 
    329   def test_name(self):
    330     batch_ops = ops.batch(
    331         [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True)
    332     for bo in batch_ops:
    333       self.assertIn('lt_batch', bo.name)
    334 
    335   def test_enqueue_many(self):
    336     [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True)
    337     self.assertEqual(len(batch_2_op.axes['batch']), 2)
    338 
    339     [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
    340 
    341     self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)
    342 
    343   def test_no_enqueue_many(self):
    344     [batch_2_op] = ops.batch([self.original_lt], batch_size=2)
    345     self.assertEqual(len(batch_2_op.axes['batch']), 2)
    346 
    347     [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
    348 
    349     self.assertLabeledTensorsEqual(
    350         ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)
    351 
    352   def test_invalid_input(self):
    353     with self.assertRaises(ValueError):
    354       ops.batch([self.original_lt], 3, enqueue_many=True)
    355 
    356   def test_allow_smaller_final_batch(self):
    357     [batch_2_op] = ops.batch(
    358         [self.original_lt], batch_size=2, allow_smaller_final_batch=True)
    359     self.assertEqual(batch_2_op.axes['batch'].size, None)
    360 
    361 
    362 class ShuffleBatchTest(Base):
    363 
    364   def setUp(self):
    365     super(ShuffleBatchTest, self).setUp()
    366 
    367     tensors = []
    368     for i in range(10):
    369       offset_lt = core.LabeledTensor(constant_op.constant(i), [])
    370       tensors.append(core.add(self.original_lt, offset_lt))
    371     self.pack_lt = ops.pack(tensors, 'batch')
    372 
    373   def test_name(self):
    374     batch_lts = ops.shuffle_batch(
    375         [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True)
    376     for blt in batch_lts:
    377       self.assertIn('lt_shuffle_batch', blt.name)
    378 
    379   def test_enqueue_many(self):
    380     [batch_2_lt] = ops.shuffle_batch(
    381         [self.pack_lt],
    382         batch_size=2,
    383         enqueue_many=True,
    384         min_after_dequeue=8,
    385         seed=0)
    386     self.assertEqual(len(batch_2_lt.axes['batch']), 2)
    387 
    388     [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True)
    389 
    390     self.assertEqual(batch_10_lt.axes, self.pack_lt.axes)
    391     [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor])
    392     self.assertFalse((batch_10 == pack).all())
    393 
    394   def test_allow_smaller_final_batch(self):
    395     [batch_2_op] = ops.shuffle_batch(
    396         [self.original_lt], batch_size=2, allow_smaller_final_batch=True)
    397     self.assertEqual(batch_2_op.axes['batch'].size, None)
    398 
    399 
    400 class RandomCropTest(Base):
    401 
    402   def test_name(self):
    403     crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
    404     self.assertIn('lt_random_crop', crop_lt.name)
    405 
    406   def test_single(self):
    407     crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
    408 
    409     self.assertEqual(
    410         core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]),
    411         crop_lt.axes)
    412 
    413   def test_double(self):
    414     crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})
    415 
    416     self.assertEqual(
    417         core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
    418         crop_lt.axes)
    419 
    420   def test_size1(self):
    421     crop_lt = ops.random_crop(self.original_lt, {'probs': 1})
    422 
    423     self.assertEqual(
    424         core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
    425         crop_lt.axes)
    426 
    427   def test_different_seeds(self):
    428     crop_0_lt = ops.random_crop(
    429         self.original_lt, {'probs': 3,
    430                            'channel': 2}, seed=0)
    431     crop_1_lt = ops.random_crop(
    432         self.original_lt, {'probs': 3,
    433                            'channel': 2}, seed=1)
    434 
    435     self.assertEqual(crop_0_lt.axes, crop_1_lt.axes)
    436     [crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor])
    437     self.assertFalse((crop_0 == crop_1).all())
    438 
    439   def test_identical_seeds(self):
    440     crop_0_lt = ops.random_crop(
    441         self.original_lt, {'probs': 3,
    442                            'channel': 2}, seed=0)
    443     crop_1_lt = ops.random_crop(
    444         self.original_lt, {'probs': 3,
    445                            'channel': 2}, seed=0)
    446 
    447     self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
    448 
    449   def test_crop_idempotent(self):
    450     crop_0_lt = ops.random_crop(
    451         self.original_lt, {'probs': 3,
    452                            'channel': 2}, seed=0)
    453     crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1)
    454 
    455     self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
    456 
    457   def test_invalid_input(self):
    458     with self.assertRaises(ValueError):
    459       ops.random_crop(self.original_lt, {'foobar': 2})
    460 
    461 
    462 class MapFnTest(Base):
    463 
    464   def test_name(self):
    465     map_lt = ops.map_fn(core.identity, self.original_lt)
    466     self.assertIn('lt_map_fn', map_lt.name)
    467 
    468   def test_identity(self):
    469     map_lt = ops.map_fn(core.identity, self.original_lt)
    470     self.assertLabeledTensorsEqual(map_lt, self.original_lt)
    471 
    472   def test_callable_object(self):
    473 
    474     class Identity(object):
    475 
    476       def __call__(self, other):
    477         return other
    478 
    479     map_lt = ops.map_fn(Identity(), self.original_lt)
    480     self.assertLabeledTensorsEqual(map_lt, self.original_lt)
    481 
    482   def test_slice(self):
    483     map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}),
    484                         self.original_lt)
    485     slice_lt = core.slice_function(self.original_lt, {'channel': 1})
    486     self.assertLabeledTensorsEqual(map_lt, slice_lt)
    487 
    488   def test_string(self):
    489 
    490     def fn(entry_lt):
    491       op = string_ops.string_join([entry_lt, 'world'])
    492       return core.LabeledTensor(op, [])
    493 
    494     tensor_lt = ops.constant(['hi', 'bye'], axes=['batch'])
    495     map_lt = ops.map_fn(fn, tensor_lt)
    496     golden_lt = ops.constant(['hiworld', 'byeworld'], axes=['batch'])
    497 
    498     self.assertLabeledTensorsEqual(map_lt, golden_lt)
    499 
    500 
    501 class FoldlTest(Base):
    502 
    503   def test_name(self):
    504     foldl_lt = ops.foldl(core.add, self.original_lt,
    505                          core.slice_function(self.original_lt, {'x': 0}))
    506     self.assertIn('lt_foldl', foldl_lt.name)
    507 
    508   def test_sum(self):
    509     initializer_lt = ops.constant([0, 10], axes=['y'])
    510     tensor_lt = ops.constant([[1, 2], [3, 4], [5, 6]], axes=['x', 'y'])
    511     foldl_lt = ops.foldl(core.add, tensor_lt, initializer_lt)
    512     golden_lt = ops.constant([9, 22], axes=['y'])
    513     self.assertLabeledTensorsEqual(foldl_lt, golden_lt)
    514 
    515 
    516 class SqueezeTest(Base):
    517 
    518   def setUp(self):
    519     super(SqueezeTest, self).setUp()
    520 
    521     self.squeezable_lt = core.slice_function(
    522         self.original_lt, {'channel': slice(0, 1),
    523                            'probs': slice(0, 1)})
    524 
    525   def test_name(self):
    526     squeeze_lt = ops.squeeze(self.squeezable_lt)
    527     self.assertIn('lt_squeeze', squeeze_lt.name)
    528 
    529   def test_none(self):
    530     none_lt = ops.squeeze(self.squeezable_lt, None)
    531     axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs'])
    532     self.assertLabeledTensorsEqual(none_lt, axes_lt)
    533 
    534   def test(self):
    535     squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs'])
    536     golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0})
    537     self.assertLabeledTensorsEqual(squeeze_lt, golden_lt)
    538 
    539   def test_invalid_input(self):
    540     with self.assertRaises(ValueError):
    541       ops.squeeze(self.original_lt, ['channel'])
    542     with self.assertRaises(ValueError):
    543       ops.squeeze(self.squeezable_lt, ['foo'])
    544 
    545 
    546 class MatMulTest(Base):
    547 
    548   def test_name(self):
    549     x_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])
    550     matmul_lt = ops.matmul(x_lt, x_lt)
    551     self.assertIn('lt_matmul', matmul_lt.name)
    552 
    553   def test_vector_vector(self):
    554     x_lt = core.LabeledTensor(math_ops.range(3), ['x'])
    555     matmul_lt = ops.matmul(x_lt, x_lt)
    556     golden_lt = core.convert_to_labeled_tensor(5)
    557     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    558 
    559   def test_matrix_vector(self):
    560     xy_lt = core.LabeledTensor(
    561         array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
    562     y_lt = core.LabeledTensor(math_ops.range(3), ['y'])
    563 
    564     matmul_lt = ops.matmul(xy_lt, y_lt)
    565     golden_lt = core.LabeledTensor(
    566         math_ops.matmul(xy_lt.tensor, array_ops.reshape(y_lt.tensor,
    567                                                         (-1, 1)))[:, 0], ['x'])
    568     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    569 
    570     matmul_lt = ops.matmul(y_lt, xy_lt)
    571     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    572 
    573   def test_matrix_matrix(self):
    574     xy_lt = core.LabeledTensor(
    575         array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
    576     yz_lt = core.LabeledTensor(
    577         array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])
    578 
    579     matmul_lt = ops.matmul(xy_lt, yz_lt)
    580     golden_lt = core.LabeledTensor(
    581         math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
    582     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    583 
    584     transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])
    585 
    586     matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
    587     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    588 
    589     matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
    590     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    591 
    592     matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
    593     self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    594 
    595     matmul_lt = ops.matmul(yz_lt, xy_lt)
    596     self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
    597 
    598   def test_matrix_matrix_axis_order(self):
    599     xy_lt = core.LabeledTensor(
    600         array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
    601     yz_lt = core.LabeledTensor(
    602         array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])
    603 
    604     golden_lt = core.LabeledTensor(
    605         math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
    606 
    607     with core.axis_order_scope(['x', 'y', 'z']):
    608 
    609       matmul_lt = ops.matmul(xy_lt, yz_lt)
    610       self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    611 
    612       matmul_lt = ops.matmul(yz_lt, xy_lt)
    613       self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
    614 
    615   def test_invalid(self):
    616     scalar_lt = core.LabeledTensor(array_ops.ones(()), [])
    617     x_lt = core.LabeledTensor(array_ops.ones((2,)), ['x'])
    618     x2_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])
    619     y_lt = core.LabeledTensor(array_ops.ones((3,)), ['y'])
    620     xy_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'y'])
    621     xyz_lt = core.LabeledTensor(array_ops.ones((2, 3, 1)), ['x', 'y', 'z'])
    622 
    623     with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
    624       ops.matmul(x_lt, scalar_lt)
    625 
    626     with self.assertRaises(NotImplementedError):
    627       ops.matmul(x_lt, xyz_lt)
    628 
    629     with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
    630       ops.matmul(x_lt, y_lt)
    631 
    632     with self.assertRaises(NotImplementedError):
    633       ops.matmul(xy_lt, xy_lt)
    634 
    635     with self.assertRaisesRegexp(ValueError, 'does not match'):
    636       ops.matmul(x_lt, x2_lt)
    637 
    638 
    639 class ReduceSumTest(Base):
    640 
    641   def test_name(self):
    642     sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
    643     self.assertIn('lt_reduce_sum', sum_lt.name)
    644 
    645   def test_drop_axis(self):
    646     sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
    647     golden_lt = core.LabeledTensor(
    648         math_ops.reduce_sum(self.original_lt.tensor, 1),
    649         [self.a0, self.a2, self.a3])
    650     self.assertLabeledTensorsEqual(sum_lt, golden_lt)
    651 
    652   def test_drop_scalar_axis(self):
    653     sum_lt = ops.reduce_sum(self.original_lt, 'channel')
    654     golden_lt = core.LabeledTensor(
    655         math_ops.reduce_sum(self.original_lt.tensor, 1),
    656         [self.a0, self.a2, self.a3])
    657     self.assertLabeledTensorsEqual(sum_lt, golden_lt)
    658 
    659   def test_keep_axis(self):
    660     sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')})
    661     golden_lt = core.LabeledTensor(
    662         math_ops.reduce_sum(
    663             self.original_lt.tensor, 1, keep_dims=True),
    664         [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
    665     self.assertLabeledTensorsEqual(sum_lt, golden_lt)
    666 
    667   def test_keep_scalar_axis(self):
    668     sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou'))
    669     golden_lt = core.LabeledTensor(
    670         math_ops.reduce_sum(
    671             self.original_lt.tensor, 1, keep_dims=True),
    672         [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
    673     self.assertLabeledTensorsEqual(sum_lt, golden_lt)
    674 
    675   def test_scalar(self):
    676     scalar_lt = core.LabeledTensor(constant_op.constant(42), [])
    677     reduce_lt = ops.reduce_sum(scalar_lt, [])
    678     self.assertLabeledTensorsEqual(reduce_lt, scalar_lt)
    679 
    680   def test_empty_list(self):
    681     reduce_lt = ops.reduce_sum(self.original_lt, [])
    682     self.assertLabeledTensorsEqual(reduce_lt, self.original_lt)
    683 
    684   def test_none(self):
    685     sum_lt = ops.reduce_sum(self.original_lt)
    686     golden_lt = core.LabeledTensor(
    687         math_ops.reduce_sum(self.original_lt.tensor), [])
    688     self.assertLabeledTensorsEqual(sum_lt, golden_lt)
    689 
    690   def test_function_docstring_and_name(self):
    691     self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__)
    692     self.assertEqual('reduce_sum', ops.reduce_sum.__name__)
    693 
    694 
    695 class ReduceMeanTest(Base):
    696 
    697   def test_name(self):
    698     actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
    699     self.assertIn('lt_reduce_mean', actual_lt.name)
    700 
    701   def test(self):
    702     actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
    703     golden_lt = core.LabeledTensor(
    704         math_ops.reduce_mean(self.original_lt.tensor, 1),
    705         [self.a0, self.a2, self.a3])
    706     self.assertLabeledTensorsEqual(actual_lt, golden_lt)
    707 
    708 
    709 class ReduceProdTest(Base):
    710 
    711   def test_name(self):
    712     result_lt = ops.reduce_prod(self.original_lt, {'channel'})
    713     self.assertIn('lt_reduce_prod', result_lt.name)
    714 
    715   def test(self):
    716     result_lt = ops.reduce_prod(self.original_lt, {'channel'})
    717     golden_lt = core.LabeledTensor(
    718         math_ops.reduce_prod(self.original_lt.tensor, 1),
    719         [self.a0, self.a2, self.a3])
    720     self.assertLabeledTensorsEqual(result_lt, golden_lt)
    721 
    722 
    723 class ReduceMinTest(Base):
    724 
    725   def test_name(self):
    726     result_lt = ops.reduce_min(self.original_lt, {'channel'})
    727     self.assertIn('lt_reduce_min', result_lt.name)
    728 
    729   def test(self):
    730     result_lt = ops.reduce_min(self.original_lt, {'channel'})
    731     golden_lt = core.LabeledTensor(
    732         math_ops.reduce_min(self.original_lt.tensor, 1),
    733         [self.a0, self.a2, self.a3])
    734     self.assertLabeledTensorsEqual(result_lt, golden_lt)
    735 
    736 
    737 class ReduceMaxTest(Base):
    738 
    739   def test_name(self):
    740     result_lt = ops.reduce_max(self.original_lt, {'channel'})
    741     self.assertIn('lt_reduce_max', result_lt.name)
    742 
    743   def test(self):
    744     result_lt = ops.reduce_max(self.original_lt, {'channel'})
    745     golden_lt = core.LabeledTensor(
    746         math_ops.reduce_max(self.original_lt.tensor, 1),
    747         [self.a0, self.a2, self.a3])
    748     self.assertLabeledTensorsEqual(result_lt, golden_lt)
    749 
    750 
    751 class BaseReduceBoolean(Base):
    752 
    753   def setUp(self):
    754     super(BaseReduceBoolean, self).setUp()
    755     self.bool_tensor = math_ops.cast(self.original_lt.tensor > 5, dtypes.bool)
    756     self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes)
    757 
    758 
    759 class ReduceAllTest(BaseReduceBoolean):
    760 
    761   def test_name(self):
    762     result_lt = ops.reduce_all(self.bool_lt, {'channel'})
    763     self.assertIn('lt_reduce_all', result_lt.name)
    764 
    765   def test(self):
    766     result_lt = ops.reduce_all(self.bool_lt, {'channel'})
    767     golden_lt = core.LabeledTensor(
    768         math_ops.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
    769     self.assertLabeledTensorsEqual(result_lt, golden_lt)
    770 
    771 
    772 class ReduceAnyTest(BaseReduceBoolean):
    773 
    774   def test_name(self):
    775     result_lt = ops.reduce_any(self.bool_lt, {'channel'})
    776     self.assertIn('lt_reduce_any', result_lt.name)
    777 
    778   def test(self):
    779     result_lt = ops.reduce_any(self.bool_lt, {'channel'})
    780     golden_lt = core.LabeledTensor(
    781         math_ops.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
    782     self.assertLabeledTensorsEqual(result_lt, golden_lt)
    783 
    784 
    785 class TileTest(Base):
    786 
    787   def test_name(self):
    788     tile_lt = ops.tile(self.original_lt, {'z': 2})
    789     self.assertIn('lt_tile', tile_lt.name)
    790 
    791   def test(self):
    792     for multiple in [2, constant_op.constant(2)]:
    793       tile_lt = ops.tile(self.original_lt, {'z': multiple})
    794       golden_op = array_ops.tile(self.original_lt.tensor, [1, 1, multiple, 1])
    795       golden_axes = [
    796           'z' if axis.name == 'z' else axis
    797           for axis in self.original_lt.axes.values()
    798       ]
    799       golden_lt = core.LabeledTensor(golden_op, golden_axes)
    800       self.assertLabeledTensorsEqual(tile_lt, golden_lt)
    801 
    802   def test_invalid_input(self):
    803     with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
    804       ops.tile(self.original_lt, {'foo': 5})
    805     with self.assertRaisesRegexp(ValueError, 'axes with tick labels'):
    806       ops.tile(self.original_lt, {'x': 5})
    807 
    808 
    809 class PadTest(Base):
    810 
    811   def test_name(self):
    812     pad_lt = ops.pad(self.original_lt,
    813                      {'x': (1, 1),
    814                       'channel': ([], ['alpha'])})
    815     self.assertIn('lt_pad', pad_lt.name)
    816 
    817   def test(self):
    818     pad_lt = ops.pad(self.original_lt,
    819                      {'x': (1, 1),
    820                       'channel': ([], ['alpha'])})
    821 
    822     golden_op = array_ops.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0],
    823                                                         [0, 0]])
    824     golden_axes = [('x', self.x_size + 2),
    825                    ('channel', ['red', 'green', 'blue', 'alpha']), self.a2,
    826                    self.a3]
    827     golden_lt = core.LabeledTensor(golden_op, golden_axes)
    828     self.assertLabeledTensorsEqual(pad_lt, golden_lt)
    829 
    830   def test_invalid_input(self):
    831     with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
    832       ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])})
    833 
    834 
    835 class ConstantTest(Base):
    836 
    837   def test_name(self):
    838     constant_lt = ops.constant(1)
    839     self.assertIn('lt_constant', constant_lt.name)
    840 
    841   def test_scalar(self):
    842     constant_lt = ops.constant(1)
    843     golden_lt = core.LabeledTensor(constant_op.constant(1), [])
    844     self.assertLabeledTensorsEqual(constant_lt, golden_lt)
    845 
    846   def test_infer_shape(self):
    847     constant_lt = ops.constant([1, 2], axes=['x'])
    848     golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x'])
    849     self.assertLabeledTensorsEqual(constant_lt, golden_lt)
    850 
    851   def test_specify_shape(self):
    852     constant_lt = ops.constant(1, axes=[('x', 3)])
    853     golden_lt = core.LabeledTensor(constant_op.constant(1, shape=(3,)), ['x'])
    854     self.assertLabeledTensorsEqual(constant_lt, golden_lt)
    855 
    856   def test_existing_axes(self):
    857     golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x'])
    858     constant_lt = ops.constant([1, 2], axes=golden_lt.axes)
    859     self.assertLabeledTensorsEqual(constant_lt, golden_lt)
    860 
    861 
    862 class ZerosLikeTest(Base):
    863 
    864   def test_name(self):
    865     like_lt = ops.zeros_like(self.original_lt)
    866     self.assertIn('lt_zeros_like', like_lt.name)
    867 
    868   def test(self):
    869     like_lt = ops.zeros_like(self.original_lt)
    870     golden_lt = core.LabeledTensor(
    871         array_ops.zeros_like(self.original_lt.tensor), self.original_lt.axes)
    872     self.assertLabeledTensorsEqual(like_lt, golden_lt)
    873 
    874 
    875 class OnesLikeTest(Base):
    876 
    877   def test_name(self):
    878     like_lt = ops.ones_like(self.original_lt)
    879     self.assertIn('lt_ones_like', like_lt.name)
    880 
    881   def test(self):
    882     like_lt = ops.ones_like(self.original_lt)
    883     golden_lt = core.LabeledTensor(
    884         array_ops.ones_like(self.original_lt.tensor), self.original_lt.axes)
    885     self.assertLabeledTensorsEqual(like_lt, golden_lt)
    886 
    887 
    888 class CastTest(Base):
    889 
    890   def test_name(self):
    891     cast_lt = ops.cast(self.original_lt, dtypes.float16)
    892     self.assertIn('lt_cast', cast_lt.name)
    893 
    894   def test(self):
    895     cast_lt = ops.cast(self.original_lt, dtypes.float16)
    896     golden_lt = core.LabeledTensor(
    897         math_ops.cast(self.original_lt.tensor, dtypes.float16),
    898         self.original_lt.axes)
    899     self.assertLabeledTensorsEqual(cast_lt, golden_lt)
    900 
    901 
    902 class VerifyTensorAllFiniteTest(Base):
    903 
    904   def setUp(self):
    905     super(VerifyTensorAllFiniteTest, self).setUp()
    906 
    907     self.finite_lt = core.LabeledTensor(constant_op.constant(42.0), [])
    908     self.nan_lt = core.LabeledTensor(constant_op.constant(np.nan), [])
    909 
    910     self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '')
    911     self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '')
    912 
    913   def test_name(self):
    914     self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name)
    915     self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name)
    916 
    917   def test_finite(self):
    918     self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt)
    919 
    920   def test_nan(self):
    921     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    922                                  'Tensor had NaN values'):
    923       self.eval([self.checked_nan_lt])
    924 
    925 
    926 class BooleanMaskTest(Base):
    927 
    928   def test_name(self):
    929     mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
    930     masked_lt = ops.boolean_mask(self.original_lt, mask)
    931     self.assertIn('lt_boolean_mask', masked_lt.name)
    932 
    933   def test(self):
    934     mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
    935     masked_lt = ops.boolean_mask(self.original_lt, mask)
    936     golden_lt = core.LabeledTensor(
    937         array_ops.boolean_mask(self.original_lt.tensor, mask.tensor),
    938         ['x', self.a1, self.a2, self.a3])
    939     self.assertLabeledTensorsEqual(masked_lt, golden_lt)
    940 
    941   def test_invalid_rank(self):
    942     mask = core.LabeledTensor(array_ops.ones((7, 3)) > 3, [self.a0, self.a1])
    943     with self.assertRaises(NotImplementedError):
    944       ops.boolean_mask(self.original_lt, mask)
    945 
    946   def test_mismatched_axis(self):
    947     mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo'])
    948     with self.assertRaisesRegexp(ValueError, 'not equal'):
    949       ops.boolean_mask(self.original_lt, mask)
    950 
    951 
    952 class WhereTest(Base):
    953 
    954   def test_name(self):
    955     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    956     where_lt = ops.where(condition, condition, condition)
    957     self.assertIn('lt_where', where_lt.name)
    958 
    959   def test(self):
    960     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    961     x = core.LabeledTensor(array_ops.ones(5), ['x'])
    962     y = core.LabeledTensor(array_ops.zeros(5), ['x'])
    963     where_lt = ops.where(condition, x, y)
    964 
    965     golden_lt = core.LabeledTensor(
    966         array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])
    967     self.assertLabeledTensorsEqual(where_lt, golden_lt)
    968 
    969   def test_mismatched_axes(self):
    970     condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
    971     with self.assertRaisesRegexp(ValueError, 'equal axes'):
    972       ops.where(condition, condition[:3], condition)
    973     with self.assertRaisesRegexp(ValueError, 'equal axes'):
    974       ops.where(condition, condition, condition[:3])
    975 
    976 
    977 if __name__ == '__main__':
    978   test_lib.main()
    979