1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utility functions for the graph_editor. 16 """ 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections 23 import re 24 from six import iteritems 25 from tensorflow.python.framework import ops as tf_ops 26 from tensorflow.python.ops import array_ops as tf_array_ops 27 28 __all__ = [ 29 "make_list_of_op", 30 "get_tensors", 31 "make_list_of_t", 32 "get_generating_ops", 33 "get_consuming_ops", 34 "ControlOutputs", 35 "placeholder_name", 36 "make_placeholder_from_tensor", 37 "make_placeholder_from_dtype_and_shape", 38 ] 39 40 41 def concatenate_unique(la, lb): 42 """Add all the elements of `lb` to `la` if they are not there already. 43 44 The elements added to `la` maintain ordering with respect to `lb`. 45 46 Args: 47 la: List of Python objects. 48 lb: List of Python objects. 49 Returns: 50 `la`: The list `la` with missing elements from `lb`. 51 """ 52 la_set = set(la) 53 for l in lb: 54 if l not in la_set: 55 la.append(l) 56 la_set.add(l) 57 return la 58 59 60 # TODO(fkp): very generic code, it should be moved in a more generic place. 61 class ListView(object): 62 """Immutable list wrapper. 63 64 This class is strongly inspired by the one in tf.Operation. 65 """ 66 67 def __init__(self, list_): 68 if not isinstance(list_, list): 69 raise TypeError("Expected a list, got: {}.".format(type(list_))) 70 self._list = list_ 71 72 def __iter__(self): 73 return iter(self._list) 74 75 def __len__(self): 76 return len(self._list) 77 78 def __bool__(self): 79 return bool(self._list) 80 81 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 82 __nonzero__ = __bool__ 83 84 def __getitem__(self, i): 85 return self._list[i] 86 87 def __add__(self, other): 88 if not isinstance(other, list): 89 other = list(other) 90 return list(self) + other 91 92 93 # TODO(fkp): very generic code, it should be moved in a more generic place. 94 def is_iterable(obj): 95 """Return true if the object is iterable.""" 96 if isinstance(obj, tf_ops.Tensor): 97 return False 98 try: 99 _ = iter(obj) 100 except Exception: # pylint: disable=broad-except 101 return False 102 return True 103 104 105 def flatten_tree(tree, leaves=None): 106 """Flatten a tree into a list. 107 108 Args: 109 tree: iterable or not. If iterable, its elements (child) can also be 110 iterable or not. 111 leaves: list to which the tree leaves are appended (None by default). 112 Returns: 113 A list of all the leaves in the tree. 114 """ 115 if leaves is None: 116 leaves = [] 117 if isinstance(tree, dict): 118 for _, child in iteritems(tree): 119 flatten_tree(child, leaves) 120 elif is_iterable(tree): 121 for child in tree: 122 flatten_tree(child, leaves) 123 else: 124 leaves.append(tree) 125 return leaves 126 127 128 def transform_tree(tree, fn, iterable_type=tuple): 129 """Transform all the nodes of a tree. 130 131 Args: 132 tree: iterable or not. If iterable, its elements (child) can also be 133 iterable or not. 134 fn: function to apply to each leaves. 135 iterable_type: type use to construct the resulting tree for unknown 136 iterable, typically `list` or `tuple`. 137 Returns: 138 A tree whose leaves has been transformed by `fn`. 139 The hierarchy of the output tree mimics the one of the input tree. 140 """ 141 if is_iterable(tree): 142 if isinstance(tree, dict): 143 res = tree.__new__(type(tree)) 144 res.__init__( 145 (k, transform_tree(child, fn)) for k, child in iteritems(tree)) 146 return res 147 elif isinstance(tree, tuple): 148 # NamedTuple? 149 if hasattr(tree, "_asdict"): 150 res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn)) 151 else: 152 res = tree.__new__(type(tree), 153 (transform_tree(child, fn) for child in tree)) 154 return res 155 elif isinstance(tree, collections.Sequence): 156 res = tree.__new__(type(tree)) 157 res.__init__(transform_tree(child, fn) for child in tree) 158 return res 159 else: 160 return iterable_type(transform_tree(child, fn) for child in tree) 161 else: 162 return fn(tree) 163 164 165 def check_graphs(*args): 166 """Check that all the element in args belong to the same graph. 167 168 Args: 169 *args: a list of object with a obj.graph property. 170 Raises: 171 ValueError: if all the elements do not belong to the same graph. 172 """ 173 graph = None 174 for i, sgv in enumerate(args): 175 if graph is None and sgv.graph is not None: 176 graph = sgv.graph 177 elif sgv.graph is not None and sgv.graph is not graph: 178 raise ValueError("Argument[{}]: Wrong graph!".format(i)) 179 180 181 def get_unique_graph(tops, check_types=None, none_if_empty=False): 182 """Return the unique graph used by the all the elements in tops. 183 184 Args: 185 tops: list of elements to check (usually a list of tf.Operation and/or 186 tf.Tensor). Or a tf.Graph. 187 check_types: check that the element in tops are of given type(s). If None, 188 the types (tf.Operation, tf.Tensor) are used. 189 none_if_empty: don't raise an error if tops is an empty list, just return 190 None. 191 Returns: 192 The unique graph used by all the tops. 193 Raises: 194 TypeError: if tops is not a iterable of tf.Operation. 195 ValueError: if the graph is not unique. 196 """ 197 if isinstance(tops, tf_ops.Graph): 198 return tops 199 if not is_iterable(tops): 200 raise TypeError("{} is not iterable".format(type(tops))) 201 if check_types is None: 202 check_types = (tf_ops.Operation, tf_ops.Tensor) 203 elif not is_iterable(check_types): 204 check_types = (check_types,) 205 g = None 206 for op in tops: 207 if not isinstance(op, check_types): 208 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( 209 t) for t in check_types]), type(op))) 210 if g is None: 211 g = op.graph 212 elif g is not op.graph: 213 raise ValueError("Operation {} does not belong to given graph".format(op)) 214 if g is None and not none_if_empty: 215 raise ValueError("Can't find the unique graph of an empty list") 216 return g 217 218 219 def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): 220 """Convert ops to a list of `tf.Operation`. 221 222 Args: 223 ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single 224 operation. 225 check_graph: if `True` check if all the operations belong to the same graph. 226 allow_graph: if `False` a `tf.Graph` cannot be converted. 227 ignore_ts: if True, silently ignore `tf.Tensor`. 228 Returns: 229 A newly created list of `tf.Operation`. 230 Raises: 231 TypeError: if ops cannot be converted to a list of `tf.Operation` or, 232 if `check_graph` is `True`, if all the ops do not belong to the 233 same graph. 234 """ 235 if isinstance(ops, tf_ops.Graph): 236 if allow_graph: 237 return ops.get_operations() 238 else: 239 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 240 else: 241 if not is_iterable(ops): 242 ops = [ops] 243 if not ops: 244 return [] 245 if check_graph: 246 check_types = None if ignore_ts else tf_ops.Operation 247 get_unique_graph(ops, check_types=check_types) 248 return [op for op in ops if isinstance(op, tf_ops.Operation)] 249 250 251 # TODO(fkp): move this function in tf.Graph? 252 def get_tensors(graph): 253 """get all the tensors which are input or output of an op in the graph. 254 255 Args: 256 graph: a `tf.Graph`. 257 Returns: 258 A list of `tf.Tensor`. 259 Raises: 260 TypeError: if graph is not a `tf.Graph`. 261 """ 262 if not isinstance(graph, tf_ops.Graph): 263 raise TypeError("Expected a graph, got: {}".format(type(graph))) 264 ts = [] 265 for op in graph.get_operations(): 266 ts += op.outputs 267 return ts 268 269 270 def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): 271 """Convert ts to a list of `tf.Tensor`. 272 273 Args: 274 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. 275 check_graph: if `True` check if all the tensors belong to the same graph. 276 allow_graph: if `False` a `tf.Graph` cannot be converted. 277 ignore_ops: if `True`, silently ignore `tf.Operation`. 278 Returns: 279 A newly created list of `tf.Tensor`. 280 Raises: 281 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, 282 if `check_graph` is `True`, if all the ops do not belong to the same graph. 283 """ 284 if isinstance(ts, tf_ops.Graph): 285 if allow_graph: 286 return get_tensors(ts) 287 else: 288 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 289 else: 290 if not is_iterable(ts): 291 ts = [ts] 292 if not ts: 293 return [] 294 if check_graph: 295 check_types = None if ignore_ops else tf_ops.Tensor 296 get_unique_graph(ts, check_types=check_types) 297 return [t for t in ts if isinstance(t, tf_ops.Tensor)] 298 299 300 def get_generating_ops(ts): 301 """Return all the generating ops of the tensors in `ts`. 302 303 Args: 304 ts: a list of `tf.Tensor` 305 Returns: 306 A list of all the generating `tf.Operation` of the tensors in `ts`. 307 Raises: 308 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. 309 """ 310 ts = make_list_of_t(ts, allow_graph=False) 311 return [t.op for t in ts] 312 313 314 def get_consuming_ops(ts): 315 """Return all the consuming ops of the tensors in ts. 316 317 Args: 318 ts: a list of `tf.Tensor` 319 Returns: 320 A list of all the consuming `tf.Operation` of the tensors in `ts`. 321 Raises: 322 TypeError: if ts cannot be converted to a list of `tf.Tensor`. 323 """ 324 ts = make_list_of_t(ts, allow_graph=False) 325 ops = [] 326 for t in ts: 327 for op in t.consumers(): 328 if op not in ops: 329 ops.append(op) 330 return ops 331 332 333 class ControlOutputs(object): 334 """The control outputs topology.""" 335 336 def __init__(self, graph): 337 """Create a dictionary of control-output dependencies. 338 339 Args: 340 graph: a `tf.Graph`. 341 Returns: 342 A dictionary where a key is a `tf.Operation` instance and the 343 corresponding value is a list of all the ops which have the key 344 as one of their control-input dependencies. 345 Raises: 346 TypeError: graph is not a `tf.Graph`. 347 """ 348 if not isinstance(graph, tf_ops.Graph): 349 raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) 350 self._control_outputs = {} 351 self._graph = graph 352 self._version = None 353 self._build() 354 355 def update(self): 356 """Update the control outputs if the graph has changed.""" 357 if self._version != self._graph.version: 358 self._build() 359 return self 360 361 def _build(self): 362 """Build the control outputs dictionary.""" 363 self._control_outputs.clear() 364 ops = self._graph.get_operations() 365 for op in ops: 366 for control_input in op.control_inputs: 367 if control_input not in self._control_outputs: 368 self._control_outputs[control_input] = [] 369 if op not in self._control_outputs[control_input]: 370 self._control_outputs[control_input].append(op) 371 self._version = self._graph.version 372 373 def get_all(self): 374 return self._control_outputs 375 376 def get(self, op): 377 """return the control outputs of op.""" 378 if op in self._control_outputs: 379 return self._control_outputs[op] 380 else: 381 return () 382 383 @property 384 def graph(self): 385 return self._graph 386 387 388 def scope_finalize(scope): 389 if scope and scope[-1] != "/": 390 scope += "/" 391 return scope 392 393 394 def scope_dirname(scope): 395 slash = scope.rfind("/") 396 if slash == -1: 397 return "" 398 return scope[:slash + 1] 399 400 401 def scope_basename(scope): 402 slash = scope.rfind("/") 403 if slash == -1: 404 return scope 405 return scope[slash + 1:] 406 407 408 def placeholder_name(t=None, scope=None): 409 """Create placeholder name for the graph editor. 410 411 Args: 412 t: optional tensor on which the placeholder operation's name will be based 413 on 414 scope: absolute scope with which to prefix the placeholder's name. None 415 means that the scope of t is preserved. "" means the root scope. 416 Returns: 417 A new placeholder name prefixed by "geph". Note that "geph" stands for 418 Graph Editor PlaceHolder. This convention allows to quickly identify the 419 placeholder generated by the Graph Editor. 420 Raises: 421 TypeError: if t is not None or a tf.Tensor. 422 """ 423 if scope is not None: 424 scope = scope_finalize(scope) 425 if t is not None: 426 if not isinstance(t, tf_ops.Tensor): 427 raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t))) 428 op_dirname = scope_dirname(t.op.name) 429 op_basename = scope_basename(t.op.name) 430 if scope is None: 431 scope = op_dirname 432 433 if op_basename.startswith("geph__"): 434 ph_name = op_basename 435 else: 436 ph_name = "geph__{}_{}".format(op_basename, t.value_index) 437 438 return scope + ph_name 439 else: 440 if scope is None: 441 scope = "" 442 return scope + "geph" 443 444 445 def make_placeholder_from_tensor(t, scope=None): 446 """Create a `tf.placeholder` for the Graph Editor. 447 448 Note that the correct graph scope must be set by the calling function. 449 450 Args: 451 t: a `tf.Tensor` whose name will be used to create the placeholder 452 (see function placeholder_name). 453 scope: absolute scope within which to create the placeholder. None 454 means that the scope of `t` is preserved. `""` means the root scope. 455 Returns: 456 A newly created `tf.placeholder`. 457 Raises: 458 TypeError: if `t` is not `None` or a `tf.Tensor`. 459 """ 460 return tf_array_ops.placeholder( 461 dtype=t.dtype, shape=t.get_shape(), name=placeholder_name( 462 t, scope=scope)) 463 464 465 def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): 466 """Create a tf.placeholder for the Graph Editor. 467 468 Note that the correct graph scope must be set by the calling function. 469 The placeholder is named using the function placeholder_name (with no 470 tensor argument). 471 472 Args: 473 dtype: the tensor type. 474 shape: the tensor shape (optional). 475 scope: absolute scope within which to create the placeholder. None 476 means that the scope of t is preserved. "" means the root scope. 477 Returns: 478 A newly created tf.placeholder. 479 """ 480 return tf_array_ops.placeholder( 481 dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) 482 483 484 _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") 485 486 487 def get_predefined_collection_names(): 488 """Return all the predefined collection names.""" 489 return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) 490 if not _INTERNAL_VARIABLE_RE.match(key)] 491 492 493 def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): 494 """Find corresponding op/tensor in a different graph. 495 496 Args: 497 target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. 498 dst_graph: The graph in which the corresponding graph element must be found. 499 dst_scope: A scope which is prepended to the name to look for. 500 src_scope: A scope which is removed from the original of `target` name. 501 502 Returns: 503 The corresponding tf.Tensor` or a `tf.Operation`. 504 505 Raises: 506 ValueError: if `src_name` does not start with `src_scope`. 507 TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` 508 KeyError: If the corresponding graph element cannot be found. 509 """ 510 src_name = target.name 511 if src_scope: 512 src_scope = scope_finalize(src_scope) 513 if not src_name.startswidth(src_scope): 514 raise ValueError("{} does not start with {}".format(src_name, src_scope)) 515 src_name = src_name[len(src_scope):] 516 517 dst_name = src_name 518 if dst_scope: 519 dst_scope = scope_finalize(dst_scope) 520 dst_name = dst_scope + dst_name 521 522 if isinstance(target, tf_ops.Tensor): 523 return dst_graph.get_tensor_by_name(dst_name) 524 if isinstance(target, tf_ops.Operation): 525 return dst_graph.get_operation_by_name(dst_name) 526 raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) 527 528 529 def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): 530 """Find corresponding ops/tensors in a different graph. 531 532 `targets` is a Python tree, that is, a nested structure of iterable 533 (list, tupple, dictionary) whose leaves are instances of 534 `tf.Tensor` or `tf.Operation` 535 536 Args: 537 targets: A Python tree containing `tf.Tensor` or `tf.Operation` 538 belonging to the original graph. 539 dst_graph: The graph in which the corresponding graph element must be found. 540 dst_scope: A scope which is prepended to the name to look for. 541 src_scope: A scope which is removed from the original of `top` name. 542 543 Returns: 544 A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. 545 546 Raises: 547 ValueError: if `src_name` does not start with `src_scope`. 548 TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` 549 KeyError: If the corresponding graph element cannot be found. 550 """ 551 def func(top): 552 return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) 553 return transform_tree(targets, func) 554