1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for ragged_array_ops.concat.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from absl.testing import parameterized 22 23 from tensorflow.python.eager import context 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import errors 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops.ragged import ragged_concat_ops 30 from tensorflow.python.ops.ragged import ragged_factory_ops 31 from tensorflow.python.ops.ragged import ragged_test_util 32 from tensorflow.python.platform import googletest 33 34 35 @test_util.run_all_in_graph_and_eager_modes 36 class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase, 37 parameterized.TestCase): 38 39 def _rt_inputs_to_tensors(self, rt_inputs, ragged_ranks=None): 40 if ragged_ranks is None: 41 ragged_ranks = [None] * len(rt_inputs) 42 return [ # pylint: disable=g-long-ternary 43 ragged_factory_ops.constant(rt_input, ragged_rank=rrank) 44 if rrank != 0 else constant_op.constant(rt_input) 45 for (rt_input, rrank) in zip(rt_inputs, ragged_ranks) 46 ] 47 48 @parameterized.parameters( 49 dict( 50 descr='Two rank-2 inputs with empty value axis=1', 51 rt_inputs=([[]], [[]]), 52 axis=1, 53 expected=[[]]), 54 dict( 55 descr='Two rank-2 inputs (ragged_rank=1), axis=0', 56 rt_inputs=( 57 [['a00', 'a01'], [], ['a20', 'a21']], # shape=(3, None) 58 [['b00'], ['b10']]), # shape=(2, None) 59 axis=0, 60 expected=[[b'a00', b'a01'], [], [b'a20', b'a21'], [b'b00'], 61 [b'b10']]), 62 dict( 63 descr='Two rank-2 inputs (ragged_rank=1), axis=1', 64 rt_inputs=( 65 [['a00', 'a01'], [], ['a20', 'a21', 'a22']], # shape=(3, None) 66 [['b00'], ['b10', 'b11', 'b12'], ['b20']]), # shape=(3, None) 67 axis=1, 68 expected=[ 69 [b'a00', b'a01', b'b00'], 70 [b'b10', b'b11', b'b12'], 71 [b'a20', b'a21', b'a22', b'b20']]), 72 dict( 73 descr='Two rank-2 inputs (ragged_rank=1), axis=-2', 74 rt_inputs=( 75 [['a00', 'a01'], [], ['a20', 'a21']], # shape=(3, None) 76 [['b00'], ['b10']]), # shape=(2, None) 77 axis=-2, 78 expected=[[b'a00', b'a01'], [], [b'a20', b'a21'], [b'b00'], 79 [b'b10']]), 80 dict( 81 descr='Two rank-2 inputs (ragged_rank=1), axis=-1', 82 rt_inputs=( 83 [['a00', 'a01'], [], ['a20', 'a21', 'a22']], # shape=(3, None) 84 [['b00'], ['b10', 'b11', 'b12'], ['b20']]), # shape=(3, None) 85 axis=-1, 86 expected=[ 87 [b'a00', b'a01', b'b00'], 88 [b'b10', b'b11', b'b12'], 89 [b'a20', b'a21', b'a22', b'b20']], 90 expected_shape=[3, None]), 91 dict( 92 descr='Three rank-2 inputs (ragged_rank=1), axis=0', 93 rt_inputs=( 94 [['a00', 'a01'], [], ['a20', 'a21', 'a22']], # shape=(3, None) 95 [['b00'], ['b10']], # shape=(2, None) 96 [['c00'], ['c10', 'c11'], ['c21']]), # shape=(3, None) 97 axis=0, 98 expected=[[b'a00', b'a01'], [], [b'a20', b'a21', b'a22'], [b'b00'], 99 [b'b10'], [b'c00'], [b'c10', b'c11'], [b'c21']]), 100 dict( 101 descr='Three rank-2 inputs (ragged_rank=1), axis=1', 102 rt_inputs=( 103 [['a00', 'a01'], [], ['a20', 'a21', 'a22']], # shape=(3, None) 104 [['b00'], ['b10', 'b11', 'b12'], ['b20']], # shape=(3, None) 105 [[], ['c10', 'c11'], ['c20', 'c21']]), # shape=(3, None) 106 axis=1, 107 expected=[ 108 [b'a00', b'a01', b'b00'], 109 [b'b10', b'b11', b'b12', b'c10', b'c11'], 110 [b'a20', b'a21', b'a22', b'b20', b'c20', b'c21']]), 111 dict( 112 descr='Three rank-3 inputs (ragged_rank=2), axis=0', 113 rt_inputs=( 114 [[['a000', 'a001'], ['a010']], 115 [['a100', 'a101', 'a102'], ['a110', 'a111']]], 116 [[['b000']], [['b100', 'b101'], ['b110']]], 117 [[], [['c100', 'c101', 'c102', 'c103']], [[], ['c210', 'c211']]]), 118 axis=0, 119 expected=[ 120 [[b'a000', b'a001'], [b'a010']], 121 [[b'a100', b'a101', b'a102'], [b'a110', b'a111']], 122 [[b'b000']], 123 [[b'b100', b'b101'], [b'b110']], 124 [], 125 [[b'c100', b'c101', b'c102', b'c103']], 126 [[], [b'c210', b'c211']]]), 127 dict( 128 descr='Three rank-3 inputs (ragged_rank=2), axis=1', 129 rt_inputs=( 130 [[['a000', 'a001'], ['a010']], 131 [['a100', 'a101', 'a102'], ['a110', 'a111']]], 132 [[['b000']], [['b100', 'b101'], ['b110']]], 133 [[], [[], ['c110', 'c111']]]), 134 axis=1, 135 expected=[ 136 [[b'a000', b'a001'], [b'a010'], [b'b000']], 137 [[b'a100', b'a101', b'a102'], [b'a110', b'a111'], 138 [b'b100', b'b101'], [b'b110'], [], [b'c110', b'c111']]]), 139 dict( 140 descr='Three rank-3 inputs (ragged_rank=2), axis=2', 141 rt_inputs=( 142 [[['a000', 'a001'], ['a010']], 143 [['a100', 'a101', 'a102'], ['a110', 'a111']]], 144 [[[], ['b010', 'b011']], [['b100', 'b101'], ['b110']]], 145 [[['c000'], ['c010']], [[], ['c110', 'c111']]]), 146 axis=2, 147 expected=[ 148 [[b'a000', b'a001', b'c000'], 149 [b'a010', b'b010', b'b011', b'c010']], 150 [[b'a100', b'a101', b'a102', b'b100', b'b101'], 151 [b'a110', b'a111', b'b110', b'c110', b'c111']]]), 152 dict( 153 descr='Three rank-3 inputs (ragged_rank=2), axis=-1', 154 rt_inputs=( 155 [[['a000', 'a001'], ['a010']], 156 [['a100', 'a101', 'a102'], ['a110', 'a111']]], 157 [[[], ['b010', 'b011']], [['b100', 'b101'], ['b110']]], 158 [[['c000'], ['c010']], [[], ['c110', 'c111']]]), 159 axis=-1, 160 expected=[ 161 [[b'a000', b'a001', b'c000'], 162 [b'a010', b'b010', b'b011', b'c010']], 163 [[b'a100', b'a101', b'a102', b'b100', b'b101'], 164 [b'a110', b'a111', b'b110', b'c110', b'c111']]]), 165 dict( 166 descr='ragged_concat([uniform, ragged, uniform], axis=1)', 167 ragged_ranks=[0, 1, 0], 168 rt_inputs=( 169 [['0('], ['1('], ['2(']], # shape=(3, 1) 170 [['b00'], ['b10', 'b11', 'b12'], ['b20']], # shape=(3, None) 171 [[')0'], [')1'], [')2']]), # shape=(3, 1) 172 axis=1, 173 expected=[ 174 [b'0(', b'b00', b')0'], 175 [b'1(', b'b10', b'b11', b'b12', b')1'], 176 [b'2(', b'b20', b')2']]), 177 dict( 178 descr='ragged_concat([uniform, uniform], axis=0)', 179 ragged_ranks=[0, 0], 180 rt_inputs=( 181 [['a00', 'a01'], ['a10', 'a11'], ['a20', 'a21']], # shape=(3, 2) 182 [['b00', 'b01', 'b02'], ['b10', 'b11', 'b12']]), # shape=(2, 3) 183 axis=0, 184 expected=[ 185 [b'a00', b'a01'], [b'a10', b'a11'], [b'a20', b'a21'], 186 [b'b00', b'b01', b'b02'], [b'b10', b'b11', b'b12']], 187 expected_ragged_rank=1), 188 dict( 189 descr='ragged_concat([uniform, ragged], axis=0)', 190 ragged_ranks=[0, 1], 191 rt_inputs=( 192 [['a00', 'a01'], ['a10', 'a11'], ['a20', 'a21']], # shape=(3, 2) 193 [['b00', 'b01', 'b02'], ['b10', 'b11', 'b12']]), # shape=(2, 3) 194 axis=0, 195 expected=[ 196 [b'a00', b'a01'], [b'a10', b'a11'], [b'a20', b'a21'], 197 [b'b00', b'b01', b'b02'], [b'b10', b'b11', b'b12']]), 198 dict( 199 descr='ragged_concat([uniform, ragged], axis=0) with rank-3 inputs', 200 ragged_ranks=[0, 2], 201 rt_inputs=( 202 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], # shape = (2, 2, 2) 203 [[[8], [8, 8]]]), # shape = (2, None, None) 204 axis=0, 205 expected=[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], [8, 8]]]), 206 dict( 207 descr='Two rank-3 inputs with ragged_rank=1, axis=-1', 208 ragged_ranks=[1, 1], 209 rt_inputs=( 210 [[[0, 1], [2, 3], [4, 5]], [], [[6, 7], [8, 9]]], 211 [[[9, 8], [7, 6], [5, 4]], [], [[3, 2], [1, 0]]]), 212 axis=-1, 213 expected=[ 214 [[0, 1, 9, 8], [2, 3, 7, 6], [4, 5, 5, 4]], [], 215 [[6, 7, 3, 2], [8, 9, 1, 0]]], 216 expected_ragged_rank=1), 217 dict( 218 descr='ragged_concat([vector, vector], axis=0)', 219 ragged_ranks=[0, 0], 220 rt_inputs=([1, 2, 3], [4, 5, 6]), 221 axis=0, 222 expected=[1, 2, 3, 4, 5, 6]), 223 dict( 224 descr='One input (so ragged_conat is a noop)', 225 rt_inputs=([['a00', 'a01'], [], ['a20', 'a21']],), 226 axis=0, 227 expected=[[b'a00', b'a01'], [], [b'a20', b'a21']]), 228 ) # pyformat: disable 229 def testRaggedConcat(self, 230 descr, 231 rt_inputs, 232 axis, 233 expected, 234 ragged_ranks=None, 235 expected_ragged_rank=None, 236 expected_shape=None): 237 rt_inputs = self._rt_inputs_to_tensors(rt_inputs, ragged_ranks) 238 concatenated = ragged_concat_ops.concat(rt_inputs, axis) 239 if expected_ragged_rank is not None: 240 self.assertEqual(concatenated.ragged_rank, expected_ragged_rank) 241 if expected_shape is not None: 242 self.assertEqual(concatenated.shape.as_list(), expected_shape) 243 self.assertRaggedEqual(concatenated, expected) 244 245 @parameterized.parameters( 246 dict( 247 rt_inputs=(), 248 axis=0, 249 error=ValueError, 250 message=r'rt_inputs may not be empty\.'), 251 dict( 252 rt_inputs=([[1, 2]], [[3, 4]]), 253 axis=r'foo', 254 error=TypeError, 255 message='axis must be an int'), 256 dict( 257 rt_inputs=([[1, 2]], [[3, 4]]), 258 axis=-3, 259 error=ValueError, 260 message='axis=-3 out of bounds: expected -2<=axis<2'), 261 dict( 262 rt_inputs=([[1, 2]], [[3, 4]]), 263 axis=2, 264 error=ValueError, 265 message='axis=2 out of bounds: expected -2<=axis<2'), 266 dict( 267 ragged_ranks=(0, 0), 268 rt_inputs=([[1, 2]], [[3, 4], [5, 6]]), 269 axis=1, 270 error=(ValueError, errors.InvalidArgumentError)), 271 ) 272 def testStaticError(self, 273 rt_inputs, 274 axis, 275 error, 276 message=None, 277 ragged_ranks=None): 278 rt_inputs = self._rt_inputs_to_tensors(rt_inputs, ragged_ranks) 279 self.assertRaisesRegexp(error, message, ragged_concat_ops.concat, rt_inputs, 280 axis) 281 282 @parameterized.parameters([ 283 dict( 284 ragged_ranks=(1, 1), 285 rt_inputs=([[1, 2]], [[3, 4], [5, 6]]), 286 axis=1, 287 error=errors.InvalidArgumentError, 288 message='Input tensors have incompatible shapes'), 289 ]) 290 def testRuntimeError(self, rt_inputs, axis, error, message, 291 ragged_ranks=None): 292 if context.executing_eagerly(): 293 return 294 rt_inputs = [ 295 array_ops.placeholder_with_default(rt, shape=None) for rt in rt_inputs 296 ] 297 concatenated = ragged_concat_ops.concat(rt_inputs, axis) 298 with self.assertRaisesRegexp(error, message): 299 self.evaluate(concatenated) 300 301 def testNegativeAxisWithUnknownRankError(self): 302 if context.executing_eagerly(): 303 return 304 rt_inputs = [ 305 array_ops.placeholder(dtypes.int64), 306 array_ops.placeholder(dtypes.int64) 307 ] 308 self.assertRaisesRegexp( 309 ValueError, r'axis may only be negative if ndims is statically known.', 310 ragged_concat_ops.concat, rt_inputs, -1) 311 312 def testSingleTensorInput(self): 313 """Tests ragged_concat with a single tensor input. 314 315 Usually, we pass a list of values in for rt_inputs. However, you can 316 also pass in a single value (as with tf.concat), in which case it simply 317 returns that tensor. This test exercises that path. 318 """ 319 rt_inputs = ragged_factory_ops.constant([[1, 2], [3, 4]]) 320 concatenated = ragged_concat_ops.concat(rt_inputs, 0) 321 self.assertRaggedEqual(concatenated, [[1, 2], [3, 4]]) 322 323 324 if __name__ == '__main__': 325 googletest.main() 326