Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functions to compute receptive field of a fully-convolutional network.
     16 
     17 Please refer to the following g3doc for detailed explanation on how this
     18 computation is performed, and why it is important:
     19 g3doc/photos/vision/features/delf/g3doc/rf_computation.md
     20 """
     21 
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import numpy as np
     27 from tensorflow.contrib.receptive_field.python.util import graph_compute_order
     28 from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters
     29 from tensorflow.python.framework import ops as framework_ops
     30 from tensorflow.python.platform import tf_logging as logging
     31 
     32 
     33 def _get_rf_size_node_input(stride, kernel_size, rf_size_output):
     34   """Computes RF size at the input of a given layer.
     35 
     36   Args:
     37     stride: Stride of given layer (integer).
     38     kernel_size: Kernel size of given layer (integer).
     39     rf_size_output: RF size at output of given layer (integer).
     40 
     41   Returns:
     42     rf_size_input: RF size at input of given layer (integer).
     43   """
     44   return stride * rf_size_output + kernel_size - stride
     45 
     46 
     47 def _get_effective_stride_node_input(stride, effective_stride_output):
     48   """Computes effective stride at the input of a given layer.
     49 
     50   Args:
     51     stride: Stride of given layer (integer).
     52     effective_stride_output: Effective stride at output of given layer
     53       (integer).
     54 
     55   Returns:
     56     effective_stride_input: Effective stride at input of given layer
     57       (integer).
     58   """
     59   return stride * effective_stride_output
     60 
     61 
     62 def _get_effective_padding_node_input(stride, padding,
     63                                       effective_padding_output):
     64   """Computes effective padding at the input of a given layer.
     65 
     66   Args:
     67     stride: Stride of given layer (integer).
     68     padding: Padding of given layer (integer).
     69     effective_padding_output: Effective padding at output of given layer
     70       (integer).
     71 
     72   Returns:
     73     effective_padding_input: Effective padding at input of given layer
     74       (integer).
     75   """
     76   return stride * effective_padding_output + padding
     77 
     78 
     79 class ReceptiveField(object):
     80   """Receptive field of a convolutional neural network.
     81 
     82   Args:
     83     size: Receptive field size.
     84     stride: Effective stride.
     85     padding: Effective padding.
     86   """
     87 
     88   def __init__(self, size, stride, padding):
     89     self.size = np.asarray(size)
     90     self.stride = np.asarray(stride)
     91     self.padding = np.asarray(padding)
     92 
     93   def compute_input_center_coordinates(self, y, axis=None):
     94     """Computes the center of the receptive field that generated a feature.
     95 
     96     Args:
     97       y: An array of feature coordinates with shape `(..., d)`, where `d` is the
     98         number of dimensions of the coordinates.
     99       axis: The dimensions for which to compute the input center coordinates.
    100         If `None` (the default), compute the input center coordinates for all
    101         dimensions.
    102 
    103     Returns:
    104       x: Center of the receptive field that generated the features, at the input
    105         of the network.
    106 
    107     Raises:
    108       ValueError: If the number of dimensions of the feature coordinates does
    109         not match the number of elements in `axis`.
    110     """
    111     # Use all dimensions.
    112     if axis is None:
    113       axis = range(self.size.size)
    114     # Ensure axis is a list because tuples have different indexing behavior.
    115     axis = list(axis)
    116     y = np.asarray(y)
    117     if y.shape[-1] != len(axis):
    118       raise ValueError("Dimensionality of the feature coordinates `y` (%d) "
    119                        "does not match dimensionality of `axis` (%d)" %
    120                        (y.shape[-1], len(axis)))
    121     return -self.padding[axis] + y * self.stride[axis] + (
    122         self.size[axis] - 1) / 2
    123 
    124   def compute_feature_coordinates(self, x, axis=None):
    125     """Computes the position of a feature given the center of a receptive field.
    126 
    127     Args:
    128       x: An array of input center coordinates with shape `(..., d)`, where `d`
    129         is the number of dimensions of the coordinates.
    130       axis: The dimensions for which to compute the feature coordinates.
    131         If `None` (the default), compute the feature coordinates for all
    132         dimensions.
    133 
    134     Returns:
    135       y: Coordinates of the features.
    136 
    137     Raises:
    138       ValueError: If the number of dimensions of the input center coordinates
    139         does not match the number of elements in `axis`.
    140     """
    141     # Use all dimensions.
    142     if axis is None:
    143       axis = range(self.size.size)
    144     # Ensure axis is a list because tuples have different indexing behavior.
    145     axis = list(axis)
    146     x = np.asarray(x)
    147     if x.shape[-1] != len(axis):
    148       raise ValueError("Dimensionality of the input center coordinates `x` "
    149                        "(%d) does not match dimensionality of `axis` (%d)" %
    150                        (x.shape[-1], len(axis)))
    151     return (x + self.padding[axis] +
    152             (1 - self.size[axis]) / 2) / self.stride[axis]
    153 
    154   def __iter__(self):
    155     return iter(np.concatenate([self.size, self.stride, self.padding]))
    156 
    157 
    158 def compute_receptive_field_from_graph_def(graph_def,
    159                                            input_node,
    160                                            output_node,
    161                                            stop_propagation=None,
    162                                            input_resolution=None):
    163   """Computes receptive field (RF) parameters from a Graph or GraphDef object.
    164 
    165   The algorithm stops the calculation of the receptive field whenever it
    166   encounters an operation in the list `stop_propagation`. Stopping the
    167   calculation early can be useful to calculate the receptive field of a
    168   subgraph such as a single branch of the
    169   [inception network](https://arxiv.org/abs/1512.00567).
    170 
    171   Args:
    172     graph_def: Graph or GraphDef object.
    173     input_node: Name of the input node or Tensor object from graph.
    174     output_node: Name of the output node or Tensor object from graph.
    175     stop_propagation: List of operations or scope names for which to stop the
    176       propagation of the receptive field.
    177     input_resolution: 2D list. If the input resolution to the model is fixed and
    178       known, this may be set. This is helpful for cases where the RF parameters
    179       vary depending on the input resolution (this happens since SAME padding in
    180       tensorflow depends on input resolution in general). If this is None, it is
    181       assumed that the input resolution is unknown, so some RF parameters may be
    182       unknown (depending on the model architecture).
    183 
    184   Returns:
    185     rf_size_x: Receptive field size of network in the horizontal direction, with
    186       respect to specified input and output.
    187     rf_size_y: Receptive field size of network in the vertical direction, with
    188       respect to specified input and output.
    189     effective_stride_x: Effective stride of network in the horizontal direction,
    190       with respect to specified input and output.
    191     effective_stride_y: Effective stride of network in the vertical direction,
    192       with respect to specified input and output.
    193     effective_padding_x: Effective padding of network in the horizontal
    194       direction, with respect to specified input and output.
    195     effective_padding_y: Effective padding of network in the vertical
    196       direction, with respect to specified input and output.
    197 
    198   Raises:
    199     ValueError: If network is not aligned or if either input or output nodes
    200       cannot be found. For network criterion alignment, see
    201       photos/vision/features/delf/g3doc/rf_computation.md
    202   """
    203   # Convert a graph to graph_def if necessary.
    204   if isinstance(graph_def, framework_ops.Graph):
    205     graph_def = graph_def.as_graph_def()
    206 
    207   # Convert tensors to names.
    208   if isinstance(input_node, framework_ops.Tensor):
    209     input_node = input_node.op.name
    210   if isinstance(output_node, framework_ops.Tensor):
    211     output_node = output_node.op.name
    212 
    213   stop_propagation = stop_propagation or []
    214 
    215   # Computes order of computation for a given graph.
    216   node_info, name_to_node = graph_compute_order.get_compute_order(
    217       graph_def=graph_def,
    218       input_node_name=input_node,
    219       input_node_size=input_resolution)
    220 
    221   # Sort in reverse topological order.
    222   ordered_node_info = sorted(node_info.items(), key=lambda x: -x[1].order)
    223 
    224   # Dictionaries to keep track of receptive field, effective stride and
    225   # effective padding of different nodes.
    226   rf_sizes_x = {}
    227   rf_sizes_y = {}
    228   effective_strides_x = {}
    229   effective_strides_y = {}
    230   effective_paddings_x = {}
    231   effective_paddings_y = {}
    232 
    233   # Initialize dicts for output_node.
    234   rf_sizes_x[output_node] = 1
    235   rf_sizes_y[output_node] = 1
    236   effective_strides_x[output_node] = 1
    237   effective_strides_y[output_node] = 1
    238   effective_paddings_x[output_node] = 0
    239   effective_paddings_y[output_node] = 0
    240 
    241   # Flag to denote if we found output node yet. If we have not, we skip nodes
    242   # until the output node is found.
    243   found_output_node = False
    244 
    245   # Flag to denote if padding is undefined. This happens when SAME padding mode
    246   # is used in conjunction with stride and kernel sizes which make it such that
    247   # the padding to be applied would depend on the input size. In this case,
    248   # alignment checks are skipped, and the effective padding is None.
    249   undefined_padding = False
    250 
    251   for _, (o, node, _, _) in ordered_node_info:
    252     if node:
    253       logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op))
    254     else:
    255       continue
    256 
    257     # When we find input node, we can stop.
    258     if node.name == input_node:
    259       break
    260 
    261     # Loop until we find the output node. All nodes before finding the output
    262     # one are irrelevant, so they can be skipped.
    263     if not found_output_node:
    264       if node.name == output_node:
    265         found_output_node = True
    266 
    267     if found_output_node:
    268       if node.name not in rf_sizes_x:
    269         assert node.name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but "
    270                                              "not in rf_sizes_x" % node.name)
    271         # In this case, node is not relevant since it's not part of the
    272         # computation we're interested in.
    273         logging.vlog(3, "Irrelevant node %s, skipping it...", node.name)
    274         continue
    275 
    276       # Get params for this layer.
    277       (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
    278        padding_y, _, _) = parse_layer_parameters.get_layer_params(
    279            node, name_to_node, node_info[node.name].input_size)
    280       logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, "
    281                    "stride_x = %s, stride_y = %s, "
    282                    "padding_x = %s, padding_y = %s, input size = %s" %
    283                    (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
    284                     padding_y, node_info[node.name].input_size))
    285       if padding_x is None or padding_y is None:
    286         undefined_padding = True
    287 
    288       # Get parameters at input of this layer which may or may not be propagated
    289       # to the input layers.
    290       rf_size_input_x = _get_rf_size_node_input(stride_x, kernel_size_x,
    291                                                 rf_sizes_x[node.name])
    292       rf_size_input_y = _get_rf_size_node_input(stride_y, kernel_size_y,
    293                                                 rf_sizes_y[node.name])
    294       effective_stride_input_x = _get_effective_stride_node_input(
    295           stride_x, effective_strides_x[node.name])
    296       effective_stride_input_y = _get_effective_stride_node_input(
    297           stride_y, effective_strides_y[node.name])
    298       if not undefined_padding:
    299         effective_padding_input_x = _get_effective_padding_node_input(
    300             stride_x, padding_x, effective_paddings_x[node.name])
    301         effective_padding_input_y = _get_effective_padding_node_input(
    302             stride_y, padding_y, effective_paddings_y[node.name])
    303       else:
    304         effective_padding_input_x = None
    305         effective_padding_input_y = None
    306       logging.vlog(
    307           4, "rf_size_input_x = %s, rf_size_input_y = %s, "
    308           "effective_stride_input_x = %s, effective_stride_input_y = %s, "
    309           "effective_padding_input_x = %s, effective_padding_input_y = %s" %
    310           (rf_size_input_x, rf_size_input_y, effective_stride_input_x,
    311            effective_stride_input_y, effective_padding_input_x,
    312            effective_padding_input_y))
    313 
    314       # Loop over this node's inputs and potentially propagate information down.
    315       for inp_name in node.input:
    316         # Stop the propagation of the receptive field.
    317         if any(inp_name.startswith(stop) for stop in stop_propagation):
    318           logging.vlog(3, "Skipping explicitly ignored node %s.", inp_name)
    319           continue
    320 
    321         logging.vlog(4, "inp_name = %s", inp_name)
    322         if inp_name.startswith("^"):
    323           # The character "^" denotes a control dependency, so this input node
    324           # can be safely ignored.
    325           continue
    326 
    327         inp_node = name_to_node[inp_name]
    328         logging.vlog(4, "inp_node = \n%s", inp_node)
    329         if inp_name in rf_sizes_x:
    330           assert inp_name in rf_sizes_y, ("Node %s is in rf_sizes_x, but "
    331                                           "not in rf_sizes_y" % inp_name)
    332           logging.vlog(
    333               4, "rf_sizes_x[inp_name] = %s,"
    334               " rf_sizes_y[inp_name] = %s, "
    335               "effective_strides_x[inp_name] = %s,"
    336               " effective_strides_y[inp_name] = %s, "
    337               "effective_paddings_x[inp_name] = %s,"
    338               " effective_paddings_y[inp_name] = %s" %
    339               (rf_sizes_x[inp_name], rf_sizes_y[inp_name],
    340                effective_strides_x[inp_name], effective_strides_y[inp_name],
    341                effective_paddings_x[inp_name], effective_paddings_y[inp_name]))
    342           # This node was already discovered through a previous path, so we need
    343           # to make sure that graph is aligned. This alignment check is skipped
    344           # if the padding is not defined, since in this case alignment cannot
    345           # be checked.
    346           if not undefined_padding:
    347             if effective_strides_x[inp_name] != effective_stride_input_x:
    348               raise ValueError(
    349                   "Graph is not aligned since effective stride from different "
    350                   "paths is different in horizontal direction")
    351             if effective_strides_y[inp_name] != effective_stride_input_y:
    352               raise ValueError(
    353                   "Graph is not aligned since effective stride from different "
    354                   "paths is different in vertical direction")
    355             if (rf_sizes_x[inp_name] - 1
    356                ) / 2 - effective_paddings_x[inp_name] != (
    357                    rf_size_input_x - 1) / 2 - effective_padding_input_x:
    358               raise ValueError(
    359                   "Graph is not aligned since center shift from different "
    360                   "paths is different in horizontal direction")
    361             if (rf_sizes_y[inp_name] - 1
    362                ) / 2 - effective_paddings_y[inp_name] != (
    363                    rf_size_input_y - 1) / 2 - effective_padding_input_y:
    364               raise ValueError(
    365                   "Graph is not aligned since center shift from different "
    366                   "paths is different in vertical direction")
    367           # Keep track of path with largest RF, for both directions.
    368           if rf_sizes_x[inp_name] < rf_size_input_x:
    369             rf_sizes_x[inp_name] = rf_size_input_x
    370             effective_strides_x[inp_name] = effective_stride_input_x
    371             effective_paddings_x[inp_name] = effective_padding_input_x
    372           if rf_sizes_y[inp_name] < rf_size_input_y:
    373             rf_sizes_y[inp_name] = rf_size_input_y
    374             effective_strides_y[inp_name] = effective_stride_input_y
    375             effective_paddings_y[inp_name] = effective_padding_input_y
    376         else:
    377           assert inp_name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but "
    378                                               "not in rf_sizes_x" % inp_name)
    379           # In this case, it is the first time we encounter this node. So we
    380           # propagate the RF parameters.
    381           rf_sizes_x[inp_name] = rf_size_input_x
    382           rf_sizes_y[inp_name] = rf_size_input_y
    383           effective_strides_x[inp_name] = effective_stride_input_x
    384           effective_strides_y[inp_name] = effective_stride_input_y
    385           effective_paddings_x[inp_name] = effective_padding_input_x
    386           effective_paddings_y[inp_name] = effective_padding_input_y
    387 
    388   if not found_output_node:
    389     raise ValueError("Output node was not found")
    390   if input_node not in rf_sizes_x:
    391     raise ValueError("Input node was not found")
    392   return ReceptiveField(
    393       (rf_sizes_x[input_node], rf_sizes_y[input_node]),
    394       (effective_strides_x[input_node], effective_strides_y[input_node]),
    395       (effective_paddings_x[input_node], effective_paddings_y[input_node]))
    396