1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for utilities working with arbitrarily nested structures.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 23 import numpy as np 24 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.platform import test 30 from tensorflow.python.util import nest 31 32 33 class NestTest(test.TestCase): 34 35 def testFlattenAndPack(self): 36 structure = ((3, 4), 5, (6, 7, (9, 10), 8)) 37 flat = ["a", "b", "c", "d", "e", "f", "g", "h"] 38 self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) 39 self.assertEqual( 40 nest.pack_sequence_as(structure, flat), (("a", "b"), "c", 41 ("d", "e", ("f", "g"), "h"))) 42 point = collections.namedtuple("Point", ["x", "y"]) 43 structure = (point(x=4, y=2), ((point(x=1, y=0),),)) 44 flat = [4, 2, 1, 0] 45 self.assertEqual(nest.flatten(structure), flat) 46 restructured_from_flat = nest.pack_sequence_as(structure, flat) 47 self.assertEqual(restructured_from_flat, structure) 48 self.assertEqual(restructured_from_flat[0].x, 4) 49 self.assertEqual(restructured_from_flat[0].y, 2) 50 self.assertEqual(restructured_from_flat[1][0][0].x, 1) 51 self.assertEqual(restructured_from_flat[1][0][0].y, 0) 52 53 self.assertEqual([5], nest.flatten(5)) 54 self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) 55 56 self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) 57 self.assertEqual( 58 np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) 59 60 with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): 61 nest.pack_sequence_as("scalar", [4, 5]) 62 63 with self.assertRaisesRegexp(TypeError, "flat_sequence"): 64 nest.pack_sequence_as([4, 5], "bad_sequence") 65 66 with self.assertRaises(ValueError): 67 nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) 68 69 def testFlattenDictOrder(self): 70 """`flatten` orders dicts by key, including OrderedDicts.""" 71 ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) 72 plain = {"d": 3, "b": 1, "a": 0, "c": 2} 73 ordered_flat = nest.flatten(ordered) 74 plain_flat = nest.flatten(plain) 75 self.assertEqual([0, 1, 2, 3], ordered_flat) 76 self.assertEqual([0, 1, 2, 3], plain_flat) 77 78 def testPackDictOrder(self): 79 """Packing orders dicts by key, including OrderedDicts.""" 80 ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) 81 plain = {"d": 0, "b": 0, "a": 0, "c": 0} 82 seq = [0, 1, 2, 3] 83 ordered_reconstruction = nest.pack_sequence_as(ordered, seq) 84 plain_reconstruction = nest.pack_sequence_as(plain, seq) 85 self.assertEqual( 86 collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), 87 ordered_reconstruction) 88 self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) 89 90 def testFlattenAndPack_withDicts(self): 91 # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. 92 named_tuple = collections.namedtuple("A", ("b", "c")) 93 mess = [ 94 "z", 95 named_tuple(3, 4), 96 { 97 "c": [ 98 1, 99 collections.OrderedDict([ 100 ("b", 3), 101 ("a", 2), 102 ]), 103 ], 104 "b": 5 105 }, 106 17 107 ] 108 109 flattened = nest.flatten(mess) 110 self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) 111 112 structure_of_mess = [ 113 14, 114 named_tuple("a", True), 115 { 116 "c": [ 117 0, 118 collections.OrderedDict([ 119 ("b", 9), 120 ("a", 8), 121 ]), 122 ], 123 "b": 3 124 }, 125 "hi everybody", 126 ] 127 128 unflattened = nest.pack_sequence_as(structure_of_mess, flattened) 129 self.assertEqual(unflattened, mess) 130 131 # Check also that the OrderedDict was created, with the correct key order. 132 unflattened_ordered_dict = unflattened[2]["c"][1] 133 self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) 134 self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) 135 136 def testFlatten_numpyIsNotFlattened(self): 137 structure = np.array([1, 2, 3]) 138 flattened = nest.flatten(structure) 139 self.assertEqual(len(flattened), 1) 140 141 def testFlatten_stringIsNotFlattened(self): 142 structure = "lots of letters" 143 flattened = nest.flatten(structure) 144 self.assertEqual(len(flattened), 1) 145 unflattened = nest.pack_sequence_as("goodbye", flattened) 146 self.assertEqual(structure, unflattened) 147 148 def testPackSequenceAs_notIterableError(self): 149 with self.assertRaisesRegexp(TypeError, 150 "flat_sequence must be a sequence"): 151 nest.pack_sequence_as("hi", "bye") 152 153 def testPackSequenceAs_wrongLengthsError(self): 154 with self.assertRaisesRegexp( 155 ValueError, 156 "Structure had 2 elements, but flat_sequence had 3 elements."): 157 nest.pack_sequence_as(["hello", "world"], 158 ["and", "goodbye", "again"]) 159 160 def testIsSequence(self): 161 self.assertFalse(nest.is_sequence("1234")) 162 self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) 163 self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) 164 self.assertTrue(nest.is_sequence([])) 165 self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) 166 self.assertFalse(nest.is_sequence(set([1, 2]))) 167 ones = array_ops.ones([2, 3]) 168 self.assertFalse(nest.is_sequence(ones)) 169 self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) 170 self.assertFalse(nest.is_sequence(np.ones((4, 5)))) 171 172 def testFlattenDictItems(self): 173 dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} 174 flat = {4: "a", 5: "b", 6: "c", 8: "d"} 175 self.assertEqual(nest.flatten_dict_items(dictionary), flat) 176 177 with self.assertRaises(TypeError): 178 nest.flatten_dict_items(4) 179 180 bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))} 181 with self.assertRaisesRegexp(ValueError, "not unique"): 182 nest.flatten_dict_items(bad_dictionary) 183 184 another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))} 185 with self.assertRaisesRegexp( 186 ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): 187 nest.flatten_dict_items(another_bad_dictionary) 188 189 def testAssertSameStructure(self): 190 structure1 = (((1, 2), 3), 4, (5, 6)) 191 structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 192 structure_different_num_elements = ("spam", "eggs") 193 structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) 194 nest.assert_same_structure(structure1, structure2) 195 nest.assert_same_structure("abc", 1.0) 196 nest.assert_same_structure("abc", np.array([0, 1])) 197 nest.assert_same_structure("abc", constant_op.constant([0, 1])) 198 199 with self.assertRaisesRegexp( 200 ValueError, 201 ("don't have the same number of elements\\.\n\n" 202 "First structure \\(6 elements\\):.*?" 203 "\n\nSecond structure \\(2 elements\\):")): 204 nest.assert_same_structure(structure1, structure_different_num_elements) 205 206 with self.assertRaisesRegexp( 207 ValueError, 208 ("don't have the same number of elements\\.\n\n" 209 "First structure \\(2 elements\\):.*?" 210 "\n\nSecond structure \\(1 elements\\):")): 211 nest.assert_same_structure([0, 1], np.array([0, 1])) 212 213 with self.assertRaisesRegexp( 214 ValueError, 215 ("don't have the same number of elements\\.\n\n" 216 "First structure \\(1 elements\\):.*" 217 "\n\nSecond structure \\(2 elements\\):")): 218 nest.assert_same_structure(0, [0, 1]) 219 220 self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) 221 222 with self.assertRaisesRegexp( 223 ValueError, 224 ("don't have the same nested structure\\.\n\n" 225 "First structure: .*?\n\nSecond structure: ")): 226 nest.assert_same_structure(structure1, structure_different_nesting) 227 228 named_type_0 = collections.namedtuple("named_0", ("a", "b")) 229 named_type_1 = collections.namedtuple("named_1", ("a", "b")) 230 self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), 231 named_type_0("a", "b")) 232 233 nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) 234 235 self.assertRaises(TypeError, nest.assert_same_structure, 236 named_type_0(3, 4), named_type_1(3, 4)) 237 238 with self.assertRaisesRegexp( 239 ValueError, 240 ("don't have the same nested structure\\.\n\n" 241 "First structure: .*?\n\nSecond structure: ")): 242 nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4)) 243 244 with self.assertRaisesRegexp( 245 ValueError, 246 ("don't have the same nested structure\\.\n\n" 247 "First structure: .*?\n\nSecond structure: ")): 248 nest.assert_same_structure([[3], 4], [3, [4]]) 249 250 structure1_list = [[[1, 2], 3], 4, [5, 6]] 251 with self.assertRaisesRegexp(TypeError, 252 "don't have the same sequence type"): 253 nest.assert_same_structure(structure1, structure1_list) 254 nest.assert_same_structure(structure1, structure2, check_types=False) 255 nest.assert_same_structure(structure1, structure1_list, check_types=False) 256 257 with self.assertRaisesRegexp(ValueError, 258 "don't have the same set of keys"): 259 nest.assert_same_structure({"a": 1}, {"b": 1}) 260 261 same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) 262 same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) 263 nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3)) 264 265 # This assertion is expected to pass: two namedtuples with the same 266 # name and field names are considered to be identical. 267 same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y")) 268 same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y")) 269 nest.assert_same_structure( 270 same_name_type_0(same_name_type_2(0, 1), 2), 271 same_name_type_1(same_name_type_3(2, 3), 4)) 272 273 expected_message = "The two structures don't have the same.*" 274 with self.assertRaisesRegexp(ValueError, expected_message): 275 nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)), 276 same_name_type_1(same_name_type_0(0, 1), 2)) 277 278 same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b")) 279 self.assertRaises(TypeError, nest.assert_same_structure, 280 same_name_type_0(0, 1), same_name_type_1(2, 3)) 281 282 same_name_type_1 = collections.namedtuple("same_name", ("x", "y")) 283 self.assertRaises(TypeError, nest.assert_same_structure, 284 same_name_type_0(0, 1), same_name_type_1(2, 3)) 285 286 class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))): 287 pass 288 self.assertRaises(TypeError, nest.assert_same_structure, 289 same_name_type_0(0, 1), SameNamedType1(2, 3)) 290 291 def testMapStructure(self): 292 structure1 = (((1, 2), 3), 4, (5, 6)) 293 structure2 = (((7, 8), 9), 10, (11, 12)) 294 structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) 295 nest.assert_same_structure(structure1, structure1_plus1) 296 self.assertAllEqual( 297 [2, 3, 4, 5, 6, 7], 298 nest.flatten(structure1_plus1)) 299 structure1_plus_structure2 = nest.map_structure( 300 lambda x, y: x + y, structure1, structure2) 301 self.assertEqual( 302 (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), 303 structure1_plus_structure2) 304 305 self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) 306 307 self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) 308 309 # Empty structures 310 self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) 311 self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) 312 self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) 313 empty_nt = collections.namedtuple("empty_nt", "") 314 self.assertEqual(empty_nt(), nest.map_structure(lambda x: x + 1, 315 empty_nt())) 316 317 # This is checking actual equality of types, empty list != empty tuple 318 self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) 319 320 with self.assertRaisesRegexp(TypeError, "callable"): 321 nest.map_structure("bad", structure1_plus1) 322 323 with self.assertRaisesRegexp(ValueError, "at least one structure"): 324 nest.map_structure(lambda x: x) 325 326 with self.assertRaisesRegexp(ValueError, "same number of elements"): 327 nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) 328 329 with self.assertRaisesRegexp(ValueError, "same nested structure"): 330 nest.map_structure(lambda x, y: None, 3, (3,)) 331 332 with self.assertRaisesRegexp(TypeError, "same sequence type"): 333 nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) 334 335 with self.assertRaisesRegexp(ValueError, "same nested structure"): 336 nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) 337 338 structure1_list = [[[1, 2], 3], 4, [5, 6]] 339 with self.assertRaisesRegexp(TypeError, "same sequence type"): 340 nest.map_structure(lambda x, y: None, structure1, structure1_list) 341 342 nest.map_structure(lambda x, y: None, structure1, structure1_list, 343 check_types=False) 344 345 with self.assertRaisesRegexp(ValueError, "same nested structure"): 346 nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), 347 check_types=False) 348 349 with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): 350 nest.map_structure(lambda x: None, structure1, foo="a") 351 352 with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): 353 nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") 354 355 def testMapStructureWithStrings(self): 356 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 357 inp_a = ab_tuple(a="foo", b=("bar", "baz")) 358 inp_b = ab_tuple(a=2, b=(1, 3)) 359 out = nest.map_structure(lambda string, repeats: string * repeats, 360 inp_a, 361 inp_b) 362 self.assertEqual("foofoo", out.a) 363 self.assertEqual("bar", out.b[0]) 364 self.assertEqual("bazbazbaz", out.b[1]) 365 366 nt = ab_tuple(a=("something", "something_else"), 367 b="yet another thing") 368 rev_nt = nest.map_structure(lambda x: x[::-1], nt) 369 # Check the output is the correct structure, and all strings are reversed. 370 nest.assert_same_structure(nt, rev_nt) 371 self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) 372 self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) 373 self.assertEqual(nt.b[::-1], rev_nt.b) 374 375 def testMapStructureOverPlaceholders(self): 376 inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), 377 array_ops.placeholder(dtypes.float32, shape=[3, 7])) 378 inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), 379 array_ops.placeholder(dtypes.float32, shape=[3, 7])) 380 381 output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) 382 383 nest.assert_same_structure(output, inp_a) 384 self.assertShapeEqual(np.zeros((3, 4)), output[0]) 385 self.assertShapeEqual(np.zeros((3, 7)), output[1]) 386 387 feed_dict = { 388 inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), 389 inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) 390 } 391 392 with self.test_session() as sess: 393 output_np = sess.run(output, feed_dict=feed_dict) 394 self.assertAllClose(output_np[0], 395 feed_dict[inp_a][0] + feed_dict[inp_b][0]) 396 self.assertAllClose(output_np[1], 397 feed_dict[inp_a][1] + feed_dict[inp_b][1]) 398 399 def testAssertShallowStructure(self): 400 inp_ab = ["a", "b"] 401 inp_abc = ["a", "b", "c"] 402 expected_message = ( 403 "The two structures don't have the same sequence length. Input " 404 "structure has length 2, while shallow structure has length 3.") 405 with self.assertRaisesRegexp(ValueError, expected_message): 406 nest.assert_shallow_structure(inp_abc, inp_ab) 407 408 inp_ab1 = [(1, 1), (2, 2)] 409 inp_ab2 = [[1, 1], [2, 2]] 410 expected_message = ( 411 "The two structures don't have the same sequence type. Input structure " 412 "has type <(type|class) 'tuple'>, while shallow structure has type " 413 "<(type|class) 'list'>.") 414 with self.assertRaisesRegexp(TypeError, expected_message): 415 nest.assert_shallow_structure(inp_ab2, inp_ab1) 416 nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) 417 418 inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} 419 inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} 420 expected_message = ( 421 r"The two structures don't have the same keys. Input " 422 r"structure has keys \['c'\], while shallow structure has " 423 r"keys \['d'\].") 424 425 with self.assertRaisesRegexp(ValueError, expected_message): 426 nest.assert_shallow_structure(inp_ab2, inp_ab1) 427 428 inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) 429 inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) 430 nest.assert_shallow_structure(inp_ab, inp_ba) 431 432 # This assertion is expected to pass: two namedtuples with the same 433 # name and field names are considered to be identical. 434 same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) 435 same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) 436 inp_shallow = same_name_type_0(1, 2) 437 inp_deep = same_name_type_1(1, [1, 2, 3]) 438 nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) 439 nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) 440 441 def testFlattenUpTo(self): 442 # Shallow tree ends at scalar. 443 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 444 shallow_tree = [[True, True], [False, True]] 445 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 446 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 447 self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) 448 self.assertEqual(flattened_shallow_tree, [True, True, False, True]) 449 450 # Shallow tree ends at string. 451 input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] 452 shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] 453 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 454 input_tree) 455 input_tree_flattened = nest.flatten(input_tree) 456 self.assertEqual(input_tree_flattened_as_shallow_tree, 457 [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) 458 self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) 459 460 # Make sure dicts are correctly flattened, yielding values, not keys. 461 input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} 462 shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} 463 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 464 input_tree) 465 self.assertEqual(input_tree_flattened_as_shallow_tree, 466 [1, {"c": 2}, 3, (4, 5)]) 467 468 # Namedtuples. 469 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 470 input_tree = ab_tuple(a=[0, 1], b=2) 471 shallow_tree = ab_tuple(a=0, b=1) 472 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 473 input_tree) 474 self.assertEqual(input_tree_flattened_as_shallow_tree, 475 [[0, 1], 2]) 476 477 # Nested dicts, OrderedDicts and namedtuples. 478 input_tree = collections.OrderedDict( 479 [("a", ab_tuple(a=[0, {"b": 1}], b=2)), 480 ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) 481 shallow_tree = input_tree 482 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 483 input_tree) 484 self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) 485 shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) 486 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 487 input_tree) 488 self.assertEqual(input_tree_flattened_as_shallow_tree, 489 [ab_tuple(a=[0, {"b": 1}], b=2), 490 3, 491 collections.OrderedDict([("f", 4)])]) 492 shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) 493 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 494 input_tree) 495 self.assertEqual(input_tree_flattened_as_shallow_tree, 496 [ab_tuple(a=[0, {"b": 1}], b=2), 497 {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) 498 499 ## Shallow non-list edge-case. 500 # Using iterable elements. 501 input_tree = ["input_tree"] 502 shallow_tree = "shallow_tree" 503 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 504 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 505 self.assertEqual(flattened_input_tree, [input_tree]) 506 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 507 508 input_tree = ["input_tree_0", "input_tree_1"] 509 shallow_tree = "shallow_tree" 510 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 511 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 512 self.assertEqual(flattened_input_tree, [input_tree]) 513 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 514 515 # Using non-iterable elements. 516 input_tree = [0] 517 shallow_tree = 9 518 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 519 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 520 self.assertEqual(flattened_input_tree, [input_tree]) 521 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 522 523 input_tree = [0, 1] 524 shallow_tree = 9 525 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 526 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 527 self.assertEqual(flattened_input_tree, [input_tree]) 528 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 529 530 ## Both non-list edge-case. 531 # Using iterable elements. 532 input_tree = "input_tree" 533 shallow_tree = "shallow_tree" 534 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 535 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 536 self.assertEqual(flattened_input_tree, [input_tree]) 537 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 538 539 # Using non-iterable elements. 540 input_tree = 0 541 shallow_tree = 0 542 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 543 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 544 self.assertEqual(flattened_input_tree, [input_tree]) 545 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 546 547 ## Input non-list edge-case. 548 # Using iterable elements. 549 input_tree = "input_tree" 550 shallow_tree = ["shallow_tree"] 551 expected_message = ("If shallow structure is a sequence, input must also " 552 "be a sequence. Input has type: <(type|class) 'str'>.") 553 with self.assertRaisesRegexp(TypeError, expected_message): 554 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 555 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 556 self.assertEqual(flattened_shallow_tree, shallow_tree) 557 558 input_tree = "input_tree" 559 shallow_tree = ["shallow_tree_9", "shallow_tree_8"] 560 with self.assertRaisesRegexp(TypeError, expected_message): 561 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 562 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 563 self.assertEqual(flattened_shallow_tree, shallow_tree) 564 565 # Using non-iterable elements. 566 input_tree = 0 567 shallow_tree = [9] 568 expected_message = ("If shallow structure is a sequence, input must also " 569 "be a sequence. Input has type: <(type|class) 'int'>.") 570 with self.assertRaisesRegexp(TypeError, expected_message): 571 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 572 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 573 self.assertEqual(flattened_shallow_tree, shallow_tree) 574 575 input_tree = 0 576 shallow_tree = [9, 8] 577 with self.assertRaisesRegexp(TypeError, expected_message): 578 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 579 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 580 self.assertEqual(flattened_shallow_tree, shallow_tree) 581 582 def testMapStructureUpTo(self): 583 # Named tuples. 584 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 585 op_tuple = collections.namedtuple("op_tuple", "add, mul") 586 inp_val = ab_tuple(a=2, b=3) 587 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 588 out = nest.map_structure_up_to( 589 inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) 590 self.assertEqual(out.a, 6) 591 self.assertEqual(out.b, 15) 592 593 # Lists. 594 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 595 name_list = ["evens", ["odds", "primes"]] 596 out = nest.map_structure_up_to( 597 name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), 598 name_list, data_list) 599 self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) 600 601 def testGetTraverseShallowStructure(self): 602 scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] 603 scalar_traverse_r = nest.get_traverse_shallow_structure( 604 lambda s: not isinstance(s, tuple), 605 scalar_traverse_input) 606 self.assertEqual(scalar_traverse_r, 607 [True, True, False, [True, True], {"a": False}, []]) 608 nest.assert_shallow_structure(scalar_traverse_r, 609 scalar_traverse_input) 610 611 structure_traverse_input = [(1, [2]), ([1], 2)] 612 structure_traverse_r = nest.get_traverse_shallow_structure( 613 lambda s: (True, False) if isinstance(s, tuple) else True, 614 structure_traverse_input) 615 self.assertEqual(structure_traverse_r, 616 [(True, False), ([True], False)]) 617 nest.assert_shallow_structure(structure_traverse_r, 618 structure_traverse_input) 619 620 with self.assertRaisesRegexp(TypeError, "returned structure"): 621 nest.get_traverse_shallow_structure(lambda _: [True], 0) 622 623 with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): 624 nest.get_traverse_shallow_structure(lambda _: 1, [1]) 625 626 with self.assertRaisesRegexp( 627 TypeError, "didn't return a depth=1 structure of bools"): 628 nest.get_traverse_shallow_structure(lambda _: [1], [1]) 629 630 def testYieldFlatStringPaths(self): 631 for inputs_expected in ({"inputs": [], "expected": []}, 632 {"inputs": 3, "expected": [()]}, 633 {"inputs": [3], "expected": [(0,)]}, 634 {"inputs": {"a": 3}, "expected": [("a",)]}, 635 {"inputs": {"a": {"b": 4}}, 636 "expected": [("a", "b")]}, 637 {"inputs": [{"a": 2}], "expected": [(0, "a")]}, 638 {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, 639 {"inputs": [{"a": [(23, 42)]}], 640 "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, 641 {"inputs": [{"a": ([23], 42)}], 642 "expected": [(0, "a", 0, 0), (0, "a", 1)]}, 643 {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, 644 "expected": [("a", "a"), ("c", 0, 0, 0)]}, 645 {"inputs": {"0": [{"1": 23}]}, 646 "expected": [("0", 0, "1")]}): 647 inputs = inputs_expected["inputs"] 648 expected = inputs_expected["expected"] 649 self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) 650 651 def testFlattenWithStringPaths(self): 652 for inputs_expected in ( 653 {"inputs": [], "expected": []}, 654 {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, 655 {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): 656 inputs = inputs_expected["inputs"] 657 expected = inputs_expected["expected"] 658 self.assertEqual( 659 nest.flatten_with_joined_string_paths(inputs, separator="/"), 660 expected) 661 662 # Need a separate test for namedtuple as we can't declare tuple definitions 663 # in the @parameterized arguments. 664 def testFlattenNamedTuple(self): 665 # pylint: disable=invalid-name 666 Foo = collections.namedtuple("Foo", ["a", "b"]) 667 Bar = collections.namedtuple("Bar", ["c", "d"]) 668 # pylint: enable=invalid-name 669 test_cases = [ 670 (Foo(a=3, b=Bar(c=23, d=42)), 671 [("a", 3), ("b/c", 23), ("b/d", 42)]), 672 (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), 673 [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), 674 (Bar(c=42, d=43), 675 [("c", 42), ("d", 43)]), 676 (Bar(c=[42], d=43), 677 [("c/0", 42), ("d", 43)]), 678 ] 679 for inputs, expected in test_cases: 680 self.assertEqual( 681 list(nest.flatten_with_joined_string_paths(inputs)), expected) 682 683 684 if __name__ == "__main__": 685 test.main() 686