1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for ragged_map_ops.map_fn.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from absl.testing import parameterized 21 import numpy as np 22 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import sparse_tensor 25 from tensorflow.python.framework import test_util 26 from tensorflow.python.keras import backend 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import math_ops as mo 29 from tensorflow.python.ops import string_ops 30 from tensorflow.python.ops.ragged import ragged_factory_ops 31 from tensorflow.python.ops.ragged import ragged_functional_ops 32 from tensorflow.python.ops.ragged import ragged_map_ops 33 from tensorflow.python.ops.ragged import ragged_math_ops 34 from tensorflow.python.ops.ragged import ragged_tensor 35 from tensorflow.python.ops.ragged import ragged_test_util 36 from tensorflow.python.platform import googletest 37 38 39 @test_util.run_all_in_graph_and_eager_modes 40 class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase, 41 parameterized.TestCase): 42 43 @parameterized.parameters([ 44 # The following test sets map over a RaggedTensor and apply a 45 # transformation that returns with shape: 46 # [d1, (d2)] -> [d1] 47 dict( 48 fn=mo.reduce_mean, 49 elems=[[1, 2, 3], [4, 5], [6, 7]], 50 expected_output=[2, 4, 6], 51 ), 52 dict( 53 fn=string_ops.reduce_join, 54 elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']], 55 expected_output=[b'foobarbaz', b'a', b'bc'], 56 dtype=dtypes.string, 57 ), 58 # [d1, (d2)] -> [d1, 2] 59 dict( 60 fn=lambda x: array_ops.stack([mo.reduce_mean(x), mo.reduce_sum(x)]), 61 # fn=self.stack_mean_and_sum, 62 elems=[[1, 2, 3], [4, 5], [6, 7]], 63 expected_output=[[2, 6], [4.5, 9], [6.5, 13]], 64 dtype=dtypes.float32, 65 expected_ragged_rank=0, 66 ), 67 # [d1, (d2)] -> [d1, (d2)] 68 dict( 69 fn=lambda x: x + np.int64(1), 70 elems=[[1, 2, 3], [4, 5], [6, 7]], 71 expected_output=[[2, 3, 4], [5, 6], [7, 8]], 72 dtype=dtypes.int64, 73 result_dtype=ragged_tensor.RaggedTensorType( 74 dtype=dtypes.int64, ragged_rank=1), 75 ), 76 # [d1, (d2), d3] -> [d1, (d2), d3] 77 dict( 78 fn=lambda x: x + np.int64(1), 79 elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]], 80 elems_ragged_rank=1, 81 expected_ragged_rank=1, 82 result_dtype=ragged_tensor.RaggedTensorType( 83 dtype=dtypes.int64, ragged_rank=1), 84 expected_output=[[[2, 3], [4, 5]], [], [[6, 7], [8, 9], [10, 1]]], 85 ), 86 # [d1, (d2)] -> [d1, (d2), (d3)] 87 dict( 88 fn=lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]), 89 elems=[[1, 2, 3], [4, 5], [6, 7]], 90 expected_output=[[[1, 2, 3]], [[4, 5]], [[6, 7]]], 91 result_dtype=ragged_tensor.RaggedTensorType( 92 dtype=dtypes.int64, ragged_rank=2), 93 ), 94 # [d1, (d2), (d3)] -> [d1, (d2), (d3)] 95 dict( 96 fn=lambda x: ragged_functional_ops.map_flat_values(mo.add, x, 1), 97 elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], 98 expected_output=[[[2, 3, 4]], [[5, 6], [7, 8]]], 99 result_dtype=ragged_tensor.RaggedTensorType( 100 dtype=dtypes.int64, ragged_rank=2), 101 ), 102 # [d1, (d2), (d3)] -> [d1, (d2)] 103 dict( 104 fn=lambda x: ragged_math_ops.reduce_sum(x, axis=1), 105 elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], 106 expected_output=[[6], [9, 13]], 107 result_dtype=ragged_tensor.RaggedTensorType( 108 dtype=dtypes.int64, ragged_rank=1), 109 ), 110 # [d1, (d2), (d3)] -> [d1, (d3)] 111 dict( 112 fn=lambda x: ragged_math_ops.reduce_sum(x, axis=0), 113 elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], 114 expected_output=[[1, 2, 3], [10, 12]], 115 result_dtype=ragged_tensor.RaggedTensorType( 116 dtype=dtypes.int64, ragged_rank=1), 117 ), 118 # [d1, (d2), (d3)] -> [d1] 119 dict( 120 fn=ragged_math_ops.reduce_sum, 121 elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], 122 expected_output=[6, 22], 123 result_dtype=dtypes.int64, 124 ), 125 # [d1] -> [d1, (d2)] 126 dict( 127 fn=mo.range, 128 elems=[4, 0, 2], 129 expected_output=[[0, 1, 2, 3], [], [0, 1]], 130 result_dtype=ragged_tensor.RaggedTensorType( 131 dtype=dtypes.int64, ragged_rank=1), 132 ), 133 # [d1] -> [d1, (d2), (d3)] 134 dict( 135 fn=lambda x: ragged_math_ops.range(mo.range(x)), 136 elems=[5, 0, 3], 137 expected_output=[[[], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3]], [], 138 [[], [0], [0, 1]]], 139 result_dtype=ragged_tensor.RaggedTensorType( 140 dtype=dtypes.int64, ragged_rank=2), 141 ), 142 # [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)] 143 dict( 144 fn=lambda x: x + np.int64(1), 145 elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]], 146 expected_output=[[[[[2, 3, 4]], [[5], [6]]]], [[[[7, 8]]], [[[9], 147 []]]]], 148 result_dtype=ragged_tensor.RaggedTensorType( 149 dtype=dtypes.int64, ragged_rank=4), 150 ), 151 ]) 152 153 def testRaggedMap( 154 self, 155 fn, 156 elems, 157 expected_output, 158 expected_ragged_rank=None, 159 result_ragged_rank=None, 160 elems_ragged_rank=None, 161 dtype=dtypes.int64, 162 result_dtype=None, 163 infer_shape=False, 164 ): 165 elems = ragged_factory_ops.constant(elems, dtype, elems_ragged_rank) 166 output = ragged_map_ops.map_fn( 167 fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape) 168 169 expected_rt = ragged_factory_ops.constant( 170 expected_output, ragged_rank=expected_ragged_rank) 171 self.assertRaggedEqual(expected_rt, output) 172 173 def testRaggedMapOnStructure(self): 174 batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]]) 175 # [[10, 20, 30], [40], [50, 60, 70]] 176 robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10) 177 178 features = {'batman': batman, 'robin': robin} 179 180 def _reduce_sum_from_all(f): 181 return mo.reduce_sum(f['batman']) + mo.reduce_sum(f['robin']) 182 183 output = ragged_map_ops.map_fn( 184 fn=_reduce_sum_from_all, 185 elems=features, 186 dtype=dtypes.int32, 187 ) 188 189 self.assertRaggedEqual(output, [66, 44, 198]) 190 191 # Test mapping over a dict of RTs can produce a dict of RTs. 192 def testRaggedMapOnStructure_RaggedOutputs(self): 193 batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]]) 194 # [[10, 20, 30], [40], [50, 60, 70]] 195 robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10) 196 197 features = {'batman': batman, 'robin': robin} 198 199 def _increment(f): 200 return { 201 'batman': f['batman'] + 1, 202 'robin': f['robin'] + 1, 203 } 204 205 output = ragged_map_ops.map_fn( 206 fn=_increment, 207 elems=features, 208 infer_shape=False, 209 dtype={ 210 'batman': 211 ragged_tensor.RaggedTensorType( 212 dtype=dtypes.int32, ragged_rank=1), 213 'robin': 214 ragged_tensor.RaggedTensorType( 215 dtype=dtypes.int32, ragged_rank=1) 216 }, 217 ) 218 219 self.assertRaggedEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]]) 220 self.assertRaggedEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]]) 221 222 def testZip(self): 223 x = ragged_factory_ops.constant( 224 [[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64) 225 y = array_ops.expand_dims(mo.range(x.nrows(), dtype=dtypes.int64), axis=1) 226 227 def _zip(foo): 228 y_val, x_val = foo 229 bar = backend.tile(y_val, array_ops.shape(x_val)) 230 return array_ops.stack([bar, x_val], axis=1) 231 232 output = ragged_map_ops.map_fn( 233 _zip, (y, x), 234 dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1), 235 infer_shape=False) 236 237 self.assertRaggedEqual( 238 output, [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]], 239 [[3, 70]], [[4, 80], [4, 90], [4, 100]]]) 240 241 def testBatchGather(self): 242 tokens = ragged_factory_ops.constant([['hello', '.', 'there'], ['merhaba'], 243 ['bonjour', '.', 'ca va', '?']]) 244 indices = ragged_factory_ops.constant([[0, 2], [0], [0, 2]]) 245 246 def gather(x): 247 tokens_val, indices_val = x 248 return array_ops.gather(tokens_val, indices_val) 249 250 data = tokens, indices 251 out = ragged_map_ops.map_fn( 252 gather, 253 data, 254 dtype=ragged_tensor.RaggedTensorType( 255 dtype=dtypes.string, ragged_rank=1), 256 infer_shape=False) 257 258 self.assertRaggedEqual( 259 out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']]) 260 261 def testMismatchRaggedRank(self): 262 elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) 263 fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) 264 with self.assertRaisesWithLiteralMatch( 265 ValueError, r'The declared ragged rank (23) mismatches the result (1)'): 266 _ = ragged_map_ops.map_fn( 267 fn, 268 elems, 269 dtype=ragged_tensor.RaggedTensorType( 270 dtype=dtypes.int64, ragged_rank=23)) 271 272 def testMismatchRaggedRank2(self): 273 elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]]) 274 fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]) 275 with self.assertRaisesWithLiteralMatch( 276 ValueError, r'The declared ragged rank (10) mismatches the result (1)'): 277 _ = ragged_map_ops.map_fn( 278 fn, 279 elems, 280 dtype=ragged_tensor.RaggedTensorType( 281 dtype=dtypes.int64, ragged_rank=10)) 282 283 def testMapOnSparseTensor(self): 284 s = sparse_tensor.SparseTensor( 285 indices=[[0, 0], [0, 1], [1, 0], [1, 1]], 286 values=[0, 5, 0, 4], 287 dense_shape=[2, 2], 288 ) 289 t2 = ragged_tensor.RaggedTensor.from_sparse(s) 290 id_t2 = ragged_map_ops.map_fn( 291 lambda x: x, t2, 292 ) 293 self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]]) 294 295 296 if __name__ == '__main__': 297 googletest.main() 298