1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """## Functions for working with arbitrarily nested sequences of elements. 17 18 This module can perform operations on nested structures. A nested structure is a 19 Python sequence, tuple (including `namedtuple`), or dict that can contain 20 further sequences, tuples, and dicts. 21 22 The utilities here assume (and do not check) that the nested structures form a 23 'tree', i.e., no references in the structure of the input of these functions 24 should be recursive. 25 26 Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), 27 (np.array([3, 4]), tf.constant([3, 4])))` 28 """ 29 30 from __future__ import absolute_import 31 from __future__ import division 32 from __future__ import print_function 33 34 import collections as _collections 35 36 import six as _six 37 38 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow 39 from tensorflow.python.util.all_util import remove_undocumented 40 41 42 def _sorted(dict_): 43 """Returns a sorted list of the dict keys, with error if keys not sortable.""" 44 try: 45 return sorted(_six.iterkeys(dict_)) 46 except TypeError: 47 raise TypeError("nest only supports dicts with sortable keys.") 48 49 50 def _is_namedtuple(instance, strict=False): 51 """Returns True iff `instance` is a `namedtuple`. 52 53 Args: 54 instance: An instance of a Python object. 55 strict: If True, `instance` is considered to be a `namedtuple` only if 56 it is a "plain" namedtuple. For instance, a class inheriting 57 from a `namedtuple` will be considered to be a `namedtuple` 58 iff `strict=False`. 59 60 Returns: 61 True if `instance` is a `namedtuple`. 62 """ 63 # Attemp to limit the test to plain namedtuple (not stuff inheriting from it). 64 if not isinstance(instance, tuple): 65 return False 66 if strict and instance.__class__.__base__ != tuple: 67 return False 68 return ( 69 hasattr(instance, "_fields") and 70 isinstance(instance._fields, _collections.Sequence) and 71 all(isinstance(f, _six.string_types) for f in instance._fields)) 72 73 74 def _sequence_like(instance, args): 75 """Converts the sequence `args` to the same type as `instance`. 76 77 Args: 78 instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or 79 `collections.OrderedDict`. 80 args: elements to be converted to the `instance` type. 81 82 Returns: 83 `args` with the type of `instance`. 84 """ 85 if isinstance(instance, dict): 86 # Pack dictionaries in a deterministic order by sorting the keys. 87 # Notice this means that we ignore the original order of `OrderedDict` 88 # instances. This is intentional, to avoid potential bugs caused by mixing 89 # ordered and plain dicts (e.g., flattening a dict but using a 90 # corresponding `OrderedDict` to pack it back). 91 result = dict(zip(_sorted(instance), args)) 92 return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) 93 elif _is_namedtuple(instance): 94 return type(instance)(*args) 95 else: 96 # Not a namedtuple 97 return type(instance)(args) 98 99 100 def _yield_value(iterable): 101 if isinstance(iterable, dict): 102 # Iterate through dictionaries in a deterministic order by sorting the 103 # keys. Notice this means that we ignore the original order of `OrderedDict` 104 # instances. This is intentional, to avoid potential bugs caused by mixing 105 # ordered and plain dicts (e.g., flattening a dict but using a 106 # corresponding `OrderedDict` to pack it back). 107 for key in _sorted(iterable): 108 yield iterable[key] 109 else: 110 for value in iterable: 111 yield value 112 113 114 def is_sequence(seq): 115 """Returns a true if its input is a collections.Sequence (except strings). 116 117 Args: 118 seq: an input sequence. 119 120 Returns: 121 True if the sequence is a not a string and is a collections.Sequence or a 122 dict. 123 """ 124 return _pywrap_tensorflow.IsSequence(seq) 125 126 127 def flatten(nest): 128 """Returns a flat list from a given nested structure. 129 130 If `nest` is not a sequence, tuple, or dict, then returns a single-element 131 list: `[nest]`. 132 133 In the case of dict instances, the sequence consists of the values, sorted by 134 key to ensure deterministic behavior. This is true also for `OrderedDict` 135 instances: their sequence order is ignored, the sorting order of keys is 136 used instead. The same convention is followed in `pack_sequence_as`. This 137 correctly repacks dicts and `OrderedDict`s after they have been flattened, 138 and also allows flattening an `OrderedDict` and then repacking it back using 139 a corresponding plain dict, or vice-versa. 140 Dictionaries with non-sortable keys cannot be flattened. 141 142 Users must not modify any collections used in `nest` while this function is 143 running. 144 145 Args: 146 nest: an arbitrarily nested structure or a scalar object. Note, numpy 147 arrays are considered scalars. 148 149 Returns: 150 A Python list, the flattened version of the input. 151 152 Raises: 153 TypeError: The nest is or contains a dict with non-sortable keys. 154 """ 155 return _pywrap_tensorflow.Flatten(nest) 156 157 158 def _same_namedtuples(nest1, nest2): 159 """Returns True if the two namedtuples have the same name and fields.""" 160 if nest1._fields != nest2._fields: 161 return False 162 if nest1.__class__.__name__ != nest2.__class__.__name__: 163 return False 164 return True 165 166 167 def _recursive_assert_same_structure(nest1, nest2, check_types): 168 """Helper function for `assert_same_structure`. 169 170 See `assert_same_structure` for further information about namedtuples. 171 172 Args: 173 nest1: An arbitrarily nested structure. 174 nest2: An arbitrarily nested structure. 175 check_types: If `True` (default) types of sequences are checked as 176 well, including the keys of dictionaries. If set to `False`, for example 177 a list and a tuple of objects will look the same if they have the same 178 size. Note that namedtuples with identical name and fields are always 179 considered to have the same shallow structure. 180 181 Returns: 182 True if `nest1` and `nest2` have the same structure. 183 184 Raises: 185 ValueError: If the two structure don't have the same nested structre. 186 TypeError: If the two structure don't have the same sequence type. 187 ValueError: If the two dictionaries don't have the same set of keys. 188 """ 189 is_sequence_nest1 = is_sequence(nest1) 190 if is_sequence_nest1 != is_sequence(nest2): 191 raise ValueError( 192 "The two structures don't have the same nested structure.\n\n" 193 "First structure: %s\n\nSecond structure: %s." % (nest1, nest2)) 194 195 if not is_sequence_nest1: 196 return # finished checking 197 198 if check_types: 199 type_nest1 = type(nest1) 200 type_nest2 = type(nest2) 201 202 # Duck-typing means that nest should be fine with two different namedtuples 203 # with identical name and fields. 204 if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True): 205 if not _same_namedtuples(nest1, nest2): 206 raise TypeError( 207 "The two namedtuples don't have the same sequence type. First " 208 "structure has type %s, while second structure has type %s." 209 % (type_nest1, type_nest2)) 210 else: 211 if type_nest1 != type_nest2: 212 raise TypeError( 213 "The two structures don't have the same sequence type. First " 214 "structure has type %s, while second structure has type %s." 215 % (type_nest1, type_nest2)) 216 217 if isinstance(nest1, dict): 218 keys1 = set(_six.iterkeys(nest1)) 219 keys2 = set(_six.iterkeys(nest2)) 220 if keys1 != keys2: 221 raise ValueError( 222 "The two dictionaries don't have the same set of keys. First " 223 "structure has keys {}, while second structure has keys {}." 224 .format(keys1, keys2)) 225 226 nest1_as_sequence = [n for n in _yield_value(nest1)] 227 nest2_as_sequence = [n for n in _yield_value(nest2)] 228 for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence): 229 _recursive_assert_same_structure(n1, n2, check_types) 230 231 232 def assert_same_structure(nest1, nest2, check_types=True): 233 """Asserts that two structures are nested in the same way. 234 235 Note that namedtuples with identical name and fields are always considered 236 to have the same shallow structure (even with `check_types=True`). 237 For intance, this code will print `True`: 238 239 ```python 240 def nt(a, b): 241 return collections.namedtuple('foo', 'a b')(a, b) 242 print(assert_same_structure(nt(0, 1), nt(2, 3))) 243 ``` 244 245 Args: 246 nest1: an arbitrarily nested structure. 247 nest2: an arbitrarily nested structure. 248 check_types: if `True` (default) types of sequences are checked as 249 well, including the keys of dictionaries. If set to `False`, for example 250 a list and a tuple of objects will look the same if they have the same 251 size. Note that namedtuples with identical name and fields are always 252 considered to have the same shallow structure. 253 254 Raises: 255 ValueError: If the two structures do not have the same number of elements or 256 if the two structures are not nested in the same way. 257 TypeError: If the two structures differ in the type of sequence in any of 258 their substructures. Only possible if `check_types` is `True`. 259 """ 260 len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 261 len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 262 if len_nest1 != len_nest2: 263 raise ValueError("The two structures don't have the same number of " 264 "elements.\n\nFirst structure (%i elements): %s\n\n" 265 "Second structure (%i elements): %s" 266 % (len_nest1, nest1, len_nest2, nest2)) 267 _recursive_assert_same_structure(nest1, nest2, check_types) 268 269 270 def flatten_dict_items(dictionary): 271 """Returns a dictionary with flattened keys and values. 272 273 This function flattens the keys and values of a dictionary, which can be 274 arbitrarily nested structures, and returns the flattened version of such 275 structures: 276 277 ```python 278 example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} 279 result = {4: "a", 5: "b", 6: "c", 8: "d"} 280 flatten_dict_items(example_dictionary) == result 281 ``` 282 283 The input dictionary must satisfy two properties: 284 285 1. Its keys and values should have the same exact nested structure. 286 2. The set of all flattened keys of the dictionary must not contain repeated 287 keys. 288 289 Args: 290 dictionary: the dictionary to zip 291 292 Returns: 293 The zipped dictionary. 294 295 Raises: 296 TypeError: If the input is not a dictionary. 297 ValueError: If any key and value have not the same structure, or if keys are 298 not unique. 299 """ 300 if not isinstance(dictionary, dict): 301 raise TypeError("input must be a dictionary") 302 flat_dictionary = {} 303 for i, v in _six.iteritems(dictionary): 304 if not is_sequence(i): 305 if i in flat_dictionary: 306 raise ValueError( 307 "Could not flatten dictionary: key %s is not unique." % i) 308 flat_dictionary[i] = v 309 else: 310 flat_i = flatten(i) 311 flat_v = flatten(v) 312 if len(flat_i) != len(flat_v): 313 raise ValueError( 314 "Could not flatten dictionary. Key had %d elements, but value had " 315 "%d elements. Key: %s, value: %s." 316 % (len(flat_i), len(flat_v), flat_i, flat_v)) 317 for new_i, new_v in zip(flat_i, flat_v): 318 if new_i in flat_dictionary: 319 raise ValueError( 320 "Could not flatten dictionary: key %s is not unique." 321 % (new_i)) 322 flat_dictionary[new_i] = new_v 323 return flat_dictionary 324 325 326 def _packed_nest_with_indices(structure, flat, index): 327 """Helper function for pack_sequence_as. 328 329 Args: 330 structure: Substructure (list / tuple / dict) to mimic. 331 flat: Flattened values to output substructure for. 332 index: Index at which to start reading from flat. 333 334 Returns: 335 The tuple (new_index, child), where: 336 * new_index - the updated index into `flat` having processed `structure`. 337 * packed - the subset of `flat` corresponding to `structure`, 338 having started at `index`, and packed into the same nested 339 format. 340 341 Raises: 342 ValueError: if `structure` contains more elements than `flat` 343 (assuming indexing starts from `index`). 344 """ 345 packed = [] 346 for s in _yield_value(structure): 347 if is_sequence(s): 348 new_index, child = _packed_nest_with_indices(s, flat, index) 349 packed.append(_sequence_like(s, child)) 350 index = new_index 351 else: 352 packed.append(flat[index]) 353 index += 1 354 return index, packed 355 356 357 def pack_sequence_as(structure, flat_sequence): 358 """Returns a given flattened sequence packed into a given structure. 359 360 If `structure` is a scalar, `flat_sequence` must be a single-element list; 361 in this case the return value is `flat_sequence[0]`. 362 363 If `structure` is or contains a dict instance, the keys will be sorted to 364 pack the flat sequence in deterministic order. This is true also for 365 `OrderedDict` instances: their sequence order is ignored, the sorting order of 366 keys is used instead. The same convention is followed in `flatten`. 367 This correctly repacks dicts and `OrderedDict`s after they have been 368 flattened, and also allows flattening an `OrderedDict` and then repacking it 369 back using a corresponding plain dict, or vice-versa. 370 Dictionaries with non-sortable keys cannot be flattened. 371 372 Args: 373 structure: Nested structure, whose structure is given by nested lists, 374 tuples, and dicts. Note: numpy arrays and strings are considered 375 scalars. 376 flat_sequence: flat sequence to pack. 377 378 Returns: 379 packed: `flat_sequence` converted to have the same recursive structure as 380 `structure`. 381 382 Raises: 383 ValueError: If `flat_sequence` and `structure` have different 384 element counts. 385 TypeError: `structure` is or contains a dict with non-sortable keys. 386 """ 387 if not is_sequence(flat_sequence): 388 raise TypeError("flat_sequence must be a sequence") 389 390 if not is_sequence(structure): 391 if len(flat_sequence) != 1: 392 raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1" 393 % len(flat_sequence)) 394 return flat_sequence[0] 395 396 flat_structure = flatten(structure) 397 if len(flat_structure) != len(flat_sequence): 398 raise ValueError( 399 "Could not pack sequence. Structure had %d elements, but flat_sequence " 400 "had %d elements. Structure: %s, flat_sequence: %s." 401 % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) 402 403 _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) 404 return _sequence_like(structure, packed) 405 406 407 def map_structure(func, *structure, **check_types_dict): 408 """Applies `func` to each entry in `structure` and returns a new structure. 409 410 Applies `func(x[0], x[1], ...)` where x[i] is an entry in 411 `structure[i]`. All structures in `structure` must have the same arity, 412 and the return value will contain the results in the same structure. 413 414 Args: 415 func: A callable that accepts as many arguments as there are structures. 416 *structure: scalar, or tuple or list of constructed scalars and/or other 417 tuples/lists, or scalars. Note: numpy arrays are considered as scalars. 418 **check_types_dict: only valid keyword argument is `check_types`. If set to 419 `True` (default) the types of iterables within the structures have to be 420 same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` 421 exception). To allow this set this argument to `False`. 422 Note that namedtuples with identical name and fields are always 423 considered to have the same shallow structure. 424 425 Returns: 426 A new structure with the same arity as `structure`, whose values correspond 427 to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding 428 location in `structure[i]`. If there are different sequence types and 429 `check_types` is `False` the sequence types of the first structure will be 430 used. 431 432 Raises: 433 TypeError: If `func` is not callable or if the structures do not match 434 each other by depth tree. 435 ValueError: If no structure is provided or if the structures do not match 436 each other by type. 437 ValueError: If wrong keyword arguments are provided. 438 """ 439 if not callable(func): 440 raise TypeError("func must be callable, got: %s" % func) 441 442 if not structure: 443 raise ValueError("Must provide at least one structure") 444 445 if check_types_dict: 446 if "check_types" not in check_types_dict or len(check_types_dict) > 1: 447 raise ValueError("Only valid keyword argument is check_types") 448 check_types = check_types_dict["check_types"] 449 else: 450 check_types = True 451 452 for other in structure[1:]: 453 assert_same_structure(structure[0], other, check_types=check_types) 454 455 flat_structure = [flatten(s) for s in structure] 456 entries = zip(*flat_structure) 457 458 return pack_sequence_as( 459 structure[0], [func(*x) for x in entries]) 460 461 462 def _yield_flat_up_to(shallow_tree, input_tree): 463 """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" 464 if is_sequence(shallow_tree): 465 for shallow_branch, input_branch in zip(_yield_value(shallow_tree), 466 _yield_value(input_tree)): 467 for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): 468 yield input_leaf 469 else: 470 yield input_tree 471 472 473 def assert_shallow_structure(shallow_tree, input_tree, check_types=True): 474 """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 475 476 That is, this function tests if the `input_tree` structure can be created from 477 the `shallow_tree` structure by replacing its leaf nodes with deeper 478 tree structures. 479 480 Examples: 481 482 The following code will raise an exception: 483 ```python 484 shallow_tree = ["a", "b"] 485 input_tree = ["c", ["d", "e"], "f"] 486 assert_shallow_structure(shallow_tree, input_tree) 487 ``` 488 489 The following code will not raise an exception: 490 ```python 491 shallow_tree = ["a", "b"] 492 input_tree = ["c", ["d", "e"]] 493 assert_shallow_structure(shallow_tree, input_tree) 494 ``` 495 496 Args: 497 shallow_tree: an arbitrarily nested structure. 498 input_tree: an arbitrarily nested structure. 499 check_types: if `True` (default) the sequence types of `shallow_tree` and 500 `input_tree` have to be the same. Note that even with check_types==True, 501 this function will consider two different namedtuple classes with the same 502 name and _fields attribute to be the same class. 503 504 Raises: 505 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 506 TypeError: If the sequence types of `shallow_tree` are different from 507 `input_tree`. Only raised if `check_types` is `True`. 508 ValueError: If the sequence lengths of `shallow_tree` are different from 509 `input_tree`. 510 """ 511 if is_sequence(shallow_tree): 512 if not is_sequence(input_tree): 513 raise TypeError( 514 "If shallow structure is a sequence, input must also be a sequence. " 515 "Input has type: %s." % type(input_tree)) 516 517 if check_types and not isinstance(input_tree, type(shallow_tree)): 518 # Duck-typing means that nest should be fine with two different 519 # namedtuples with identical name and fields. 520 shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) 521 input_is_namedtuple = _is_namedtuple(input_tree, False) 522 if shallow_is_namedtuple and input_is_namedtuple: 523 if not _same_namedtuples(shallow_tree, input_tree): 524 raise TypeError( 525 "The two namedtuples don't have the same sequence type. Input " 526 "structure has type %s, while shallow structure has type %s." 527 % (type(input_tree), type(shallow_tree))) 528 else: 529 raise TypeError( 530 "The two structures don't have the same sequence type. Input " 531 "structure has type %s, while shallow structure has type %s." 532 % (type(input_tree), type(shallow_tree))) 533 534 if len(input_tree) != len(shallow_tree): 535 raise ValueError( 536 "The two structures don't have the same sequence length. Input " 537 "structure has length %s, while shallow structure has length %s." 538 % (len(input_tree), len(shallow_tree))) 539 540 if check_types and isinstance(shallow_tree, dict): 541 if set(input_tree) != set(shallow_tree): 542 raise ValueError( 543 "The two structures don't have the same keys. Input " 544 "structure has keys %s, while shallow structure has keys %s." % 545 (list(_six.iterkeys(input_tree)), 546 list(_six.iterkeys(shallow_tree)))) 547 548 input_tree = list(sorted(_six.iteritems(input_tree))) 549 shallow_tree = list(sorted(_six.iteritems(shallow_tree))) 550 551 for shallow_branch, input_branch in zip(shallow_tree, input_tree): 552 assert_shallow_structure(shallow_branch, input_branch, 553 check_types=check_types) 554 555 556 def flatten_up_to(shallow_tree, input_tree): 557 """Flattens `input_tree` up to `shallow_tree`. 558 559 Any further depth in structure in `input_tree` is retained as elements in the 560 partially flatten output. 561 562 If `shallow_tree` and `input_tree` are not sequences, this returns a 563 single-element list: `[input_tree]`. 564 565 Use Case: 566 567 Sometimes we may wish to partially flatten a nested sequence, retaining some 568 of the nested structure. We achieve this by specifying a shallow structure, 569 `shallow_tree`, we wish to flatten up to. 570 571 The input, `input_tree`, can be thought of as having the same structure as 572 `shallow_tree`, but with leaf nodes that are themselves tree structures. 573 574 Examples: 575 576 ```python 577 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 578 shallow_tree = [[True, True], [False, True]] 579 580 flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 581 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 582 583 # Output is: 584 # [[2, 2], [3, 3], [4, 9], [5, 5]] 585 # [True, True, False, True] 586 ``` 587 588 ```python 589 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 590 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 591 592 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 593 input_tree_flattened = flatten(input_tree) 594 595 # Output is: 596 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 597 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 598 ``` 599 600 Non-Sequence Edge Cases: 601 602 ```python 603 flatten_up_to(0, 0) # Output: [0] 604 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 605 flatten_up_to([0, 1, 2], 0) # Output: TypeError 606 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 607 ``` 608 609 Args: 610 shallow_tree: a possibly pruned structure of input_tree. 611 input_tree: an arbitrarily nested structure or a scalar object. 612 Note, numpy arrays are considered scalars. 613 614 Returns: 615 A Python list, the partially flattened version of `input_tree` according to 616 the structure of `shallow_tree`. 617 618 Raises: 619 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 620 TypeError: If the sequence types of `shallow_tree` are different from 621 `input_tree`. 622 ValueError: If the sequence lengths of `shallow_tree` are different from 623 `input_tree`. 624 """ 625 assert_shallow_structure(shallow_tree, input_tree) 626 return list(_yield_flat_up_to(shallow_tree, input_tree)) 627 628 629 def map_structure_up_to(shallow_tree, func, *inputs): 630 """Applies a function or op to a number of partially flattened inputs. 631 632 The `inputs` are flattened up to `shallow_tree` before being mapped. 633 634 Use Case: 635 636 Sometimes we wish to apply a function to a partially flattened 637 sequence (for example when the function itself takes sequence inputs). We 638 achieve this by specifying a shallow structure, `shallow_tree` we wish to 639 flatten up to. 640 641 The `inputs`, can be thought of as having the same structure as 642 `shallow_tree`, but with leaf nodes that are themselves tree structures. 643 644 This function therefore will return something with the same base structure as 645 `shallow_tree`. 646 647 Examples: 648 649 ```python 650 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 651 op_tuple = collections.namedtuple("op_tuple", "add, mul") 652 inp_val = ab_tuple(a=2, b=3) 653 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 654 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 655 inp_val, inp_ops) 656 657 # Output is: ab_tuple(a=6, b=15) 658 ``` 659 660 ```python 661 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 662 name_list = ['evens', ['odds', 'primes']] 663 out = map_structure_up_to( 664 name_list, 665 lambda name, sec: "first_{}_{}".format(len(sec), name), 666 name_list, data_list) 667 668 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 669 ``` 670 671 Args: 672 shallow_tree: a shallow tree, common to all the inputs. 673 func: callable which will be applied to each input individually. 674 *inputs: arbitrarily nested combination of objects that are compatible with 675 shallow_tree. The function `func` is applied to corresponding 676 partially flattened elements of each input, so the function must support 677 arity of `len(inputs)`. 678 679 Raises: 680 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 681 TypeError: If the sequence types of `shallow_tree` are different from 682 `input_tree`. 683 ValueError: If the sequence lengths of `shallow_tree` are different from 684 `input_tree`. 685 686 Returns: 687 result of repeatedly applying `func`, with same structure as 688 `shallow_tree`. 689 """ 690 if not inputs: 691 raise ValueError("Cannot map over no sequences") 692 for input_tree in inputs: 693 assert_shallow_structure(shallow_tree, input_tree) 694 695 # Flatten each input separately, apply the function to corresponding elements, 696 # then repack based on the structure of the first input. 697 all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) 698 for input_tree in inputs] 699 results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] 700 return pack_sequence_as(structure=shallow_tree, flat_sequence=results) 701 702 703 def get_traverse_shallow_structure(traverse_fn, structure): 704 """Generates a shallow structure from a `traverse_fn` and `structure`. 705 706 `traverse_fn` must accept any possible subtree of `structure` and return 707 a depth=1 structure containing `True` or `False` values, describing which 708 of the top-level subtrees may be traversed. It may also 709 return scalar `True` or `False` "traversal is OK / not OK for all subtrees." 710 711 Examples are available in the unit tests (nest_test.py). 712 713 Args: 714 traverse_fn: Function taking a substructure and returning either a scalar 715 `bool` (whether to traverse that substructure or not) or a depth=1 716 shallow structure of the same type, describing which parts of the 717 substructure to traverse. 718 structure: The structure to traverse. 719 720 Returns: 721 A shallow structure containing python bools, which can be passed to 722 `map_structure_up_to` and `flatten_up_to`. 723 724 Raises: 725 TypeError: if `traverse_fn` returns a sequence for a non-sequence input, 726 or a structure with depth higher than 1 for a sequence input, 727 or if any leaf values in the returned structure or scalar are not type 728 `bool`. 729 """ 730 to_traverse = traverse_fn(structure) 731 if not is_sequence(structure): 732 if not isinstance(to_traverse, bool): 733 raise TypeError("traverse_fn returned structure: %s for non-structure: %s" 734 % (to_traverse, structure)) 735 return to_traverse 736 level_traverse = [] 737 if isinstance(to_traverse, bool): 738 if not to_traverse: 739 # Do not traverse this substructure at all. Exit early. 740 return False 741 else: 742 # Traverse the entire substructure. 743 for branch in _yield_value(structure): 744 level_traverse.append( 745 get_traverse_shallow_structure(traverse_fn, branch)) 746 elif not is_sequence(to_traverse): 747 raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" 748 % (to_traverse, structure)) 749 else: 750 # Traverse some subset of this substructure. 751 assert_shallow_structure(to_traverse, structure) 752 for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): 753 if not isinstance(t, bool): 754 raise TypeError( 755 "traverse_fn didn't return a depth=1 structure of bools. saw: %s " 756 " for structure: %s" % (to_traverse, structure)) 757 if t: 758 level_traverse.append( 759 get_traverse_shallow_structure(traverse_fn, branch)) 760 else: 761 level_traverse.append(False) 762 return _sequence_like(structure, level_traverse) 763 764 765 def yield_flat_paths(nest): 766 """Yields paths for some nested structure. 767 768 Paths are lists of objects which can be str-converted, which may include 769 integers or other types which are used as indices in a dict. 770 771 The flat list will be in the corresponding order as if you called 772 `snt.nest.flatten` on the structure. This is handy for naming Tensors such 773 the TF scope structure matches the tuple structure. 774 775 E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` 776 777 ```shell 778 >>> nest.flatten(value) 779 [3, 23, 42] 780 >>> list(nest.yield_flat_paths(value)) 781 [('a',), ('b', 'c'), ('b', 'd')] 782 ``` 783 784 ```shell 785 >>> list(nest.yield_flat_paths({'a': [3]})) 786 [('a', 0)] 787 >>> list(nest.yield_flat_paths({'a': 3})) 788 [('a',)] 789 ``` 790 791 Args: 792 nest: the value to produce a flattened paths list for. 793 794 Yields: 795 Tuples containing index or key values which form the path to a specific 796 leaf value in the nested structure. 797 """ 798 799 # The _maybe_add_final_path_element function is used below in order to avoid 800 # adding trailing slashes when the sub-element recursed into is a leaf. 801 if isinstance(nest, dict): 802 for key in _sorted(nest): 803 value = nest[key] 804 for sub_path in yield_flat_paths(value): 805 yield (key,) + sub_path 806 elif _is_namedtuple(nest): 807 for key in nest._fields: 808 value = getattr(nest, key) 809 for sub_path in yield_flat_paths(value): 810 yield (key,) + sub_path 811 elif isinstance(nest, _six.string_types): 812 yield () 813 elif isinstance(nest, _collections.Sequence): 814 for idx, value in enumerate(nest): 815 for sub_path in yield_flat_paths(value): 816 yield (idx,) + sub_path 817 else: 818 yield () 819 820 821 def flatten_with_joined_string_paths(structure, separator="/"): 822 """Returns a list of (string path, data element) tuples. 823 824 The order of tuples produced matches that of `nest.flatten`. This allows you 825 to flatten a nested structure while keeping information about where in the 826 structure each data element was located. See `nest.yield_flat_paths` 827 for more information. 828 829 Args: 830 structure: the nested structure to flatten. 831 separator: string to separate levels of hierarchy in the results, defaults 832 to '/'. 833 834 Returns: 835 A list of (string, data element) tuples. 836 """ 837 flat_paths = yield_flat_paths(structure) 838 def stringify_and_join(path_elements): 839 return separator.join(str(path_element) for path_element in path_elements) 840 flat_string_paths = [stringify_and_join(path) for path in flat_paths] 841 return list(zip(flat_string_paths, flatten(structure))) 842 843 844 _pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence) 845 846 847 _allowed_symbols = [ 848 "assert_same_structure", 849 "is_sequence", 850 "flatten", 851 "flatten_dict_items", 852 "pack_sequence_as", 853 "map_structure", 854 "assert_shallow_structure", 855 "flatten_up_to", 856 "map_structure_up_to", 857 "get_traverse_shallow_structure", 858 "yield_flat_paths", 859 "flatten_with_joined_string_paths", 860 ] 861 862 remove_undocumented(__name__, _allowed_symbols) 863