Home | History | Annotate | Download | only in internal
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Register flops statistics for various TensorFlow operations.
     16 """
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import graph_util
     22 from tensorflow.python.framework import ops
     23 
     24 
     25 # List of all ops which have implemented flops statistics.
     26 IMPLEMENTED_OPS = set([
     27     # Unary ops
     28     "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
     29     "L2Loss", "Softmax",
     30     # Binary ops
     31     "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
     32     "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
     33     "SquaredDifference",
     34     # Reduction ops
     35     "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
     36     # Convolution and pooling
     37     "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
     38     "Conv2DBackpropFilter",
     39     # Other ops
     40     "AddN",
     41     # Ops implemented in core tensorflow:
     42     "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
     43 ])
     44 
     45 
     46 def _zero_flops(graph, node):
     47   """Returns zero flops."""
     48   del graph, node  # graph and node are unused
     49   return ops.OpStats("flops", 0)
     50 
     51 
     52 def _list_product(lst):
     53   """Computes product of element of the list."""
     54   result = 1
     55   for item in lst:
     56     result *= item
     57   return result
     58 
     59 ################################################################################
     60 # Unary operations
     61 ################################################################################
     62 
     63 
     64 def _unary_op_flops(graph, node, ops_per_element=1):
     65   """Common code which compute flops for unary operations."""
     66   in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
     67   in_shape.assert_is_fully_defined()
     68   return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
     69 
     70 
     71 @ops.RegisterStatistics("Reciprocal", "flops")
     72 def _reciprocal_flops(graph, node):
     73   """Compute flops for Reciprocal operation."""
     74   return _unary_op_flops(graph, node)
     75 
     76 
     77 @ops.RegisterStatistics("Square", "flops")
     78 def _square_flops(graph, node):
     79   """Compute flops for Square operation."""
     80   return _unary_op_flops(graph, node)
     81 
     82 
     83 @ops.RegisterStatistics("Rsqrt", "flops")
     84 def _rsqrt_flops(graph, node):
     85   """Compute flops for Rsqrt operation."""
     86   # Rsqrt(x) = 1 / sqrt(x)
     87   return _unary_op_flops(graph, node, ops_per_element=2)
     88 
     89 
     90 @ops.RegisterStatistics("Log", "flops")
     91 def _log_flops(graph, node):
     92   """Compute flops for Log operation."""
     93   return _unary_op_flops(graph, node)
     94 
     95 
     96 @ops.RegisterStatistics("Neg", "flops")
     97 def _neg_flops(graph, node):
     98   """Compute flops for Neg operation."""
     99   return _unary_op_flops(graph, node)
    100 
    101 
    102 @ops.RegisterStatistics("AssignSub", "flops")
    103 def _assign_sub_flops(graph, node):
    104   """Compute flops for AssignSub operation."""
    105   return _unary_op_flops(graph, node)
    106 
    107 
    108 @ops.RegisterStatistics("AssignAdd", "flops")
    109 def _assign_add_flops(graph, node):
    110   """Compute flops for AssignAdd operation."""
    111   return _unary_op_flops(graph, node)
    112 
    113 
    114 @ops.RegisterStatistics("L2Loss", "flops")
    115 def _l2_loss_flops(graph, node):
    116   """Compute flops for L2Loss operation."""
    117   in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
    118   in_shape.assert_is_fully_defined()
    119   # Tensorflow uses inefficient implementation, with (3*N-1) flops:
    120   # Optimal implementation is 2*N flops
    121   return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
    122 
    123 
    124 @ops.RegisterStatistics("Softmax", "flops")
    125 def _softmax_flops(graph, node):
    126   """Compute flops for Softmax operation."""
    127   # Softmax implenetation:
    128   #
    129   # Approximate flops breakdown:
    130   #   2*n          -- compute shifted logits
    131   #   n            -- exp of shifted logits
    132   #   2*n          -- compute softmax from exp of shifted logits
    133   return _unary_op_flops(graph, node, ops_per_element=5)
    134 
    135 ################################################################################
    136 # Binary operations
    137 ################################################################################
    138 
    139 
    140 def _binary_per_element_op_flops(graph, node, ops_per_element=1):
    141   """Common code which compute flops for binary operations."""
    142   out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
    143   out_shape.assert_is_fully_defined()
    144   return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
    145 
    146 
    147 @ops.RegisterStatistics("Add", "flops")
    148 def _add_flops(graph, node):
    149   """Compute flops for Add operation."""
    150   return _binary_per_element_op_flops(graph, node)
    151 
    152 
    153 @ops.RegisterStatistics("Sub", "flops")
    154 def _sub_flops(graph, node):
    155   """Compute flops for Sub operation."""
    156   return _binary_per_element_op_flops(graph, node)
    157 
    158 
    159 @ops.RegisterStatistics("Mul", "flops")
    160 def _mul_flops(graph, node):
    161   """Compute flops for Mul operation."""
    162   return _binary_per_element_op_flops(graph, node)
    163 
    164 
    165 @ops.RegisterStatistics("RealDiv", "flops")
    166 def _real_div_flops(graph, node):
    167   """Compute flops for RealDiv operation."""
    168   return _binary_per_element_op_flops(graph, node)
    169 
    170 
    171 @ops.RegisterStatistics("Maximum", "flops")
    172 def _maximum_flops(graph, node):
    173   """Compute flops for Maximum operation."""
    174   return _binary_per_element_op_flops(graph, node)
    175 
    176 
    177 @ops.RegisterStatistics("Minimum", "flops")
    178 def _minimum_flops(graph, node):
    179   """Compute flops for Minimum operation."""
    180   return _binary_per_element_op_flops(graph, node)
    181 
    182 
    183 @ops.RegisterStatistics("Pow", "flops")
    184 def _pow_flops(graph, node):
    185   """Compute flops for Pow operation."""
    186   return _binary_per_element_op_flops(graph, node)
    187 
    188 
    189 @ops.RegisterStatistics("RsqrtGrad", "flops")
    190 def _rsqrt_grad_flops(graph, node):
    191   """Compute flops for RsqrtGrad operation."""
    192   return _binary_per_element_op_flops(graph, node, ops_per_element=4)
    193 
    194 
    195 @ops.RegisterStatistics("GreaterEqual", "flops")
    196 def _greater_equal_flops(graph, node):
    197   """Compute flops for GreaterEqual operation."""
    198   return _binary_per_element_op_flops(graph, node)
    199 
    200 
    201 @ops.RegisterStatistics("Greater", "flops")
    202 def _greater_flops(graph, node):
    203   """Compute flops for Greater operation."""
    204   return _binary_per_element_op_flops(graph, node)
    205 
    206 
    207 @ops.RegisterStatistics("LessEqual", "flops")
    208 def _less_equal_flops(graph, node):
    209   """Compute flops for LessEqual operation."""
    210   return _binary_per_element_op_flops(graph, node)
    211 
    212 
    213 @ops.RegisterStatistics("Less", "flops")
    214 def _less_flops(graph, node):
    215   """Compute flops for Less operation."""
    216   return _binary_per_element_op_flops(graph, node)
    217 
    218 
    219 @ops.RegisterStatistics("Equal", "flops")
    220 def _equal_flops(graph, node):
    221   """Compute flops for Equal operation."""
    222   return _binary_per_element_op_flops(graph, node)
    223 
    224 
    225 @ops.RegisterStatistics("NotEqual", "flops")
    226 def _not_equal_flops(graph, node):
    227   """Compute flops for NotEqual operation."""
    228   return _binary_per_element_op_flops(graph, node)
    229 
    230 
    231 @ops.RegisterStatistics("SquaredDifference", "flops")
    232 def _squared_difference_flops(graph, node):
    233   """Compute flops for SquaredDifference operation."""
    234   return _binary_per_element_op_flops(graph, node, ops_per_element=2)
    235 
    236 ################################################################################
    237 # Reduction ops
    238 ################################################################################
    239 
    240 
    241 def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
    242   """Common code which compute flops for reduction operations."""
    243   in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
    244   in_shape.assert_is_fully_defined()
    245   out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
    246   out_shape.assert_is_fully_defined()
    247   num_flops = (in_shape.num_elements() * reduce_flops
    248                + out_shape.num_elements() * (finalize_flops - reduce_flops))
    249   return ops.OpStats("flops", num_flops)
    250 
    251 
    252 @ops.RegisterStatistics("Mean", "flops")
    253 def _mean_flops(graph, node):
    254   """Compute flops for Mean operation."""
    255   # reduction - sum, finalization - divide
    256   return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
    257 
    258 
    259 @ops.RegisterStatistics("Sum", "flops")
    260 def _sum_flops(graph, node):
    261   """Compute flops for Sum operation."""
    262   # reduction - sum, no finalization
    263   return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
    264 
    265 
    266 @ops.RegisterStatistics("ArgMax", "flops")
    267 def _arg_max_flops(graph, node):
    268   """Compute flops for ArgMax operation."""
    269   # reduction - comparison, no finalization
    270   return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
    271 
    272 
    273 @ops.RegisterStatistics("ArgMin", "flops")
    274 def _arg_min_flops(graph, node):
    275   """Compute flops for ArgMin operation."""
    276   # reduction - comparison, no finalization
    277   return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
    278 
    279 
    280 @ops.RegisterStatistics("BiasAddGrad", "flops")
    281 def _bias_add_grad_flops(graph, node):
    282   """Compute flops for BiasAddGrad operation."""
    283   # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
    284   # So computing flops same way as for "Sum"
    285   return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
    286 
    287 ################################################################################
    288 # Convolution and pooling
    289 # Note: all flops statistics are implemented only for NHWC data format
    290 ################################################################################
    291 
    292 
    293 def _verify_conv_data_format(node):
    294   """Verifies data format for pooling and convolutional operations."""
    295   # TODO(xpan): P1: Support NCHW
    296   if node.attr["data_format"].s != b"NHWC":
    297     raise ValueError("Only NHWC format is supported in flops computations")
    298 
    299 
    300 def _pool_flops(graph, node):
    301   """Common code which compute flops for pooling operations."""
    302   # compute flops for average and max pooling
    303   _verify_conv_data_format(node)
    304   #
    305   # Pooling declaration:
    306   #   Inputs:
    307   #     - value
    308   #   Outputs:
    309   #     - output
    310   #   Attributes:
    311   #     - ksize
    312   #     - strides
    313   #     - padding
    314   #     - data_format
    315   #
    316   # Pooling implenetation:
    317   out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
    318   out_shape.assert_is_fully_defined()
    319   kernel_shape = list(node.attr["ksize"].list.i)
    320   kernel_area = _list_product(kernel_shape)
    321   return ops.OpStats("flops", kernel_area * out_shape.num_elements())
    322 
    323 
    324 @ops.RegisterStatistics("AvgPool", "flops")
    325 def _avg_pool_flops(graph, node):
    326   """Compute flops for AvgPool operation."""
    327   return _pool_flops(graph, node)
    328 
    329 
    330 @ops.RegisterStatistics("MaxPool", "flops")
    331 def _max_pool_flops(graph, node):
    332   """Compute flops for MaxPool operation."""
    333   return _pool_flops(graph, node)
    334 
    335 
    336 @ops.RegisterStatistics("AvgPoolGrad", "flops")
    337 def _avg_pool_grad_flops(graph, node):
    338   """Compute flops for AvgPoolGrad operation."""
    339   _verify_conv_data_format(node)
    340   # Pooling gradient implementation:
    341   out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
    342                                                                   node.input[1])
    343   out_backprop_shape.assert_is_fully_defined()
    344   kernel_shape = list(node.attr["ksize"].list.i)
    345   kernel_area = _list_product(kernel_shape)
    346   # TensorFlow multiply each element of pooling window by coefficient,
    347   # then sum up all of them, thus we have 2 flops per element:
    348   # More optimal implementation - if division is done after.
    349   return ops.OpStats("flops",
    350                      kernel_area * out_backprop_shape.num_elements() * 2)
    351 
    352 
    353 @ops.RegisterStatistics("MaxPoolGrad", "flops")
    354 def _max_pool_grad_flops(graph, node):
    355   """Compute flops for MaxPoolGrad operation."""
    356   _verify_conv_data_format(node)
    357   #
    358   # MaxPoolGrad declaration:
    359   #   Inputs:
    360   #     - orig_input  -- original input tensor (of max_pool)
    361   #     - orig_output  -- original output tensor (of max_pool)
    362   #     - grad --  gradient with respect to output of max_pool
    363   #   Outputs:
    364   #     - output -- gradient with respect to input of max_pool
    365   #   Attributes:
    366   #     - ksize
    367   #     - strides
    368   #     - padding
    369   #     - data_format
    370   # It computes MaxPool first, then one flop per each element of original output
    371   #
    372   kernel_shape = list(node.attr["ksize"].list.i)
    373   kernel_area = _list_product(kernel_shape)
    374   orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
    375                                                               node.input[1])
    376   orig_out_shape.assert_is_fully_defined()
    377   max_pool_ops = kernel_area * orig_out_shape.num_elements()
    378   return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
    379 
    380 
    381 @ops.RegisterStatistics("Conv2DBackpropInput", "flops")
    382 def _conv_2d_backprop_input_flops(graph, node):
    383   """Compute flops for Conv2DBackpropInput operation."""
    384   # Formula:
    385   #  batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
    386   #  * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
    387   #
    388   # Where:
    389   # image_x_dim, image_y_dim and input_depth --- size of input to source (no
    390   #   backprop) convolution, in other words they are sizes of backprop output.
    391   # output_depth --- number of filters in the original convolution, thus
    392   #   depth of backprop input.
    393   # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
    394   # image_x_stride and image_x_stride --- strides of the convolution
    395   #
    396   _verify_conv_data_format(node)
    397   # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
    398   out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
    399   out_shape.assert_is_fully_defined()
    400   # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
    401   kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
    402                                                             node.input[1])
    403   kernel_shape.assert_is_fully_defined()
    404   # strides
    405   strides_shape = list(node.attr["strides"].list.i)
    406   strides_product = strides_shape[1] * strides_shape[2]
    407   return ops.OpStats("flops",
    408                      (2 * out_shape.num_elements()
    409                       * kernel_shape.num_elements()
    410                       / (out_shape[-1].value * strides_product)))
    411 
    412 
    413 @ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
    414 def _conv_2d_backprop_filter_flops(graph, node):
    415   """Compute flops for Conv2DBackpropFilter operation."""
    416   # Formula same as for Conv2DBackpropInput:
    417   #  batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
    418   #  * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
    419   #
    420   _verify_conv_data_format(node)
    421   # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
    422   image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
    423   image_shape.assert_is_fully_defined()
    424   # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
    425   kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
    426   kernel_shape.assert_is_fully_defined()
    427   # strides
    428   strides_shape = list(node.attr["strides"].list.i)
    429   strides_product = strides_shape[1] * strides_shape[2]
    430   return ops.OpStats("flops",
    431                      (2 * image_shape.num_elements()
    432                       * kernel_shape.num_elements()
    433                       / (image_shape[-1].value * strides_product)))
    434 
    435 ################################################################################
    436 # Other ops
    437 ################################################################################
    438 
    439 
    440 @ops.RegisterStatistics("AddN", "flops")
    441 def _add_n_flops(graph, node):
    442   """Compute flops for AddN operation."""
    443   if not node.input:
    444     return _zero_flops(graph, node)
    445   in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
    446   in_shape.assert_is_fully_defined()
    447   return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
    448