1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Register flops statistics for various TensorFlow operations. 16 """ 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import graph_util 22 from tensorflow.python.framework import ops 23 24 25 # List of all ops which have implemented flops statistics. 26 IMPLEMENTED_OPS = set([ 27 # Unary ops 28 "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd", 29 "L2Loss", "Softmax", 30 # Binary ops 31 "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad", 32 "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual", 33 "SquaredDifference", 34 # Reduction ops 35 "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad", 36 # Convolution and pooling 37 "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput", 38 "Conv2DBackpropFilter", 39 # Other ops 40 "AddN", 41 # Ops implemented in core tensorflow: 42 "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D", 43 ]) 44 45 46 def _zero_flops(graph, node): 47 """Returns zero flops.""" 48 del graph, node # graph and node are unused 49 return ops.OpStats("flops", 0) 50 51 52 def _list_product(lst): 53 """Computes product of element of the list.""" 54 result = 1 55 for item in lst: 56 result *= item 57 return result 58 59 ################################################################################ 60 # Unary operations 61 ################################################################################ 62 63 64 def _unary_op_flops(graph, node, ops_per_element=1): 65 """Common code which compute flops for unary operations.""" 66 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 67 in_shape.assert_is_fully_defined() 68 return ops.OpStats("flops", in_shape.num_elements() * ops_per_element) 69 70 71 @ops.RegisterStatistics("Reciprocal", "flops") 72 def _reciprocal_flops(graph, node): 73 """Compute flops for Reciprocal operation.""" 74 return _unary_op_flops(graph, node) 75 76 77 @ops.RegisterStatistics("Square", "flops") 78 def _square_flops(graph, node): 79 """Compute flops for Square operation.""" 80 return _unary_op_flops(graph, node) 81 82 83 @ops.RegisterStatistics("Rsqrt", "flops") 84 def _rsqrt_flops(graph, node): 85 """Compute flops for Rsqrt operation.""" 86 # Rsqrt(x) = 1 / sqrt(x) 87 return _unary_op_flops(graph, node, ops_per_element=2) 88 89 90 @ops.RegisterStatistics("Log", "flops") 91 def _log_flops(graph, node): 92 """Compute flops for Log operation.""" 93 return _unary_op_flops(graph, node) 94 95 96 @ops.RegisterStatistics("Neg", "flops") 97 def _neg_flops(graph, node): 98 """Compute flops for Neg operation.""" 99 return _unary_op_flops(graph, node) 100 101 102 @ops.RegisterStatistics("AssignSub", "flops") 103 def _assign_sub_flops(graph, node): 104 """Compute flops for AssignSub operation.""" 105 return _unary_op_flops(graph, node) 106 107 108 @ops.RegisterStatistics("AssignAdd", "flops") 109 def _assign_add_flops(graph, node): 110 """Compute flops for AssignAdd operation.""" 111 return _unary_op_flops(graph, node) 112 113 114 @ops.RegisterStatistics("L2Loss", "flops") 115 def _l2_loss_flops(graph, node): 116 """Compute flops for L2Loss operation.""" 117 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 118 in_shape.assert_is_fully_defined() 119 # Tensorflow uses inefficient implementation, with (3*N-1) flops: 120 # Optimal implementation is 2*N flops 121 return ops.OpStats("flops", in_shape.num_elements() * 3 - 1) 122 123 124 @ops.RegisterStatistics("Softmax", "flops") 125 def _softmax_flops(graph, node): 126 """Compute flops for Softmax operation.""" 127 # Softmax implenetation: 128 # 129 # Approximate flops breakdown: 130 # 2*n -- compute shifted logits 131 # n -- exp of shifted logits 132 # 2*n -- compute softmax from exp of shifted logits 133 return _unary_op_flops(graph, node, ops_per_element=5) 134 135 ################################################################################ 136 # Binary operations 137 ################################################################################ 138 139 140 def _binary_per_element_op_flops(graph, node, ops_per_element=1): 141 """Common code which compute flops for binary operations.""" 142 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 143 out_shape.assert_is_fully_defined() 144 return ops.OpStats("flops", out_shape.num_elements() * ops_per_element) 145 146 147 @ops.RegisterStatistics("Add", "flops") 148 def _add_flops(graph, node): 149 """Compute flops for Add operation.""" 150 return _binary_per_element_op_flops(graph, node) 151 152 153 @ops.RegisterStatistics("Sub", "flops") 154 def _sub_flops(graph, node): 155 """Compute flops for Sub operation.""" 156 return _binary_per_element_op_flops(graph, node) 157 158 159 @ops.RegisterStatistics("Mul", "flops") 160 def _mul_flops(graph, node): 161 """Compute flops for Mul operation.""" 162 return _binary_per_element_op_flops(graph, node) 163 164 165 @ops.RegisterStatistics("RealDiv", "flops") 166 def _real_div_flops(graph, node): 167 """Compute flops for RealDiv operation.""" 168 return _binary_per_element_op_flops(graph, node) 169 170 171 @ops.RegisterStatistics("Maximum", "flops") 172 def _maximum_flops(graph, node): 173 """Compute flops for Maximum operation.""" 174 return _binary_per_element_op_flops(graph, node) 175 176 177 @ops.RegisterStatistics("Minimum", "flops") 178 def _minimum_flops(graph, node): 179 """Compute flops for Minimum operation.""" 180 return _binary_per_element_op_flops(graph, node) 181 182 183 @ops.RegisterStatistics("Pow", "flops") 184 def _pow_flops(graph, node): 185 """Compute flops for Pow operation.""" 186 return _binary_per_element_op_flops(graph, node) 187 188 189 @ops.RegisterStatistics("RsqrtGrad", "flops") 190 def _rsqrt_grad_flops(graph, node): 191 """Compute flops for RsqrtGrad operation.""" 192 return _binary_per_element_op_flops(graph, node, ops_per_element=4) 193 194 195 @ops.RegisterStatistics("GreaterEqual", "flops") 196 def _greater_equal_flops(graph, node): 197 """Compute flops for GreaterEqual operation.""" 198 return _binary_per_element_op_flops(graph, node) 199 200 201 @ops.RegisterStatistics("Greater", "flops") 202 def _greater_flops(graph, node): 203 """Compute flops for Greater operation.""" 204 return _binary_per_element_op_flops(graph, node) 205 206 207 @ops.RegisterStatistics("LessEqual", "flops") 208 def _less_equal_flops(graph, node): 209 """Compute flops for LessEqual operation.""" 210 return _binary_per_element_op_flops(graph, node) 211 212 213 @ops.RegisterStatistics("Less", "flops") 214 def _less_flops(graph, node): 215 """Compute flops for Less operation.""" 216 return _binary_per_element_op_flops(graph, node) 217 218 219 @ops.RegisterStatistics("Equal", "flops") 220 def _equal_flops(graph, node): 221 """Compute flops for Equal operation.""" 222 return _binary_per_element_op_flops(graph, node) 223 224 225 @ops.RegisterStatistics("NotEqual", "flops") 226 def _not_equal_flops(graph, node): 227 """Compute flops for NotEqual operation.""" 228 return _binary_per_element_op_flops(graph, node) 229 230 231 @ops.RegisterStatistics("SquaredDifference", "flops") 232 def _squared_difference_flops(graph, node): 233 """Compute flops for SquaredDifference operation.""" 234 return _binary_per_element_op_flops(graph, node, ops_per_element=2) 235 236 ################################################################################ 237 # Reduction ops 238 ################################################################################ 239 240 241 def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0): 242 """Common code which compute flops for reduction operations.""" 243 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 244 in_shape.assert_is_fully_defined() 245 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 246 out_shape.assert_is_fully_defined() 247 num_flops = (in_shape.num_elements() * reduce_flops 248 + out_shape.num_elements() * (finalize_flops - reduce_flops)) 249 return ops.OpStats("flops", num_flops) 250 251 252 @ops.RegisterStatistics("Mean", "flops") 253 def _mean_flops(graph, node): 254 """Compute flops for Mean operation.""" 255 # reduction - sum, finalization - divide 256 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1) 257 258 259 @ops.RegisterStatistics("Sum", "flops") 260 def _sum_flops(graph, node): 261 """Compute flops for Sum operation.""" 262 # reduction - sum, no finalization 263 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 264 265 266 @ops.RegisterStatistics("ArgMax", "flops") 267 def _arg_max_flops(graph, node): 268 """Compute flops for ArgMax operation.""" 269 # reduction - comparison, no finalization 270 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 271 272 273 @ops.RegisterStatistics("ArgMin", "flops") 274 def _arg_min_flops(graph, node): 275 """Compute flops for ArgMin operation.""" 276 # reduction - comparison, no finalization 277 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 278 279 280 @ops.RegisterStatistics("BiasAddGrad", "flops") 281 def _bias_add_grad_flops(graph, node): 282 """Compute flops for BiasAddGrad operation.""" 283 # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping: 284 # So computing flops same way as for "Sum" 285 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 286 287 ################################################################################ 288 # Convolution and pooling 289 # Note: all flops statistics are implemented only for NHWC data format 290 ################################################################################ 291 292 293 def _verify_conv_data_format(node): 294 """Verifies data format for pooling and convolutional operations.""" 295 # TODO(xpan): P1: Support NCHW 296 if node.attr["data_format"].s != b"NHWC": 297 raise ValueError("Only NHWC format is supported in flops computations") 298 299 300 def _pool_flops(graph, node): 301 """Common code which compute flops for pooling operations.""" 302 # compute flops for average and max pooling 303 _verify_conv_data_format(node) 304 # 305 # Pooling declaration: 306 # Inputs: 307 # - value 308 # Outputs: 309 # - output 310 # Attributes: 311 # - ksize 312 # - strides 313 # - padding 314 # - data_format 315 # 316 # Pooling implenetation: 317 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 318 out_shape.assert_is_fully_defined() 319 kernel_shape = list(node.attr["ksize"].list.i) 320 kernel_area = _list_product(kernel_shape) 321 return ops.OpStats("flops", kernel_area * out_shape.num_elements()) 322 323 324 @ops.RegisterStatistics("AvgPool", "flops") 325 def _avg_pool_flops(graph, node): 326 """Compute flops for AvgPool operation.""" 327 return _pool_flops(graph, node) 328 329 330 @ops.RegisterStatistics("MaxPool", "flops") 331 def _max_pool_flops(graph, node): 332 """Compute flops for MaxPool operation.""" 333 return _pool_flops(graph, node) 334 335 336 @ops.RegisterStatistics("AvgPoolGrad", "flops") 337 def _avg_pool_grad_flops(graph, node): 338 """Compute flops for AvgPoolGrad operation.""" 339 _verify_conv_data_format(node) 340 # Pooling gradient implementation: 341 out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph, 342 node.input[1]) 343 out_backprop_shape.assert_is_fully_defined() 344 kernel_shape = list(node.attr["ksize"].list.i) 345 kernel_area = _list_product(kernel_shape) 346 # TensorFlow multiply each element of pooling window by coefficient, 347 # then sum up all of them, thus we have 2 flops per element: 348 # More optimal implementation - if division is done after. 349 return ops.OpStats("flops", 350 kernel_area * out_backprop_shape.num_elements() * 2) 351 352 353 @ops.RegisterStatistics("MaxPoolGrad", "flops") 354 def _max_pool_grad_flops(graph, node): 355 """Compute flops for MaxPoolGrad operation.""" 356 _verify_conv_data_format(node) 357 # 358 # MaxPoolGrad declaration: 359 # Inputs: 360 # - orig_input -- original input tensor (of max_pool) 361 # - orig_output -- original output tensor (of max_pool) 362 # - grad -- gradient with respect to output of max_pool 363 # Outputs: 364 # - output -- gradient with respect to input of max_pool 365 # Attributes: 366 # - ksize 367 # - strides 368 # - padding 369 # - data_format 370 # It computes MaxPool first, then one flop per each element of original output 371 # 372 kernel_shape = list(node.attr["ksize"].list.i) 373 kernel_area = _list_product(kernel_shape) 374 orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph, 375 node.input[1]) 376 orig_out_shape.assert_is_fully_defined() 377 max_pool_ops = kernel_area * orig_out_shape.num_elements() 378 return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements()) 379 380 381 @ops.RegisterStatistics("Conv2DBackpropInput", "flops") 382 def _conv_2d_backprop_input_flops(graph, node): 383 """Compute flops for Conv2DBackpropInput operation.""" 384 # Formula: 385 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim 386 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) 387 # 388 # Where: 389 # image_x_dim, image_y_dim and input_depth --- size of input to source (no 390 # backprop) convolution, in other words they are sizes of backprop output. 391 # output_depth --- number of filters in the original convolution, thus 392 # depth of backprop input. 393 # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension 394 # image_x_stride and image_x_stride --- strides of the convolution 395 # 396 _verify_conv_data_format(node) 397 # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth] 398 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 399 out_shape.assert_is_fully_defined() 400 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] 401 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, 402 node.input[1]) 403 kernel_shape.assert_is_fully_defined() 404 # strides 405 strides_shape = list(node.attr["strides"].list.i) 406 strides_product = strides_shape[1] * strides_shape[2] 407 return ops.OpStats("flops", 408 (2 * out_shape.num_elements() 409 * kernel_shape.num_elements() 410 / (out_shape[-1].value * strides_product))) 411 412 413 @ops.RegisterStatistics("Conv2DBackpropFilter", "flops") 414 def _conv_2d_backprop_filter_flops(graph, node): 415 """Compute flops for Conv2DBackpropFilter operation.""" 416 # Formula same as for Conv2DBackpropInput: 417 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim 418 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) 419 # 420 _verify_conv_data_format(node) 421 # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth] 422 image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 423 image_shape.assert_is_fully_defined() 424 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] 425 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 426 kernel_shape.assert_is_fully_defined() 427 # strides 428 strides_shape = list(node.attr["strides"].list.i) 429 strides_product = strides_shape[1] * strides_shape[2] 430 return ops.OpStats("flops", 431 (2 * image_shape.num_elements() 432 * kernel_shape.num_elements() 433 / (image_shape[-1].value * strides_product))) 434 435 ################################################################################ 436 # Other ops 437 ################################################################################ 438 439 440 @ops.RegisterStatistics("AddN", "flops") 441 def _add_n_flops(graph, node): 442 """Compute flops for AddN operation.""" 443 if not node.input: 444 return _zero_flops(graph, node) 445 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 446 in_shape.assert_is_fully_defined() 447 return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1)) 448