Home | History | Annotate | Download | only in util
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for utilities working with arbitrarily nested structures."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.util import nest
     31 
     32 
     33 class NestTest(test.TestCase):
     34 
     35   def testFlattenAndPack(self):
     36     structure = ((3, 4), 5, (6, 7, (9, 10), 8))
     37     flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
     38     self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
     39     self.assertEqual(
     40         nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
     41                                                  ("d", "e", ("f", "g"), "h")))
     42     point = collections.namedtuple("Point", ["x", "y"])
     43     structure = (point(x=4, y=2), ((point(x=1, y=0),),))
     44     flat = [4, 2, 1, 0]
     45     self.assertEqual(nest.flatten(structure), flat)
     46     restructured_from_flat = nest.pack_sequence_as(structure, flat)
     47     self.assertEqual(restructured_from_flat, structure)
     48     self.assertEqual(restructured_from_flat[0].x, 4)
     49     self.assertEqual(restructured_from_flat[0].y, 2)
     50     self.assertEqual(restructured_from_flat[1][0][0].x, 1)
     51     self.assertEqual(restructured_from_flat[1][0][0].y, 0)
     52 
     53     self.assertEqual([5], nest.flatten(5))
     54     self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
     55 
     56     self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
     57     self.assertEqual(
     58         np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
     59 
     60     with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
     61       nest.pack_sequence_as("scalar", [4, 5])
     62 
     63     with self.assertRaisesRegexp(TypeError, "flat_sequence"):
     64       nest.pack_sequence_as([4, 5], "bad_sequence")
     65 
     66     with self.assertRaises(ValueError):
     67       nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
     68 
     69   def testFlattenDictOrder(self):
     70     """`flatten` orders dicts by key, including OrderedDicts."""
     71     ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
     72     plain = {"d": 3, "b": 1, "a": 0, "c": 2}
     73     ordered_flat = nest.flatten(ordered)
     74     plain_flat = nest.flatten(plain)
     75     self.assertEqual([0, 1, 2, 3], ordered_flat)
     76     self.assertEqual([0, 1, 2, 3], plain_flat)
     77 
     78   def testPackDictOrder(self):
     79     """Packing orders dicts by key, including OrderedDicts."""
     80     ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
     81     plain = {"d": 0, "b": 0, "a": 0, "c": 0}
     82     seq = [0, 1, 2, 3]
     83     ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
     84     plain_reconstruction = nest.pack_sequence_as(plain, seq)
     85     self.assertEqual(
     86         collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
     87         ordered_reconstruction)
     88     self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
     89 
     90   def testFlattenAndPack_withDicts(self):
     91     # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
     92     named_tuple = collections.namedtuple("A", ("b", "c"))
     93     mess = [
     94         "z",
     95         named_tuple(3, 4),
     96         {
     97             "c": [
     98                 1,
     99                 collections.OrderedDict([
    100                     ("b", 3),
    101                     ("a", 2),
    102                 ]),
    103             ],
    104             "b": 5
    105         },
    106         17
    107     ]
    108 
    109     flattened = nest.flatten(mess)
    110     self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
    111 
    112     structure_of_mess = [
    113         14,
    114         named_tuple("a", True),
    115         {
    116             "c": [
    117                 0,
    118                 collections.OrderedDict([
    119                     ("b", 9),
    120                     ("a", 8),
    121                 ]),
    122             ],
    123             "b": 3
    124         },
    125         "hi everybody",
    126     ]
    127 
    128     unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
    129     self.assertEqual(unflattened, mess)
    130 
    131     # Check also that the OrderedDict was created, with the correct key order.
    132     unflattened_ordered_dict = unflattened[2]["c"][1]
    133     self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
    134     self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
    135 
    136   def testFlatten_numpyIsNotFlattened(self):
    137     structure = np.array([1, 2, 3])
    138     flattened = nest.flatten(structure)
    139     self.assertEqual(len(flattened), 1)
    140 
    141   def testFlatten_stringIsNotFlattened(self):
    142     structure = "lots of letters"
    143     flattened = nest.flatten(structure)
    144     self.assertEqual(len(flattened), 1)
    145     unflattened = nest.pack_sequence_as("goodbye", flattened)
    146     self.assertEqual(structure, unflattened)
    147 
    148   def testPackSequenceAs_notIterableError(self):
    149     with self.assertRaisesRegexp(TypeError,
    150                                  "flat_sequence must be a sequence"):
    151       nest.pack_sequence_as("hi", "bye")
    152 
    153   def testPackSequenceAs_wrongLengthsError(self):
    154     with self.assertRaisesRegexp(
    155         ValueError,
    156         "Structure had 2 elements, but flat_sequence had 3 elements."):
    157       nest.pack_sequence_as(["hello", "world"],
    158                             ["and", "goodbye", "again"])
    159 
    160   def testIsSequence(self):
    161     self.assertFalse(nest.is_sequence("1234"))
    162     self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
    163     self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
    164     self.assertTrue(nest.is_sequence([]))
    165     self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
    166     self.assertFalse(nest.is_sequence(set([1, 2])))
    167     ones = array_ops.ones([2, 3])
    168     self.assertFalse(nest.is_sequence(ones))
    169     self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
    170     self.assertFalse(nest.is_sequence(np.ones((4, 5))))
    171 
    172   def testFlattenDictItems(self):
    173     dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
    174     flat = {4: "a", 5: "b", 6: "c", 8: "d"}
    175     self.assertEqual(nest.flatten_dict_items(dictionary), flat)
    176 
    177     with self.assertRaises(TypeError):
    178       nest.flatten_dict_items(4)
    179 
    180     bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))}
    181     with self.assertRaisesRegexp(ValueError, "not unique"):
    182       nest.flatten_dict_items(bad_dictionary)
    183 
    184     another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))}
    185     with self.assertRaisesRegexp(
    186         ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
    187       nest.flatten_dict_items(another_bad_dictionary)
    188 
    189   def testAssertSameStructure(self):
    190     structure1 = (((1, 2), 3), 4, (5, 6))
    191     structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
    192     structure_different_num_elements = ("spam", "eggs")
    193     structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
    194     nest.assert_same_structure(structure1, structure2)
    195     nest.assert_same_structure("abc", 1.0)
    196     nest.assert_same_structure("abc", np.array([0, 1]))
    197     nest.assert_same_structure("abc", constant_op.constant([0, 1]))
    198 
    199     with self.assertRaisesRegexp(
    200         ValueError,
    201         ("don't have the same number of elements\\.\n\n"
    202          "First structure \\(6 elements\\):.*?"
    203          "\n\nSecond structure \\(2 elements\\):")):
    204       nest.assert_same_structure(structure1, structure_different_num_elements)
    205 
    206     with self.assertRaisesRegexp(
    207         ValueError,
    208         ("don't have the same number of elements\\.\n\n"
    209          "First structure \\(2 elements\\):.*?"
    210          "\n\nSecond structure \\(1 elements\\):")):
    211       nest.assert_same_structure([0, 1], np.array([0, 1]))
    212 
    213     with self.assertRaisesRegexp(
    214         ValueError,
    215         ("don't have the same number of elements\\.\n\n"
    216          "First structure \\(1 elements\\):.*"
    217          "\n\nSecond structure \\(2 elements\\):")):
    218       nest.assert_same_structure(0, [0, 1])
    219 
    220     self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
    221 
    222     with self.assertRaisesRegexp(
    223         ValueError,
    224         ("don't have the same nested structure\\.\n\n"
    225          "First structure: .*?\n\nSecond structure: ")):
    226       nest.assert_same_structure(structure1, structure_different_nesting)
    227 
    228     named_type_0 = collections.namedtuple("named_0", ("a", "b"))
    229     named_type_1 = collections.namedtuple("named_1", ("a", "b"))
    230     self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
    231                       named_type_0("a", "b"))
    232 
    233     nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
    234 
    235     self.assertRaises(TypeError, nest.assert_same_structure,
    236                       named_type_0(3, 4), named_type_1(3, 4))
    237 
    238     with self.assertRaisesRegexp(
    239         ValueError,
    240         ("don't have the same nested structure\\.\n\n"
    241          "First structure: .*?\n\nSecond structure: ")):
    242       nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4))
    243 
    244     with self.assertRaisesRegexp(
    245         ValueError,
    246         ("don't have the same nested structure\\.\n\n"
    247          "First structure: .*?\n\nSecond structure: ")):
    248       nest.assert_same_structure([[3], 4], [3, [4]])
    249 
    250     structure1_list = [[[1, 2], 3], 4, [5, 6]]
    251     with self.assertRaisesRegexp(TypeError,
    252                                  "don't have the same sequence type"):
    253       nest.assert_same_structure(structure1, structure1_list)
    254     nest.assert_same_structure(structure1, structure2, check_types=False)
    255     nest.assert_same_structure(structure1, structure1_list, check_types=False)
    256 
    257     with self.assertRaisesRegexp(ValueError,
    258                                  "don't have the same set of keys"):
    259       nest.assert_same_structure({"a": 1}, {"b": 1})
    260 
    261     same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
    262     same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
    263     nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3))
    264 
    265     # This assertion is expected to pass: two namedtuples with the same
    266     # name and field names are considered to be identical.
    267     same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y"))
    268     same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y"))
    269     nest.assert_same_structure(
    270         same_name_type_0(same_name_type_2(0, 1), 2),
    271         same_name_type_1(same_name_type_3(2, 3), 4))
    272 
    273     expected_message = "The two structures don't have the same.*"
    274     with self.assertRaisesRegexp(ValueError, expected_message):
    275       nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)),
    276                                  same_name_type_1(same_name_type_0(0, 1), 2))
    277 
    278     same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b"))
    279     self.assertRaises(TypeError, nest.assert_same_structure,
    280                       same_name_type_0(0, 1), same_name_type_1(2, 3))
    281 
    282     same_name_type_1 = collections.namedtuple("same_name", ("x", "y"))
    283     self.assertRaises(TypeError, nest.assert_same_structure,
    284                       same_name_type_0(0, 1), same_name_type_1(2, 3))
    285 
    286     class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))):
    287       pass
    288     self.assertRaises(TypeError, nest.assert_same_structure,
    289                       same_name_type_0(0, 1), SameNamedType1(2, 3))
    290 
    291   def testMapStructure(self):
    292     structure1 = (((1, 2), 3), 4, (5, 6))
    293     structure2 = (((7, 8), 9), 10, (11, 12))
    294     structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
    295     nest.assert_same_structure(structure1, structure1_plus1)
    296     self.assertAllEqual(
    297         [2, 3, 4, 5, 6, 7],
    298         nest.flatten(structure1_plus1))
    299     structure1_plus_structure2 = nest.map_structure(
    300         lambda x, y: x + y, structure1, structure2)
    301     self.assertEqual(
    302         (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
    303         structure1_plus_structure2)
    304 
    305     self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
    306 
    307     self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
    308 
    309     # Empty structures
    310     self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
    311     self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
    312     self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
    313     empty_nt = collections.namedtuple("empty_nt", "")
    314     self.assertEqual(empty_nt(), nest.map_structure(lambda x: x + 1,
    315                                                     empty_nt()))
    316 
    317     # This is checking actual equality of types, empty list != empty tuple
    318     self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
    319 
    320     with self.assertRaisesRegexp(TypeError, "callable"):
    321       nest.map_structure("bad", structure1_plus1)
    322 
    323     with self.assertRaisesRegexp(ValueError, "at least one structure"):
    324       nest.map_structure(lambda x: x)
    325 
    326     with self.assertRaisesRegexp(ValueError, "same number of elements"):
    327       nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
    328 
    329     with self.assertRaisesRegexp(ValueError, "same nested structure"):
    330       nest.map_structure(lambda x, y: None, 3, (3,))
    331 
    332     with self.assertRaisesRegexp(TypeError, "same sequence type"):
    333       nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
    334 
    335     with self.assertRaisesRegexp(ValueError, "same nested structure"):
    336       nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
    337 
    338     structure1_list = [[[1, 2], 3], 4, [5, 6]]
    339     with self.assertRaisesRegexp(TypeError, "same sequence type"):
    340       nest.map_structure(lambda x, y: None, structure1, structure1_list)
    341 
    342     nest.map_structure(lambda x, y: None, structure1, structure1_list,
    343                        check_types=False)
    344 
    345     with self.assertRaisesRegexp(ValueError, "same nested structure"):
    346       nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
    347                          check_types=False)
    348 
    349     with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
    350       nest.map_structure(lambda x: None, structure1, foo="a")
    351 
    352     with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
    353       nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
    354 
    355   def testMapStructureWithStrings(self):
    356     ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    357     inp_a = ab_tuple(a="foo", b=("bar", "baz"))
    358     inp_b = ab_tuple(a=2, b=(1, 3))
    359     out = nest.map_structure(lambda string, repeats: string * repeats,
    360                              inp_a,
    361                              inp_b)
    362     self.assertEqual("foofoo", out.a)
    363     self.assertEqual("bar", out.b[0])
    364     self.assertEqual("bazbazbaz", out.b[1])
    365 
    366     nt = ab_tuple(a=("something", "something_else"),
    367                   b="yet another thing")
    368     rev_nt = nest.map_structure(lambda x: x[::-1], nt)
    369     # Check the output is the correct structure, and all strings are reversed.
    370     nest.assert_same_structure(nt, rev_nt)
    371     self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
    372     self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
    373     self.assertEqual(nt.b[::-1], rev_nt.b)
    374 
    375   def testMapStructureOverPlaceholders(self):
    376     inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
    377              array_ops.placeholder(dtypes.float32, shape=[3, 7]))
    378     inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
    379              array_ops.placeholder(dtypes.float32, shape=[3, 7]))
    380 
    381     output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
    382 
    383     nest.assert_same_structure(output, inp_a)
    384     self.assertShapeEqual(np.zeros((3, 4)), output[0])
    385     self.assertShapeEqual(np.zeros((3, 7)), output[1])
    386 
    387     feed_dict = {
    388         inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
    389         inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
    390     }
    391 
    392     with self.test_session() as sess:
    393       output_np = sess.run(output, feed_dict=feed_dict)
    394     self.assertAllClose(output_np[0],
    395                         feed_dict[inp_a][0] + feed_dict[inp_b][0])
    396     self.assertAllClose(output_np[1],
    397                         feed_dict[inp_a][1] + feed_dict[inp_b][1])
    398 
    399   def testAssertShallowStructure(self):
    400     inp_ab = ["a", "b"]
    401     inp_abc = ["a", "b", "c"]
    402     expected_message = (
    403         "The two structures don't have the same sequence length. Input "
    404         "structure has length 2, while shallow structure has length 3.")
    405     with self.assertRaisesRegexp(ValueError, expected_message):
    406       nest.assert_shallow_structure(inp_abc, inp_ab)
    407 
    408     inp_ab1 = [(1, 1), (2, 2)]
    409     inp_ab2 = [[1, 1], [2, 2]]
    410     expected_message = (
    411         "The two structures don't have the same sequence type. Input structure "
    412         "has type <(type|class) 'tuple'>, while shallow structure has type "
    413         "<(type|class) 'list'>.")
    414     with self.assertRaisesRegexp(TypeError, expected_message):
    415       nest.assert_shallow_structure(inp_ab2, inp_ab1)
    416     nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
    417 
    418     inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
    419     inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
    420     expected_message = (
    421         r"The two structures don't have the same keys. Input "
    422         r"structure has keys \['c'\], while shallow structure has "
    423         r"keys \['d'\].")
    424 
    425     with self.assertRaisesRegexp(ValueError, expected_message):
    426       nest.assert_shallow_structure(inp_ab2, inp_ab1)
    427 
    428     inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
    429     inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
    430     nest.assert_shallow_structure(inp_ab, inp_ba)
    431 
    432     # This assertion is expected to pass: two namedtuples with the same
    433     # name and field names are considered to be identical.
    434     same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
    435     same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
    436     inp_shallow = same_name_type_0(1, 2)
    437     inp_deep = same_name_type_1(1, [1, 2, 3])
    438     nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
    439     nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
    440 
    441   def testFlattenUpTo(self):
    442     # Shallow tree ends at scalar.
    443     input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
    444     shallow_tree = [[True, True], [False, True]]
    445     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    446     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    447     self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
    448     self.assertEqual(flattened_shallow_tree, [True, True, False, True])
    449 
    450     # Shallow tree ends at string.
    451     input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
    452     shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
    453     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
    454                                                               input_tree)
    455     input_tree_flattened = nest.flatten(input_tree)
    456     self.assertEqual(input_tree_flattened_as_shallow_tree,
    457                      [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
    458     self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
    459 
    460     # Make sure dicts are correctly flattened, yielding values, not keys.
    461     input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
    462     shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
    463     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
    464                                                               input_tree)
    465     self.assertEqual(input_tree_flattened_as_shallow_tree,
    466                      [1, {"c": 2}, 3, (4, 5)])
    467 
    468     # Namedtuples.
    469     ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    470     input_tree = ab_tuple(a=[0, 1], b=2)
    471     shallow_tree = ab_tuple(a=0, b=1)
    472     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
    473                                                               input_tree)
    474     self.assertEqual(input_tree_flattened_as_shallow_tree,
    475                      [[0, 1], 2])
    476 
    477     # Nested dicts, OrderedDicts and namedtuples.
    478     input_tree = collections.OrderedDict(
    479         [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
    480          ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
    481     shallow_tree = input_tree
    482     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
    483                                                               input_tree)
    484     self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
    485     shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
    486     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
    487                                                               input_tree)
    488     self.assertEqual(input_tree_flattened_as_shallow_tree,
    489                      [ab_tuple(a=[0, {"b": 1}], b=2),
    490                       3,
    491                       collections.OrderedDict([("f", 4)])])
    492     shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
    493     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
    494                                                               input_tree)
    495     self.assertEqual(input_tree_flattened_as_shallow_tree,
    496                      [ab_tuple(a=[0, {"b": 1}], b=2),
    497                       {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
    498 
    499     ## Shallow non-list edge-case.
    500     # Using iterable elements.
    501     input_tree = ["input_tree"]
    502     shallow_tree = "shallow_tree"
    503     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    504     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    505     self.assertEqual(flattened_input_tree, [input_tree])
    506     self.assertEqual(flattened_shallow_tree, [shallow_tree])
    507 
    508     input_tree = ["input_tree_0", "input_tree_1"]
    509     shallow_tree = "shallow_tree"
    510     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    511     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    512     self.assertEqual(flattened_input_tree, [input_tree])
    513     self.assertEqual(flattened_shallow_tree, [shallow_tree])
    514 
    515     # Using non-iterable elements.
    516     input_tree = [0]
    517     shallow_tree = 9
    518     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    519     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    520     self.assertEqual(flattened_input_tree, [input_tree])
    521     self.assertEqual(flattened_shallow_tree, [shallow_tree])
    522 
    523     input_tree = [0, 1]
    524     shallow_tree = 9
    525     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    526     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    527     self.assertEqual(flattened_input_tree, [input_tree])
    528     self.assertEqual(flattened_shallow_tree, [shallow_tree])
    529 
    530     ## Both non-list edge-case.
    531     # Using iterable elements.
    532     input_tree = "input_tree"
    533     shallow_tree = "shallow_tree"
    534     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    535     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    536     self.assertEqual(flattened_input_tree, [input_tree])
    537     self.assertEqual(flattened_shallow_tree, [shallow_tree])
    538 
    539     # Using non-iterable elements.
    540     input_tree = 0
    541     shallow_tree = 0
    542     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    543     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    544     self.assertEqual(flattened_input_tree, [input_tree])
    545     self.assertEqual(flattened_shallow_tree, [shallow_tree])
    546 
    547     ## Input non-list edge-case.
    548     # Using iterable elements.
    549     input_tree = "input_tree"
    550     shallow_tree = ["shallow_tree"]
    551     expected_message = ("If shallow structure is a sequence, input must also "
    552                         "be a sequence. Input has type: <(type|class) 'str'>.")
    553     with self.assertRaisesRegexp(TypeError, expected_message):
    554       flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    555     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    556     self.assertEqual(flattened_shallow_tree, shallow_tree)
    557 
    558     input_tree = "input_tree"
    559     shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
    560     with self.assertRaisesRegexp(TypeError, expected_message):
    561       flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    562     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    563     self.assertEqual(flattened_shallow_tree, shallow_tree)
    564 
    565     # Using non-iterable elements.
    566     input_tree = 0
    567     shallow_tree = [9]
    568     expected_message = ("If shallow structure is a sequence, input must also "
    569                         "be a sequence. Input has type: <(type|class) 'int'>.")
    570     with self.assertRaisesRegexp(TypeError, expected_message):
    571       flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    572     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    573     self.assertEqual(flattened_shallow_tree, shallow_tree)
    574 
    575     input_tree = 0
    576     shallow_tree = [9, 8]
    577     with self.assertRaisesRegexp(TypeError, expected_message):
    578       flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    579     flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    580     self.assertEqual(flattened_shallow_tree, shallow_tree)
    581 
    582   def testMapStructureUpTo(self):
    583     # Named tuples.
    584     ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    585     op_tuple = collections.namedtuple("op_tuple", "add, mul")
    586     inp_val = ab_tuple(a=2, b=3)
    587     inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
    588     out = nest.map_structure_up_to(
    589         inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
    590     self.assertEqual(out.a, 6)
    591     self.assertEqual(out.b, 15)
    592 
    593     # Lists.
    594     data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
    595     name_list = ["evens", ["odds", "primes"]]
    596     out = nest.map_structure_up_to(
    597         name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
    598         name_list, data_list)
    599     self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
    600 
    601   def testGetTraverseShallowStructure(self):
    602     scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
    603     scalar_traverse_r = nest.get_traverse_shallow_structure(
    604         lambda s: not isinstance(s, tuple),
    605         scalar_traverse_input)
    606     self.assertEqual(scalar_traverse_r,
    607                      [True, True, False, [True, True], {"a": False}, []])
    608     nest.assert_shallow_structure(scalar_traverse_r,
    609                                   scalar_traverse_input)
    610 
    611     structure_traverse_input = [(1, [2]), ([1], 2)]
    612     structure_traverse_r = nest.get_traverse_shallow_structure(
    613         lambda s: (True, False) if isinstance(s, tuple) else True,
    614         structure_traverse_input)
    615     self.assertEqual(structure_traverse_r,
    616                      [(True, False), ([True], False)])
    617     nest.assert_shallow_structure(structure_traverse_r,
    618                                   structure_traverse_input)
    619 
    620     with self.assertRaisesRegexp(TypeError, "returned structure"):
    621       nest.get_traverse_shallow_structure(lambda _: [True], 0)
    622 
    623     with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
    624       nest.get_traverse_shallow_structure(lambda _: 1, [1])
    625 
    626     with self.assertRaisesRegexp(
    627         TypeError, "didn't return a depth=1 structure of bools"):
    628       nest.get_traverse_shallow_structure(lambda _: [1], [1])
    629 
    630   def testYieldFlatStringPaths(self):
    631     for inputs_expected in ({"inputs": [], "expected": []},
    632                             {"inputs": 3, "expected": [()]},
    633                             {"inputs": [3], "expected": [(0,)]},
    634                             {"inputs": {"a": 3}, "expected": [("a",)]},
    635                             {"inputs": {"a": {"b": 4}},
    636                              "expected": [("a", "b")]},
    637                             {"inputs": [{"a": 2}], "expected": [(0, "a")]},
    638                             {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
    639                             {"inputs": [{"a": [(23, 42)]}],
    640                              "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
    641                             {"inputs": [{"a": ([23], 42)}],
    642                              "expected": [(0, "a", 0, 0), (0, "a", 1)]},
    643                             {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
    644                              "expected": [("a", "a"), ("c", 0, 0, 0)]},
    645                             {"inputs": {"0": [{"1": 23}]},
    646                              "expected": [("0", 0, "1")]}):
    647       inputs = inputs_expected["inputs"]
    648       expected = inputs_expected["expected"]
    649       self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
    650 
    651   def testFlattenWithStringPaths(self):
    652     for inputs_expected in (
    653         {"inputs": [], "expected": []},
    654         {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
    655         {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
    656       inputs = inputs_expected["inputs"]
    657       expected = inputs_expected["expected"]
    658       self.assertEqual(
    659           nest.flatten_with_joined_string_paths(inputs, separator="/"),
    660           expected)
    661 
    662   # Need a separate test for namedtuple as we can't declare tuple definitions
    663   # in the @parameterized arguments.
    664   def testFlattenNamedTuple(self):
    665     # pylint: disable=invalid-name
    666     Foo = collections.namedtuple("Foo", ["a", "b"])
    667     Bar = collections.namedtuple("Bar", ["c", "d"])
    668     # pylint: enable=invalid-name
    669     test_cases = [
    670         (Foo(a=3, b=Bar(c=23, d=42)),
    671          [("a", 3), ("b/c", 23), ("b/d", 42)]),
    672         (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
    673          [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
    674         (Bar(c=42, d=43),
    675          [("c", 42), ("d", 43)]),
    676         (Bar(c=[42], d=43),
    677          [("c/0", 42), ("d", 43)]),
    678     ]
    679     for inputs, expected in test_cases:
    680       self.assertEqual(
    681           list(nest.flatten_with_joined_string_paths(inputs)), expected)
    682 
    683 
    684 if __name__ == "__main__":
    685   test.main()
    686