Home | History | Annotate | Download | only in ragged
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for ragged_map_ops.map_fn."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from absl.testing import parameterized
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import sparse_tensor
     25 from tensorflow.python.framework import test_util
     26 from tensorflow.python.keras import backend
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import math_ops as mo
     29 from tensorflow.python.ops import string_ops
     30 from tensorflow.python.ops.ragged import ragged_factory_ops
     31 from tensorflow.python.ops.ragged import ragged_functional_ops
     32 from tensorflow.python.ops.ragged import ragged_map_ops
     33 from tensorflow.python.ops.ragged import ragged_math_ops
     34 from tensorflow.python.ops.ragged import ragged_tensor
     35 from tensorflow.python.ops.ragged import ragged_test_util
     36 from tensorflow.python.platform import googletest
     37 
     38 
     39 @test_util.run_all_in_graph_and_eager_modes
     40 class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
     41                       parameterized.TestCase):
     42 
     43   @parameterized.parameters([
     44       # The following test sets map over a RaggedTensor and apply a
     45       # transformation that returns with shape:
     46       # [d1, (d2)] -> [d1]
     47       dict(
     48           fn=mo.reduce_mean,
     49           elems=[[1, 2, 3], [4, 5], [6, 7]],
     50           expected_output=[2, 4, 6],
     51       ),
     52       dict(
     53           fn=string_ops.reduce_join,
     54           elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']],
     55           expected_output=[b'foobarbaz', b'a', b'bc'],
     56           dtype=dtypes.string,
     57       ),
     58       # [d1, (d2)] -> [d1, 2]
     59       dict(
     60           fn=lambda x: array_ops.stack([mo.reduce_mean(x), mo.reduce_sum(x)]),
     61           # fn=self.stack_mean_and_sum,
     62           elems=[[1, 2, 3], [4, 5], [6, 7]],
     63           expected_output=[[2, 6], [4.5, 9], [6.5, 13]],
     64           dtype=dtypes.float32,
     65           expected_ragged_rank=0,
     66       ),
     67       # [d1, (d2)] -> [d1, (d2)]
     68       dict(
     69           fn=lambda x: x + np.int64(1),
     70           elems=[[1, 2, 3], [4, 5], [6, 7]],
     71           expected_output=[[2, 3, 4], [5, 6], [7, 8]],
     72           dtype=dtypes.int64,
     73           result_dtype=ragged_tensor.RaggedTensorType(
     74               dtype=dtypes.int64, ragged_rank=1),
     75       ),
     76       # [d1, (d2), d3] -> [d1, (d2), d3]
     77       dict(
     78           fn=lambda x: x + np.int64(1),
     79           elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
     80           elems_ragged_rank=1,
     81           expected_ragged_rank=1,
     82           result_dtype=ragged_tensor.RaggedTensorType(
     83               dtype=dtypes.int64, ragged_rank=1),
     84           expected_output=[[[2, 3], [4, 5]], [], [[6, 7], [8, 9], [10, 1]]],
     85       ),
     86       # [d1, (d2)] -> [d1, (d2), (d3)]
     87       dict(
     88           fn=lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]),
     89           elems=[[1, 2, 3], [4, 5], [6, 7]],
     90           expected_output=[[[1, 2, 3]], [[4, 5]], [[6, 7]]],
     91           result_dtype=ragged_tensor.RaggedTensorType(
     92               dtype=dtypes.int64, ragged_rank=2),
     93       ),
     94       # [d1, (d2), (d3)] -> [d1, (d2), (d3)]
     95       dict(
     96           fn=lambda x: ragged_functional_ops.map_flat_values(mo.add, x, 1),
     97           elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
     98           expected_output=[[[2, 3, 4]], [[5, 6], [7, 8]]],
     99           result_dtype=ragged_tensor.RaggedTensorType(
    100               dtype=dtypes.int64, ragged_rank=2),
    101       ),
    102       # [d1, (d2), (d3)] -> [d1, (d2)]
    103       dict(
    104           fn=lambda x: ragged_math_ops.reduce_sum(x, axis=1),
    105           elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
    106           expected_output=[[6], [9, 13]],
    107           result_dtype=ragged_tensor.RaggedTensorType(
    108               dtype=dtypes.int64, ragged_rank=1),
    109       ),
    110       # [d1, (d2), (d3)] -> [d1, (d3)]
    111       dict(
    112           fn=lambda x: ragged_math_ops.reduce_sum(x, axis=0),
    113           elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
    114           expected_output=[[1, 2, 3], [10, 12]],
    115           result_dtype=ragged_tensor.RaggedTensorType(
    116               dtype=dtypes.int64, ragged_rank=1),
    117       ),
    118       # [d1, (d2), (d3)] -> [d1]
    119       dict(
    120           fn=ragged_math_ops.reduce_sum,
    121           elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
    122           expected_output=[6, 22],
    123           result_dtype=dtypes.int64,
    124       ),
    125       # [d1] -> [d1, (d2)]
    126       dict(
    127           fn=mo.range,
    128           elems=[4, 0, 2],
    129           expected_output=[[0, 1, 2, 3], [], [0, 1]],
    130           result_dtype=ragged_tensor.RaggedTensorType(
    131               dtype=dtypes.int64, ragged_rank=1),
    132       ),
    133       # [d1] -> [d1, (d2), (d3)]
    134       dict(
    135           fn=lambda x: ragged_math_ops.range(mo.range(x)),
    136           elems=[5, 0, 3],
    137           expected_output=[[[], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3]], [],
    138                            [[], [0], [0, 1]]],
    139           result_dtype=ragged_tensor.RaggedTensorType(
    140               dtype=dtypes.int64, ragged_rank=2),
    141       ),
    142       # [d1, (d2), (d3), (d4a), (d5)] ->  [d1, (d2), (d3), (d4b), (d5)]
    143       dict(
    144           fn=lambda x: x + np.int64(1),
    145           elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
    146           expected_output=[[[[[2, 3, 4]], [[5], [6]]]], [[[[7, 8]]], [[[9],
    147                                                                        []]]]],
    148           result_dtype=ragged_tensor.RaggedTensorType(
    149               dtype=dtypes.int64, ragged_rank=4),
    150       ),
    151   ])
    152 
    153   def testRaggedMap(
    154       self,
    155       fn,
    156       elems,
    157       expected_output,
    158       expected_ragged_rank=None,
    159       result_ragged_rank=None,
    160       elems_ragged_rank=None,
    161       dtype=dtypes.int64,
    162       result_dtype=None,
    163       infer_shape=False,
    164   ):
    165     elems = ragged_factory_ops.constant(elems, dtype, elems_ragged_rank)
    166     output = ragged_map_ops.map_fn(
    167         fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape)
    168 
    169     expected_rt = ragged_factory_ops.constant(
    170         expected_output, ragged_rank=expected_ragged_rank)
    171     self.assertRaggedEqual(expected_rt, output)
    172 
    173   def testRaggedMapOnStructure(self):
    174     batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]])
    175     # [[10, 20, 30], [40], [50, 60, 70]]
    176     robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10)
    177 
    178     features = {'batman': batman, 'robin': robin}
    179 
    180     def _reduce_sum_from_all(f):
    181       return mo.reduce_sum(f['batman']) + mo.reduce_sum(f['robin'])
    182 
    183     output = ragged_map_ops.map_fn(
    184         fn=_reduce_sum_from_all,
    185         elems=features,
    186         dtype=dtypes.int32,
    187     )
    188 
    189     self.assertRaggedEqual(output, [66, 44, 198])
    190 
    191   # Test mapping over a dict of RTs can produce a dict of RTs.
    192   def testRaggedMapOnStructure_RaggedOutputs(self):
    193     batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]])
    194     # [[10, 20, 30], [40], [50, 60, 70]]
    195     robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10)
    196 
    197     features = {'batman': batman, 'robin': robin}
    198 
    199     def _increment(f):
    200       return {
    201           'batman': f['batman'] + 1,
    202           'robin': f['robin'] + 1,
    203       }
    204 
    205     output = ragged_map_ops.map_fn(
    206         fn=_increment,
    207         elems=features,
    208         infer_shape=False,
    209         dtype={
    210             'batman':
    211                 ragged_tensor.RaggedTensorType(
    212                     dtype=dtypes.int32, ragged_rank=1),
    213             'robin':
    214                 ragged_tensor.RaggedTensorType(
    215                     dtype=dtypes.int32, ragged_rank=1)
    216         },
    217     )
    218 
    219     self.assertRaggedEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]])
    220     self.assertRaggedEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]])
    221 
    222   def testZip(self):
    223     x = ragged_factory_ops.constant(
    224         [[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64)
    225     y = array_ops.expand_dims(mo.range(x.nrows(), dtype=dtypes.int64), axis=1)
    226 
    227     def _zip(foo):
    228       y_val, x_val = foo
    229       bar = backend.tile(y_val, array_ops.shape(x_val))
    230       return array_ops.stack([bar, x_val], axis=1)
    231 
    232     output = ragged_map_ops.map_fn(
    233         _zip, (y, x),
    234         dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1),
    235         infer_shape=False)
    236 
    237     self.assertRaggedEqual(
    238         output, [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]],
    239                  [[3, 70]], [[4, 80], [4, 90], [4, 100]]])
    240 
    241   def testBatchGather(self):
    242     tokens = ragged_factory_ops.constant([['hello', '.', 'there'], ['merhaba'],
    243                                           ['bonjour', '.', 'ca va', '?']])
    244     indices = ragged_factory_ops.constant([[0, 2], [0], [0, 2]])
    245 
    246     def gather(x):
    247       tokens_val, indices_val = x
    248       return array_ops.gather(tokens_val, indices_val)
    249 
    250     data = tokens, indices
    251     out = ragged_map_ops.map_fn(
    252         gather,
    253         data,
    254         dtype=ragged_tensor.RaggedTensorType(
    255             dtype=dtypes.string, ragged_rank=1),
    256         infer_shape=False)
    257 
    258     self.assertRaggedEqual(
    259         out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']])
    260 
    261   def testMismatchRaggedRank(self):
    262     elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
    263     fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
    264     with self.assertRaisesWithLiteralMatch(
    265         ValueError, r'The declared ragged rank (23) mismatches the result (1)'):
    266       _ = ragged_map_ops.map_fn(
    267           fn,
    268           elems,
    269           dtype=ragged_tensor.RaggedTensorType(
    270               dtype=dtypes.int64, ragged_rank=23))
    271 
    272   def testMismatchRaggedRank2(self):
    273     elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]])
    274     fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0])
    275     with self.assertRaisesWithLiteralMatch(
    276         ValueError, r'The declared ragged rank (10) mismatches the result (1)'):
    277       _ = ragged_map_ops.map_fn(
    278           fn,
    279           elems,
    280           dtype=ragged_tensor.RaggedTensorType(
    281               dtype=dtypes.int64, ragged_rank=10))
    282 
    283   def testMapOnSparseTensor(self):
    284     s = sparse_tensor.SparseTensor(
    285         indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
    286         values=[0, 5, 0, 4],
    287         dense_shape=[2, 2],
    288     )
    289     t2 = ragged_tensor.RaggedTensor.from_sparse(s)
    290     id_t2 = ragged_map_ops.map_fn(
    291         lambda x: x, t2,
    292     )
    293     self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]])
    294 
    295 
    296 if __name__ == '__main__':
    297   googletest.main()
    298