1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Library to compute order of computations in a graph. 16 """ 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections 23 import math 24 from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters 25 from tensorflow.python.platform import tf_logging as logging 26 27 28 def parse_graph_nodes(graph_def): 29 """Helper function to parse GraphDef's nodes. 30 31 It returns a dict mapping from node name to NodeDef. 32 33 Args: 34 graph_def: A GraphDef object. 35 36 Returns: 37 name_to_node: Dict keyed by node name, each entry containing the node's 38 NodeDef. 39 """ 40 name_to_node = {} 41 for node_def in graph_def.node: 42 name_to_node[node_def.name] = node_def 43 return name_to_node 44 45 46 # Named tuple used to collect information from each node in a computation graph. 47 _node_info = collections.namedtuple( 48 'NodeInfo', field_names=['order', 'node', 'input_size', 'output_size']) 49 50 51 def _compute_output_resolution(input_spatial_resolution, kernel_size, stride, 52 total_padding): 53 """Computes output resolution, given input resolution and layer parameters. 54 55 Note that this computation is done only over one dimension (eg, x or y). 56 If any of the inputs is None, returns None. 57 58 Args: 59 input_spatial_resolution: Input spatial resolution (int). 60 kernel_size: Kernel size (int). 61 stride: Stride (int). 62 total_padding: Total padding to be applied (int). 63 Returns: 64 output_resolution: Output dimension (int) or None. 65 """ 66 if (input_spatial_resolution is None) or (kernel_size is None) or ( 67 stride is None) or (total_padding is None): 68 return None 69 return int( 70 math.ceil(( 71 input_spatial_resolution + total_padding - kernel_size + 1) / stride)) 72 73 74 def _get_computed_nodes(name_to_node, 75 current, 76 node_info, 77 input_node_name='', 78 input_node_size=None): 79 """Traverses the graph recursively to compute its topological order. 80 81 Optionally, the function may also compute the input and output feature map 82 resolutions at each node. In this case, input_node_name and input_node_size 83 must be set. Note that if a node's op type is unknown, the input and output 84 resolutions are ignored and set to None. 85 86 Args: 87 name_to_node: Dict keyed by node name, each entry containing the node's 88 NodeDef. 89 current: Current node name. 90 node_info: Map of nodes we've already traversed, containing their _node_info 91 information. 92 input_node_name: Name of node with fixed input resolution (optional). 93 input_node_size: Fixed input resolution to use (optional). 94 Returns: 95 order: Order in topological sort for 'current'. 96 input_size: Tensor spatial resolution at input of current node. 97 output_size: Tensor spatial resolution at output of current node. 98 """ 99 if current in node_info: 100 return (node_info[current].order, node_info[current].input_size, 101 node_info[current].output_size) 102 103 node_def = name_to_node[current] 104 105 if current == input_node_name: 106 order = 0 107 input_size = None 108 output_size = input_node_size 109 node_info[current] = _node_info(order, node_def, input_size, output_size) 110 return (order, input_size, output_size) 111 112 input_size = None 113 output_size = None 114 115 order = 0 116 number_inputs = 0 117 for each in node_def.input: 118 # Parses name of input node. 119 if each.startswith('^'): 120 # The character '^' denotes a control dependency, so this input node can 121 # be safely ignored. 122 continue 123 each = each.split(':')[0] 124 # Recursively computes ordering. 125 (parent_order, _, parent_output_size) = _get_computed_nodes( 126 name_to_node, each, node_info, input_node_name, input_node_size) 127 order = max(order, parent_order + 1) 128 if number_inputs == 0: 129 # For all the types of nodes we consider, the first input corresponds to 130 # the feature map. 131 input_size = parent_output_size 132 number_inputs += 1 133 134 # Figure out output size for this layer. 135 logging.vlog(3, 'input_size = %s', input_size) 136 if input_size is None: 137 output_size = None 138 else: 139 (kernel_size_x, kernel_size_y, stride_x, stride_y, _, _, total_padding_x, 140 total_padding_y) = ( 141 parse_layer_parameters.get_layer_params( 142 node_def, name_to_node, input_size, force=True)) 143 logging.vlog(3, 'kernel_size_x = %s, kernel_size_y = %s, ' 144 'stride_x = %s, stride_y = %s, ' 145 'total_padding_x = %s, total_padding_y = %s' % 146 (kernel_size_x, kernel_size_y, stride_x, stride_y, 147 total_padding_x, total_padding_y)) 148 output_size = [None] * 2 149 output_size[0] = _compute_output_resolution(input_size[0], kernel_size_x, 150 stride_x, total_padding_x) 151 output_size[1] = _compute_output_resolution(input_size[1], kernel_size_y, 152 stride_y, total_padding_y) 153 154 logging.vlog(3, 'output_size = %s', output_size) 155 node_info[current] = _node_info(order, node_def, input_size, output_size) 156 157 return order, input_size, output_size 158 159 160 def get_compute_order(graph_def, input_node_name='', input_node_size=None): 161 """Computes order of computation for a given CNN graph. 162 163 Optionally, the function may also compute the input and output feature map 164 resolutions at each node. In this case, input_node_name and input_node_size 165 must be set. Note that if a node's op type is unknown, the input and output 166 resolutions are ignored and set to None. 167 168 Args: 169 graph_def: GraphDef object. 170 input_node_name: Name of node with fixed input resolution (optional). This 171 is usually the node name for the input image in a CNN. 172 input_node_size: 2D list of integers, fixed input resolution to use 173 (optional). This is usually the input resolution used for the input image 174 in a CNN (common examples are: [224, 224], [299, 299], [321, 321]). 175 Returns: 176 node_info: Default dict keyed by node name, mapping to a named tuple with 177 the following fields: 178 - order: Integer denoting topological order; 179 - node: NodeDef for the given node; 180 - input_size: 2D list of integers, denoting the input spatial resolution 181 to the node; 182 - output_size: 2D list of integers, denoting the output spatial resolution 183 of the node. 184 name_to_node: Dict keyed by node name, each entry containing the node's 185 NodeDef. 186 """ 187 name_to_node = parse_graph_nodes(graph_def) 188 node_info = collections.defaultdict(_node_info) 189 for each in graph_def.node: 190 _get_computed_nodes(name_to_node, each.name, node_info, input_node_name, 191 input_node_size) 192 return node_info, name_to_node 193