Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Library to compute order of computations in a graph.
     16 """
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections
     23 import math
     24 from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters
     25 from tensorflow.python.platform import tf_logging as logging
     26 
     27 
     28 def parse_graph_nodes(graph_def):
     29   """Helper function to parse GraphDef's nodes.
     30 
     31   It returns a dict mapping from node name to NodeDef.
     32 
     33   Args:
     34     graph_def: A GraphDef object.
     35 
     36   Returns:
     37     name_to_node: Dict keyed by node name, each entry containing the node's
     38       NodeDef.
     39   """
     40   name_to_node = {}
     41   for node_def in graph_def.node:
     42     name_to_node[node_def.name] = node_def
     43   return name_to_node
     44 
     45 
     46 # Named tuple used to collect information from each node in a computation graph.
     47 _node_info = collections.namedtuple(
     48     'NodeInfo', field_names=['order', 'node', 'input_size', 'output_size'])
     49 
     50 
     51 def _compute_output_resolution(input_spatial_resolution, kernel_size, stride,
     52                                total_padding):
     53   """Computes output resolution, given input resolution and layer parameters.
     54 
     55   Note that this computation is done only over one dimension (eg, x or y).
     56   If any of the inputs is None, returns None.
     57 
     58   Args:
     59     input_spatial_resolution: Input spatial resolution (int).
     60     kernel_size: Kernel size (int).
     61     stride: Stride (int).
     62     total_padding: Total padding to be applied (int).
     63   Returns:
     64     output_resolution: Output dimension (int) or None.
     65   """
     66   if (input_spatial_resolution is None) or (kernel_size is None) or (
     67       stride is None) or (total_padding is None):
     68     return None
     69   return int(
     70       math.ceil((
     71           input_spatial_resolution + total_padding - kernel_size + 1) / stride))
     72 
     73 
     74 def _get_computed_nodes(name_to_node,
     75                         current,
     76                         node_info,
     77                         input_node_name='',
     78                         input_node_size=None):
     79   """Traverses the graph recursively to compute its topological order.
     80 
     81   Optionally, the function may also compute the input and output feature map
     82   resolutions at each node. In this case, input_node_name and input_node_size
     83   must be set. Note that if a node's op type is unknown, the input and output
     84   resolutions are ignored and set to None.
     85 
     86   Args:
     87     name_to_node: Dict keyed by node name, each entry containing the node's
     88       NodeDef.
     89     current: Current node name.
     90     node_info: Map of nodes we've already traversed, containing their _node_info
     91       information.
     92     input_node_name: Name of node with fixed input resolution (optional).
     93     input_node_size: Fixed input resolution to use (optional).
     94   Returns:
     95     order: Order in topological sort for 'current'.
     96     input_size: Tensor spatial resolution at input of current node.
     97     output_size: Tensor spatial resolution at output of current node.
     98   """
     99   if current in node_info:
    100     return (node_info[current].order, node_info[current].input_size,
    101             node_info[current].output_size)
    102 
    103   node_def = name_to_node[current]
    104 
    105   if current == input_node_name:
    106     order = 0
    107     input_size = None
    108     output_size = input_node_size
    109     node_info[current] = _node_info(order, node_def, input_size, output_size)
    110     return (order, input_size, output_size)
    111 
    112   input_size = None
    113   output_size = None
    114 
    115   order = 0
    116   number_inputs = 0
    117   for each in node_def.input:
    118     # Parses name of input node.
    119     if each.startswith('^'):
    120       # The character '^' denotes a control dependency, so this input node can
    121       # be safely ignored.
    122       continue
    123     each = each.split(':')[0]
    124     # Recursively computes ordering.
    125     (parent_order, _, parent_output_size) = _get_computed_nodes(
    126         name_to_node, each, node_info, input_node_name, input_node_size)
    127     order = max(order, parent_order + 1)
    128     if number_inputs == 0:
    129       # For all the types of nodes we consider, the first input corresponds to
    130       # the feature map.
    131       input_size = parent_output_size
    132     number_inputs += 1
    133 
    134   # Figure out output size for this layer.
    135   logging.vlog(3, 'input_size = %s', input_size)
    136   if input_size is None:
    137     output_size = None
    138   else:
    139     (kernel_size_x, kernel_size_y, stride_x, stride_y, _, _, total_padding_x,
    140      total_padding_y) = (
    141          parse_layer_parameters.get_layer_params(
    142              node_def, name_to_node, input_size, force=True))
    143     logging.vlog(3, 'kernel_size_x = %s, kernel_size_y = %s, '
    144                  'stride_x = %s, stride_y = %s, '
    145                  'total_padding_x = %s, total_padding_y = %s' %
    146                  (kernel_size_x, kernel_size_y, stride_x, stride_y,
    147                   total_padding_x, total_padding_y))
    148     output_size = [None] * 2
    149     output_size[0] = _compute_output_resolution(input_size[0], kernel_size_x,
    150                                                 stride_x, total_padding_x)
    151     output_size[1] = _compute_output_resolution(input_size[1], kernel_size_y,
    152                                                 stride_y, total_padding_y)
    153 
    154   logging.vlog(3, 'output_size = %s', output_size)
    155   node_info[current] = _node_info(order, node_def, input_size, output_size)
    156 
    157   return order, input_size, output_size
    158 
    159 
    160 def get_compute_order(graph_def, input_node_name='', input_node_size=None):
    161   """Computes order of computation for a given CNN graph.
    162 
    163   Optionally, the function may also compute the input and output feature map
    164   resolutions at each node. In this case, input_node_name and input_node_size
    165   must be set. Note that if a node's op type is unknown, the input and output
    166   resolutions are ignored and set to None.
    167 
    168   Args:
    169     graph_def: GraphDef object.
    170     input_node_name: Name of node with fixed input resolution (optional). This
    171       is usually the node name for the input image in a CNN.
    172     input_node_size: 2D list of integers, fixed input resolution to use
    173       (optional). This is usually the input resolution used for the input image
    174       in a CNN (common examples are: [224, 224], [299, 299], [321, 321]).
    175   Returns:
    176     node_info: Default dict keyed by node name, mapping to a named tuple with
    177       the following fields:
    178       - order: Integer denoting topological order;
    179       - node: NodeDef for the given node;
    180       - input_size: 2D list of integers, denoting the input spatial resolution
    181         to the node;
    182       - output_size: 2D list of integers, denoting the output spatial resolution
    183         of the node.
    184     name_to_node: Dict keyed by node name, each entry containing the node's
    185       NodeDef.
    186   """
    187   name_to_node = parse_graph_nodes(graph_def)
    188   node_info = collections.defaultdict(_node_info)
    189   for each in graph_def.node:
    190     _get_computed_nodes(name_to_node, each.name, node_info, input_node_name,
    191                         input_node_size)
    192   return node_info, name_to_node
    193