1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tools for analyzing the operations and variables in a TensorFlow graph. 16 17 To analyze the operations in a graph: 18 19 images, labels = LoadData(...) 20 predictions = MyModel(images) 21 22 slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True) 23 24 To analyze the model variables in a graph: 25 26 variables = tf.model_variables() 27 slim.model_analyzer.analyze_vars(variables, print_info=False) 28 """ 29 from __future__ import absolute_import 30 from __future__ import division 31 from __future__ import print_function 32 33 34 def tensor_description(var): 35 """Returns a compact and informative string about a tensor. 36 37 Args: 38 var: A tensor variable. 39 40 Returns: 41 a string with type and size, e.g.: (float32 1x8x8x1024). 42 """ 43 description = '(' + str(var.dtype.name) + ' ' 44 sizes = var.get_shape() 45 for i, size in enumerate(sizes): 46 description += str(size) 47 if i < len(sizes) - 1: 48 description += 'x' 49 description += ')' 50 return description 51 52 53 def analyze_ops(graph, print_info=False): 54 """Compute the estimated size of the ops.outputs in the graph. 55 56 Args: 57 graph: the graph containing the operations. 58 print_info: Optional, if true print ops and their outputs. 59 60 Returns: 61 total size of the ops.outputs 62 """ 63 if print_info: 64 print('---------') 65 print('Operations: name -> (type shapes) [size]') 66 print('---------') 67 total_size = 0 68 for op in graph.get_operations(): 69 op_size = 0 70 shapes = [] 71 for output in op.outputs: 72 # if output.num_elements() is None or [] assume size 0. 73 output_size = output.get_shape().num_elements() or 0 74 if output.get_shape(): 75 shapes.append(tensor_description(output)) 76 op_size += output_size 77 if print_info: 78 print(op.name, '\t->', ', '.join(shapes), '[' + str(op_size) + ']') 79 total_size += op_size 80 return total_size 81 82 83 def analyze_vars(variables, print_info=False): 84 """Prints the names and shapes of the variables. 85 86 Args: 87 variables: list of variables, for example tf.global_variables(). 88 print_info: Optional, if true print variables and their shape. 89 90 Returns: 91 (total size of the variables, total bytes of the variables) 92 """ 93 if print_info: 94 print('---------') 95 print('Variables: name (type shape) [size]') 96 print('---------') 97 total_size = 0 98 total_bytes = 0 99 for var in variables: 100 # if var.num_elements() is None or [] assume size 0. 101 var_size = var.get_shape().num_elements() or 0 102 var_bytes = var_size * var.dtype.size 103 total_size += var_size 104 total_bytes += var_bytes 105 if print_info: 106 print(var.name, tensor_description(var), '[%d, bytes: %d]' % 107 (var_size, var_bytes)) 108 if print_info: 109 print('Total size of variables: %d' % total_size) 110 print('Total bytes of variables: %d' % total_bytes) 111 return total_size, total_bytes 112