Home | History | Annotate | Download | only in slim
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tools for analyzing the operations and variables in a TensorFlow graph.
     16 
     17 To analyze the operations in a graph:
     18 
     19   images, labels = LoadData(...)
     20   predictions = MyModel(images)
     21 
     22   slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True)
     23 
     24 To analyze the model variables in a graph:
     25 
     26   variables = tf.model_variables()
     27   slim.model_analyzer.analyze_vars(variables, print_info=False)
     28 """
     29 from __future__ import absolute_import
     30 from __future__ import division
     31 from __future__ import print_function
     32 
     33 
     34 def tensor_description(var):
     35   """Returns a compact and informative string about a tensor.
     36 
     37   Args:
     38     var: A tensor variable.
     39 
     40   Returns:
     41     a string with type and size, e.g.: (float32 1x8x8x1024).
     42   """
     43   description = '(' + str(var.dtype.name) + ' '
     44   sizes = var.get_shape()
     45   for i, size in enumerate(sizes):
     46     description += str(size)
     47     if i < len(sizes) - 1:
     48       description += 'x'
     49   description += ')'
     50   return description
     51 
     52 
     53 def analyze_ops(graph, print_info=False):
     54   """Compute the estimated size of the ops.outputs in the graph.
     55 
     56   Args:
     57     graph: the graph containing the operations.
     58     print_info: Optional, if true print ops and their outputs.
     59 
     60   Returns:
     61     total size of the ops.outputs
     62   """
     63   if print_info:
     64     print('---------')
     65     print('Operations: name -> (type shapes) [size]')
     66     print('---------')
     67   total_size = 0
     68   for op in graph.get_operations():
     69     op_size = 0
     70     shapes = []
     71     for output in op.outputs:
     72       # if output.num_elements() is None or [] assume size 0.
     73       output_size = output.get_shape().num_elements() or 0
     74       if output.get_shape():
     75         shapes.append(tensor_description(output))
     76       op_size += output_size
     77     if print_info:
     78       print(op.name, '\t->', ', '.join(shapes), '[' + str(op_size) + ']')
     79     total_size += op_size
     80   return total_size
     81 
     82 
     83 def analyze_vars(variables, print_info=False):
     84   """Prints the names and shapes of the variables.
     85 
     86   Args:
     87     variables: list of variables, for example tf.global_variables().
     88     print_info: Optional, if true print variables and their shape.
     89 
     90   Returns:
     91     (total size of the variables, total bytes of the variables)
     92   """
     93   if print_info:
     94     print('---------')
     95     print('Variables: name (type shape) [size]')
     96     print('---------')
     97   total_size = 0
     98   total_bytes = 0
     99   for var in variables:
    100     # if var.num_elements() is None or [] assume size 0.
    101     var_size = var.get_shape().num_elements() or 0
    102     var_bytes = var_size * var.dtype.size
    103     total_size += var_size
    104     total_bytes += var_bytes
    105     if print_info:
    106       print(var.name, tensor_description(var), '[%d, bytes: %d]' %
    107             (var_size, var_bytes))
    108   if print_info:
    109     print('Total size of variables: %d' % total_size)
    110     print('Total bytes of variables: %d' % total_bytes)
    111   return total_size, total_bytes
    112