Home | History | Annotate | Download | only in util
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """## Functions for working with arbitrarily nested sequences of elements.
     17 
     18 This module can perform operations on nested structures. A nested structure is a
     19 Python sequence, tuple (including `namedtuple`), or dict that can contain
     20 further sequences, tuples, and dicts.
     21 
     22 The utilities here assume (and do not check) that the nested structures form a
     23 'tree', i.e., no references in the structure of the input of these functions
     24 should be recursive.
     25 
     26 Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
     27   (np.array([3, 4]), tf.constant([3, 4])))`
     28 """
     29 
     30 from __future__ import absolute_import
     31 from __future__ import division
     32 from __future__ import print_function
     33 
     34 import collections as _collections
     35 
     36 import six as _six
     37 
     38 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
     39 from tensorflow.python.util.all_util import remove_undocumented
     40 
     41 
     42 def _sorted(dict_):
     43   """Returns a sorted list of the dict keys, with error if keys not sortable."""
     44   try:
     45     return sorted(_six.iterkeys(dict_))
     46   except TypeError:
     47     raise TypeError("nest only supports dicts with sortable keys.")
     48 
     49 
     50 def _is_namedtuple(instance, strict=False):
     51   """Returns True iff `instance` is a `namedtuple`.
     52 
     53   Args:
     54     instance: An instance of a Python object.
     55     strict: If True, `instance` is considered to be a `namedtuple` only if
     56         it is a "plain" namedtuple. For instance, a class inheriting
     57         from a `namedtuple` will be considered to be a `namedtuple`
     58         iff `strict=False`.
     59 
     60   Returns:
     61     True if `instance` is a `namedtuple`.
     62   """
     63   # Attemp to limit the test to plain namedtuple (not stuff inheriting from it).
     64   if not isinstance(instance, tuple):
     65     return False
     66   if strict and instance.__class__.__base__ != tuple:
     67     return False
     68   return (
     69       hasattr(instance, "_fields") and
     70       isinstance(instance._fields, _collections.Sequence) and
     71       all(isinstance(f, _six.string_types) for f in instance._fields))
     72 
     73 
     74 def _sequence_like(instance, args):
     75   """Converts the sequence `args` to the same type as `instance`.
     76 
     77   Args:
     78     instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
     79         `collections.OrderedDict`.
     80     args: elements to be converted to the `instance` type.
     81 
     82   Returns:
     83     `args` with the type of `instance`.
     84   """
     85   if isinstance(instance, dict):
     86     # Pack dictionaries in a deterministic order by sorting the keys.
     87     # Notice this means that we ignore the original order of `OrderedDict`
     88     # instances. This is intentional, to avoid potential bugs caused by mixing
     89     # ordered and plain dicts (e.g., flattening a dict but using a
     90     # corresponding `OrderedDict` to pack it back).
     91     result = dict(zip(_sorted(instance), args))
     92     return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
     93   elif _is_namedtuple(instance):
     94     return type(instance)(*args)
     95   else:
     96     # Not a namedtuple
     97     return type(instance)(args)
     98 
     99 
    100 def _yield_value(iterable):
    101   if isinstance(iterable, dict):
    102     # Iterate through dictionaries in a deterministic order by sorting the
    103     # keys. Notice this means that we ignore the original order of `OrderedDict`
    104     # instances. This is intentional, to avoid potential bugs caused by mixing
    105     # ordered and plain dicts (e.g., flattening a dict but using a
    106     # corresponding `OrderedDict` to pack it back).
    107     for key in _sorted(iterable):
    108       yield iterable[key]
    109   else:
    110     for value in iterable:
    111       yield value
    112 
    113 
    114 def is_sequence(seq):
    115   """Returns a true if its input is a collections.Sequence (except strings).
    116 
    117   Args:
    118     seq: an input sequence.
    119 
    120   Returns:
    121     True if the sequence is a not a string and is a collections.Sequence or a
    122     dict.
    123   """
    124   return _pywrap_tensorflow.IsSequence(seq)
    125 
    126 
    127 def flatten(nest):
    128   """Returns a flat list from a given nested structure.
    129 
    130   If `nest` is not a sequence, tuple, or dict, then returns a single-element
    131   list: `[nest]`.
    132 
    133   In the case of dict instances, the sequence consists of the values, sorted by
    134   key to ensure deterministic behavior. This is true also for `OrderedDict`
    135   instances: their sequence order is ignored, the sorting order of keys is
    136   used instead. The same convention is followed in `pack_sequence_as`. This
    137   correctly repacks dicts and `OrderedDict`s after they have been flattened,
    138   and also allows flattening an `OrderedDict` and then repacking it back using
    139   a corresponding plain dict, or vice-versa.
    140   Dictionaries with non-sortable keys cannot be flattened.
    141 
    142   Users must not modify any collections used in `nest` while this function is
    143   running.
    144 
    145   Args:
    146     nest: an arbitrarily nested structure or a scalar object. Note, numpy
    147         arrays are considered scalars.
    148 
    149   Returns:
    150     A Python list, the flattened version of the input.
    151 
    152   Raises:
    153     TypeError: The nest is or contains a dict with non-sortable keys.
    154   """
    155   return _pywrap_tensorflow.Flatten(nest)
    156 
    157 
    158 def _same_namedtuples(nest1, nest2):
    159   """Returns True if the two namedtuples have the same name and fields."""
    160   if nest1._fields != nest2._fields:
    161     return False
    162   if nest1.__class__.__name__ != nest2.__class__.__name__:
    163     return False
    164   return True
    165 
    166 
    167 def _recursive_assert_same_structure(nest1, nest2, check_types):
    168   """Helper function for `assert_same_structure`.
    169 
    170   See `assert_same_structure` for further information about namedtuples.
    171 
    172   Args:
    173     nest1: An arbitrarily nested structure.
    174     nest2: An arbitrarily nested structure.
    175     check_types: If `True` (default) types of sequences are checked as
    176         well, including the keys of dictionaries. If set to `False`, for example
    177         a list and a tuple of objects will look the same if they have the same
    178         size. Note that namedtuples with identical name and fields are always
    179         considered to have the same shallow structure.
    180 
    181   Returns:
    182     True if `nest1` and `nest2` have the same structure.
    183 
    184   Raises:
    185     ValueError: If the two structure don't have the same nested structre.
    186     TypeError: If the two structure don't have the same sequence type.
    187     ValueError: If the two dictionaries don't have the same set of keys.
    188   """
    189   is_sequence_nest1 = is_sequence(nest1)
    190   if is_sequence_nest1 != is_sequence(nest2):
    191     raise ValueError(
    192         "The two structures don't have the same nested structure.\n\n"
    193         "First structure: %s\n\nSecond structure: %s." % (nest1, nest2))
    194 
    195   if not is_sequence_nest1:
    196     return  # finished checking
    197 
    198   if check_types:
    199     type_nest1 = type(nest1)
    200     type_nest2 = type(nest2)
    201 
    202     # Duck-typing means that nest should be fine with two different namedtuples
    203     # with identical name and fields.
    204     if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True):
    205       if not _same_namedtuples(nest1, nest2):
    206         raise TypeError(
    207             "The two namedtuples don't have the same sequence type. First "
    208             "structure has type %s, while second structure has type %s."
    209             % (type_nest1, type_nest2))
    210     else:
    211       if type_nest1 != type_nest2:
    212         raise TypeError(
    213             "The two structures don't have the same sequence type. First "
    214             "structure has type %s, while second structure has type %s."
    215             % (type_nest1, type_nest2))
    216 
    217     if isinstance(nest1, dict):
    218       keys1 = set(_six.iterkeys(nest1))
    219       keys2 = set(_six.iterkeys(nest2))
    220       if keys1 != keys2:
    221         raise ValueError(
    222             "The two dictionaries don't have the same set of keys. First "
    223             "structure has keys {}, while second structure has keys {}."
    224             .format(keys1, keys2))
    225 
    226   nest1_as_sequence = [n for n in _yield_value(nest1)]
    227   nest2_as_sequence = [n for n in _yield_value(nest2)]
    228   for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence):
    229     _recursive_assert_same_structure(n1, n2, check_types)
    230 
    231 
    232 def assert_same_structure(nest1, nest2, check_types=True):
    233   """Asserts that two structures are nested in the same way.
    234 
    235   Note that namedtuples with identical name and fields are always considered
    236   to have the same shallow structure (even with `check_types=True`).
    237   For intance, this code will print `True`:
    238 
    239   ```python
    240   def nt(a, b):
    241     return collections.namedtuple('foo', 'a b')(a, b)
    242   print(assert_same_structure(nt(0, 1), nt(2, 3)))
    243   ```
    244 
    245   Args:
    246     nest1: an arbitrarily nested structure.
    247     nest2: an arbitrarily nested structure.
    248     check_types: if `True` (default) types of sequences are checked as
    249         well, including the keys of dictionaries. If set to `False`, for example
    250         a list and a tuple of objects will look the same if they have the same
    251         size. Note that namedtuples with identical name and fields are always
    252         considered to have the same shallow structure.
    253 
    254   Raises:
    255     ValueError: If the two structures do not have the same number of elements or
    256       if the two structures are not nested in the same way.
    257     TypeError: If the two structures differ in the type of sequence in any of
    258       their substructures. Only possible if `check_types` is `True`.
    259   """
    260   len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
    261   len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
    262   if len_nest1 != len_nest2:
    263     raise ValueError("The two structures don't have the same number of "
    264                      "elements.\n\nFirst structure (%i elements): %s\n\n"
    265                      "Second structure (%i elements): %s"
    266                      % (len_nest1, nest1, len_nest2, nest2))
    267   _recursive_assert_same_structure(nest1, nest2, check_types)
    268 
    269 
    270 def flatten_dict_items(dictionary):
    271   """Returns a dictionary with flattened keys and values.
    272 
    273   This function flattens the keys and values of a dictionary, which can be
    274   arbitrarily nested structures, and returns the flattened version of such
    275   structures:
    276 
    277   ```python
    278   example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
    279   result = {4: "a", 5: "b", 6: "c", 8: "d"}
    280   flatten_dict_items(example_dictionary) == result
    281   ```
    282 
    283   The input dictionary must satisfy two properties:
    284 
    285   1. Its keys and values should have the same exact nested structure.
    286   2. The set of all flattened keys of the dictionary must not contain repeated
    287      keys.
    288 
    289   Args:
    290     dictionary: the dictionary to zip
    291 
    292   Returns:
    293     The zipped dictionary.
    294 
    295   Raises:
    296     TypeError: If the input is not a dictionary.
    297     ValueError: If any key and value have not the same structure, or if keys are
    298       not unique.
    299   """
    300   if not isinstance(dictionary, dict):
    301     raise TypeError("input must be a dictionary")
    302   flat_dictionary = {}
    303   for i, v in _six.iteritems(dictionary):
    304     if not is_sequence(i):
    305       if i in flat_dictionary:
    306         raise ValueError(
    307             "Could not flatten dictionary: key %s is not unique." % i)
    308       flat_dictionary[i] = v
    309     else:
    310       flat_i = flatten(i)
    311       flat_v = flatten(v)
    312       if len(flat_i) != len(flat_v):
    313         raise ValueError(
    314             "Could not flatten dictionary. Key had %d elements, but value had "
    315             "%d elements. Key: %s, value: %s."
    316             % (len(flat_i), len(flat_v), flat_i, flat_v))
    317       for new_i, new_v in zip(flat_i, flat_v):
    318         if new_i in flat_dictionary:
    319           raise ValueError(
    320               "Could not flatten dictionary: key %s is not unique."
    321               % (new_i))
    322         flat_dictionary[new_i] = new_v
    323   return flat_dictionary
    324 
    325 
    326 def _packed_nest_with_indices(structure, flat, index):
    327   """Helper function for pack_sequence_as.
    328 
    329   Args:
    330     structure: Substructure (list / tuple / dict) to mimic.
    331     flat: Flattened values to output substructure for.
    332     index: Index at which to start reading from flat.
    333 
    334   Returns:
    335     The tuple (new_index, child), where:
    336       * new_index - the updated index into `flat` having processed `structure`.
    337       * packed - the subset of `flat` corresponding to `structure`,
    338                  having started at `index`, and packed into the same nested
    339                  format.
    340 
    341   Raises:
    342     ValueError: if `structure` contains more elements than `flat`
    343       (assuming indexing starts from `index`).
    344   """
    345   packed = []
    346   for s in _yield_value(structure):
    347     if is_sequence(s):
    348       new_index, child = _packed_nest_with_indices(s, flat, index)
    349       packed.append(_sequence_like(s, child))
    350       index = new_index
    351     else:
    352       packed.append(flat[index])
    353       index += 1
    354   return index, packed
    355 
    356 
    357 def pack_sequence_as(structure, flat_sequence):
    358   """Returns a given flattened sequence packed into a given structure.
    359 
    360   If `structure` is a scalar, `flat_sequence` must be a single-element list;
    361   in this case the return value is `flat_sequence[0]`.
    362 
    363   If `structure` is or contains a dict instance, the keys will be sorted to
    364   pack the flat sequence in deterministic order. This is true also for
    365   `OrderedDict` instances: their sequence order is ignored, the sorting order of
    366   keys is used instead. The same convention is followed in `flatten`.
    367   This correctly repacks dicts and `OrderedDict`s after they have been
    368   flattened, and also allows flattening an `OrderedDict` and then repacking it
    369   back using a corresponding plain dict, or vice-versa.
    370   Dictionaries with non-sortable keys cannot be flattened.
    371 
    372   Args:
    373     structure: Nested structure, whose structure is given by nested lists,
    374         tuples, and dicts. Note: numpy arrays and strings are considered
    375         scalars.
    376     flat_sequence: flat sequence to pack.
    377 
    378   Returns:
    379     packed: `flat_sequence` converted to have the same recursive structure as
    380       `structure`.
    381 
    382   Raises:
    383     ValueError: If `flat_sequence` and `structure` have different
    384       element counts.
    385     TypeError: `structure` is or contains a dict with non-sortable keys.
    386   """
    387   if not is_sequence(flat_sequence):
    388     raise TypeError("flat_sequence must be a sequence")
    389 
    390   if not is_sequence(structure):
    391     if len(flat_sequence) != 1:
    392       raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1"
    393                        % len(flat_sequence))
    394     return flat_sequence[0]
    395 
    396   flat_structure = flatten(structure)
    397   if len(flat_structure) != len(flat_sequence):
    398     raise ValueError(
    399         "Could not pack sequence. Structure had %d elements, but flat_sequence "
    400         "had %d elements.  Structure: %s, flat_sequence: %s."
    401         % (len(flat_structure), len(flat_sequence), structure, flat_sequence))
    402 
    403   _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
    404   return _sequence_like(structure, packed)
    405 
    406 
    407 def map_structure(func, *structure, **check_types_dict):
    408   """Applies `func` to each entry in `structure` and returns a new structure.
    409 
    410   Applies `func(x[0], x[1], ...)` where x[i] is an entry in
    411   `structure[i]`.  All structures in `structure` must have the same arity,
    412   and the return value will contain the results in the same structure.
    413 
    414   Args:
    415     func: A callable that accepts as many arguments as there are structures.
    416     *structure: scalar, or tuple or list of constructed scalars and/or other
    417       tuples/lists, or scalars.  Note: numpy arrays are considered as scalars.
    418     **check_types_dict: only valid keyword argument is `check_types`. If set to
    419       `True` (default) the types of iterables within the structures have to be
    420       same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
    421       exception). To allow this set this argument to `False`.
    422       Note that namedtuples with identical name and fields are always
    423       considered to have the same shallow structure.
    424 
    425   Returns:
    426     A new structure with the same arity as `structure`, whose values correspond
    427     to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
    428     location in `structure[i]`. If there are different sequence types and
    429     `check_types` is `False` the sequence types of the first structure will be
    430     used.
    431 
    432   Raises:
    433     TypeError: If `func` is not callable or if the structures do not match
    434       each other by depth tree.
    435     ValueError: If no structure is provided or if the structures do not match
    436       each other by type.
    437     ValueError: If wrong keyword arguments are provided.
    438   """
    439   if not callable(func):
    440     raise TypeError("func must be callable, got: %s" % func)
    441 
    442   if not structure:
    443     raise ValueError("Must provide at least one structure")
    444 
    445   if check_types_dict:
    446     if "check_types" not in check_types_dict or len(check_types_dict) > 1:
    447       raise ValueError("Only valid keyword argument is check_types")
    448     check_types = check_types_dict["check_types"]
    449   else:
    450     check_types = True
    451 
    452   for other in structure[1:]:
    453     assert_same_structure(structure[0], other, check_types=check_types)
    454 
    455   flat_structure = [flatten(s) for s in structure]
    456   entries = zip(*flat_structure)
    457 
    458   return pack_sequence_as(
    459       structure[0], [func(*x) for x in entries])
    460 
    461 
    462 def _yield_flat_up_to(shallow_tree, input_tree):
    463   """Yields elements `input_tree` partially flattened up to `shallow_tree`."""
    464   if is_sequence(shallow_tree):
    465     for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
    466                                             _yield_value(input_tree)):
    467       for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
    468         yield input_leaf
    469   else:
    470     yield input_tree
    471 
    472 
    473 def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
    474   """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
    475 
    476   That is, this function tests if the `input_tree` structure can be created from
    477   the `shallow_tree` structure by replacing its leaf nodes with deeper
    478   tree structures.
    479 
    480   Examples:
    481 
    482   The following code will raise an exception:
    483   ```python
    484     shallow_tree = ["a", "b"]
    485     input_tree = ["c", ["d", "e"], "f"]
    486     assert_shallow_structure(shallow_tree, input_tree)
    487   ```
    488 
    489   The following code will not raise an exception:
    490   ```python
    491     shallow_tree = ["a", "b"]
    492     input_tree = ["c", ["d", "e"]]
    493     assert_shallow_structure(shallow_tree, input_tree)
    494   ```
    495 
    496   Args:
    497     shallow_tree: an arbitrarily nested structure.
    498     input_tree: an arbitrarily nested structure.
    499     check_types: if `True` (default) the sequence types of `shallow_tree` and
    500       `input_tree` have to be the same. Note that even with check_types==True,
    501       this function will consider two different namedtuple classes with the same
    502       name and _fields attribute to be the same class.
    503 
    504   Raises:
    505     TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
    506     TypeError: If the sequence types of `shallow_tree` are different from
    507       `input_tree`. Only raised if `check_types` is `True`.
    508     ValueError: If the sequence lengths of `shallow_tree` are different from
    509       `input_tree`.
    510   """
    511   if is_sequence(shallow_tree):
    512     if not is_sequence(input_tree):
    513       raise TypeError(
    514           "If shallow structure is a sequence, input must also be a sequence. "
    515           "Input has type: %s." % type(input_tree))
    516 
    517     if check_types and not isinstance(input_tree, type(shallow_tree)):
    518       # Duck-typing means that nest should be fine with two different
    519       # namedtuples with identical name and fields.
    520       shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
    521       input_is_namedtuple = _is_namedtuple(input_tree, False)
    522       if shallow_is_namedtuple and input_is_namedtuple:
    523         if not _same_namedtuples(shallow_tree, input_tree):
    524           raise TypeError(
    525               "The two namedtuples don't have the same sequence type. Input "
    526               "structure has type %s, while shallow structure has type %s."
    527               % (type(input_tree), type(shallow_tree)))
    528       else:
    529         raise TypeError(
    530             "The two structures don't have the same sequence type. Input "
    531             "structure has type %s, while shallow structure has type %s."
    532             % (type(input_tree), type(shallow_tree)))
    533 
    534     if len(input_tree) != len(shallow_tree):
    535       raise ValueError(
    536           "The two structures don't have the same sequence length. Input "
    537           "structure has length %s, while shallow structure has length %s."
    538           % (len(input_tree), len(shallow_tree)))
    539 
    540     if check_types and isinstance(shallow_tree, dict):
    541       if set(input_tree) != set(shallow_tree):
    542         raise ValueError(
    543             "The two structures don't have the same keys. Input "
    544             "structure has keys %s, while shallow structure has keys %s." %
    545             (list(_six.iterkeys(input_tree)),
    546              list(_six.iterkeys(shallow_tree))))
    547 
    548       input_tree = list(sorted(_six.iteritems(input_tree)))
    549       shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
    550 
    551     for shallow_branch, input_branch in zip(shallow_tree, input_tree):
    552       assert_shallow_structure(shallow_branch, input_branch,
    553                                check_types=check_types)
    554 
    555 
    556 def flatten_up_to(shallow_tree, input_tree):
    557   """Flattens `input_tree` up to `shallow_tree`.
    558 
    559   Any further depth in structure in `input_tree` is retained as elements in the
    560   partially flatten output.
    561 
    562   If `shallow_tree` and `input_tree` are not sequences, this returns a
    563   single-element list: `[input_tree]`.
    564 
    565   Use Case:
    566 
    567   Sometimes we may wish to partially flatten a nested sequence, retaining some
    568   of the nested structure. We achieve this by specifying a shallow structure,
    569   `shallow_tree`, we wish to flatten up to.
    570 
    571   The input, `input_tree`, can be thought of as having the same structure as
    572   `shallow_tree`, but with leaf nodes that are themselves tree structures.
    573 
    574   Examples:
    575 
    576   ```python
    577   input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
    578   shallow_tree = [[True, True], [False, True]]
    579 
    580   flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
    581   flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
    582 
    583   # Output is:
    584   # [[2, 2], [3, 3], [4, 9], [5, 5]]
    585   # [True, True, False, True]
    586   ```
    587 
    588   ```python
    589   input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
    590   shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
    591 
    592   input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
    593   input_tree_flattened = flatten(input_tree)
    594 
    595   # Output is:
    596   # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
    597   # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
    598   ```
    599 
    600   Non-Sequence Edge Cases:
    601 
    602   ```python
    603   flatten_up_to(0, 0)  # Output: [0]
    604   flatten_up_to(0, [0, 1, 2])  # Output: [[0, 1, 2]]
    605   flatten_up_to([0, 1, 2], 0)  # Output: TypeError
    606   flatten_up_to([0, 1, 2], [0, 1, 2])  # Output: [0, 1, 2]
    607   ```
    608 
    609   Args:
    610     shallow_tree: a possibly pruned structure of input_tree.
    611     input_tree: an arbitrarily nested structure or a scalar object.
    612       Note, numpy arrays are considered scalars.
    613 
    614   Returns:
    615     A Python list, the partially flattened version of `input_tree` according to
    616     the structure of `shallow_tree`.
    617 
    618   Raises:
    619     TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
    620     TypeError: If the sequence types of `shallow_tree` are different from
    621       `input_tree`.
    622     ValueError: If the sequence lengths of `shallow_tree` are different from
    623       `input_tree`.
    624   """
    625   assert_shallow_structure(shallow_tree, input_tree)
    626   return list(_yield_flat_up_to(shallow_tree, input_tree))
    627 
    628 
    629 def map_structure_up_to(shallow_tree, func, *inputs):
    630   """Applies a function or op to a number of partially flattened inputs.
    631 
    632   The `inputs` are flattened up to `shallow_tree` before being mapped.
    633 
    634   Use Case:
    635 
    636   Sometimes we wish to apply a function to a partially flattened
    637   sequence (for example when the function itself takes sequence inputs). We
    638   achieve this by specifying a shallow structure, `shallow_tree` we wish to
    639   flatten up to.
    640 
    641   The `inputs`, can be thought of as having the same structure as
    642   `shallow_tree`, but with leaf nodes that are themselves tree structures.
    643 
    644   This function therefore will return something with the same base structure as
    645   `shallow_tree`.
    646 
    647   Examples:
    648 
    649   ```python
    650   ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    651   op_tuple = collections.namedtuple("op_tuple", "add, mul")
    652   inp_val = ab_tuple(a=2, b=3)
    653   inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
    654   out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
    655                             inp_val, inp_ops)
    656 
    657   # Output is: ab_tuple(a=6, b=15)
    658   ```
    659 
    660   ```python
    661   data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
    662   name_list = ['evens', ['odds', 'primes']]
    663   out = map_structure_up_to(
    664       name_list,
    665       lambda name, sec: "first_{}_{}".format(len(sec), name),
    666       name_list, data_list)
    667 
    668   # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
    669   ```
    670 
    671   Args:
    672     shallow_tree: a shallow tree, common to all the inputs.
    673     func: callable which will be applied to each input individually.
    674     *inputs: arbitrarily nested combination of objects that are compatible with
    675         shallow_tree. The function `func` is applied to corresponding
    676         partially flattened elements of each input, so the function must support
    677         arity of `len(inputs)`.
    678 
    679   Raises:
    680     TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
    681     TypeError: If the sequence types of `shallow_tree` are different from
    682       `input_tree`.
    683     ValueError: If the sequence lengths of `shallow_tree` are different from
    684       `input_tree`.
    685 
    686   Returns:
    687     result of repeatedly applying `func`, with same structure as
    688     `shallow_tree`.
    689   """
    690   if not inputs:
    691     raise ValueError("Cannot map over no sequences")
    692   for input_tree in inputs:
    693     assert_shallow_structure(shallow_tree, input_tree)
    694 
    695   # Flatten each input separately, apply the function to corresponding elements,
    696   # then repack based on the structure of the first input.
    697   all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
    698                          for input_tree in inputs]
    699   results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
    700   return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
    701 
    702 
    703 def get_traverse_shallow_structure(traverse_fn, structure):
    704   """Generates a shallow structure from a `traverse_fn` and `structure`.
    705 
    706   `traverse_fn` must accept any possible subtree of `structure` and return
    707   a depth=1 structure containing `True` or `False` values, describing which
    708   of the top-level subtrees may be traversed.  It may also
    709   return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
    710 
    711   Examples are available in the unit tests (nest_test.py).
    712 
    713   Args:
    714     traverse_fn: Function taking a substructure and returning either a scalar
    715       `bool` (whether to traverse that substructure or not) or a depth=1
    716       shallow structure of the same type, describing which parts of the
    717       substructure to traverse.
    718     structure: The structure to traverse.
    719 
    720   Returns:
    721     A shallow structure containing python bools, which can be passed to
    722     `map_structure_up_to` and `flatten_up_to`.
    723 
    724   Raises:
    725     TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
    726       or a structure with depth higher than 1 for a sequence input,
    727       or if any leaf values in the returned structure or scalar are not type
    728       `bool`.
    729   """
    730   to_traverse = traverse_fn(structure)
    731   if not is_sequence(structure):
    732     if not isinstance(to_traverse, bool):
    733       raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
    734                       % (to_traverse, structure))
    735     return to_traverse
    736   level_traverse = []
    737   if isinstance(to_traverse, bool):
    738     if not to_traverse:
    739       # Do not traverse this substructure at all.  Exit early.
    740       return False
    741     else:
    742       # Traverse the entire substructure.
    743       for branch in _yield_value(structure):
    744         level_traverse.append(
    745             get_traverse_shallow_structure(traverse_fn, branch))
    746   elif not is_sequence(to_traverse):
    747     raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
    748                     % (to_traverse, structure))
    749   else:
    750     # Traverse some subset of this substructure.
    751     assert_shallow_structure(to_traverse, structure)
    752     for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
    753       if not isinstance(t, bool):
    754         raise TypeError(
    755             "traverse_fn didn't return a depth=1 structure of bools.  saw: %s "
    756             " for structure: %s" % (to_traverse, structure))
    757       if t:
    758         level_traverse.append(
    759             get_traverse_shallow_structure(traverse_fn, branch))
    760       else:
    761         level_traverse.append(False)
    762   return _sequence_like(structure, level_traverse)
    763 
    764 
    765 def yield_flat_paths(nest):
    766   """Yields paths for some nested structure.
    767 
    768   Paths are lists of objects which can be str-converted, which may include
    769   integers or other types which are used as indices in a dict.
    770 
    771   The flat list will be in the corresponding order as if you called
    772   `snt.nest.flatten` on the structure. This is handy for naming Tensors such
    773   the TF scope structure matches the tuple structure.
    774 
    775   E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
    776 
    777   ```shell
    778   >>> nest.flatten(value)
    779   [3, 23, 42]
    780   >>> list(nest.yield_flat_paths(value))
    781   [('a',), ('b', 'c'), ('b', 'd')]
    782   ```
    783 
    784   ```shell
    785   >>> list(nest.yield_flat_paths({'a': [3]}))
    786   [('a', 0)]
    787   >>> list(nest.yield_flat_paths({'a': 3}))
    788   [('a',)]
    789   ```
    790 
    791   Args:
    792     nest: the value to produce a flattened paths list for.
    793 
    794   Yields:
    795     Tuples containing index or key values which form the path to a specific
    796       leaf value in the nested structure.
    797   """
    798 
    799   # The _maybe_add_final_path_element function is used below in order to avoid
    800   # adding trailing slashes when the sub-element recursed into is a leaf.
    801   if isinstance(nest, dict):
    802     for key in _sorted(nest):
    803       value = nest[key]
    804       for sub_path in yield_flat_paths(value):
    805         yield (key,) + sub_path
    806   elif _is_namedtuple(nest):
    807     for key in nest._fields:
    808       value = getattr(nest, key)
    809       for sub_path in yield_flat_paths(value):
    810         yield (key,) + sub_path
    811   elif isinstance(nest, _six.string_types):
    812     yield ()
    813   elif isinstance(nest, _collections.Sequence):
    814     for idx, value in enumerate(nest):
    815       for sub_path in yield_flat_paths(value):
    816         yield (idx,) + sub_path
    817   else:
    818     yield ()
    819 
    820 
    821 def flatten_with_joined_string_paths(structure, separator="/"):
    822   """Returns a list of (string path, data element) tuples.
    823 
    824   The order of tuples produced matches that of `nest.flatten`. This allows you
    825   to flatten a nested structure while keeping information about where in the
    826   structure each data element was located. See `nest.yield_flat_paths`
    827   for more information.
    828 
    829   Args:
    830     structure: the nested structure to flatten.
    831     separator: string to separate levels of hierarchy in the results, defaults
    832       to '/'.
    833 
    834   Returns:
    835     A list of (string, data element) tuples.
    836   """
    837   flat_paths = yield_flat_paths(structure)
    838   def stringify_and_join(path_elements):
    839     return separator.join(str(path_element) for path_element in path_elements)
    840   flat_string_paths = [stringify_and_join(path) for path in flat_paths]
    841   return list(zip(flat_string_paths, flatten(structure)))
    842 
    843 
    844 _pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
    845 
    846 
    847 _allowed_symbols = [
    848     "assert_same_structure",
    849     "is_sequence",
    850     "flatten",
    851     "flatten_dict_items",
    852     "pack_sequence_as",
    853     "map_structure",
    854     "assert_shallow_structure",
    855     "flatten_up_to",
    856     "map_structure_up_to",
    857     "get_traverse_shallow_structure",
    858     "yield_flat_paths",
    859     "flatten_with_joined_string_paths",
    860 ]
    861 
    862 remove_undocumented(__name__, _allowed_symbols)
    863