Home | History | Annotate | Download | only in tools
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""Computes a header file to be used with SELECTIVE_REGISTRATION.
     16 
     17 See the executable wrapper, print_selective_registration_header.py, for more
     18 information.
     19 """
     20 
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 import os
     26 import sys
     27 
     28 from google.protobuf import text_format
     29 
     30 from tensorflow.core.framework import graph_pb2
     31 from tensorflow.python import pywrap_tensorflow
     32 from tensorflow.python.platform import gfile
     33 from tensorflow.python.platform import tf_logging
     34 
     35 
     36 def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
     37   """Gets the ops and kernels needed from the model files."""
     38   ops = set()
     39 
     40   for proto_file in proto_files:
     41     tf_logging.info('Loading proto file %s', proto_file)
     42     # Load GraphDef.
     43     file_data = gfile.GFile(proto_file, 'rb').read()
     44     if proto_fileformat == 'rawproto':
     45       graph_def = graph_pb2.GraphDef.FromString(file_data)
     46     else:
     47       assert proto_fileformat == 'textproto'
     48       graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())
     49 
     50     # Find all ops and kernels used by the graph.
     51     for node_def in graph_def.node:
     52       if not node_def.device:
     53         node_def.device = '/cpu:0'
     54       kernel_class = pywrap_tensorflow.TryFindKernelClass(
     55           node_def.SerializeToString())
     56       if kernel_class:
     57         op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8')))
     58         if op_and_kernel not in ops:
     59           ops.add(op_and_kernel)
     60       else:
     61         print(
     62             'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)
     63 
     64   # Add default ops.
     65   if default_ops_str and default_ops_str != 'all':
     66     for s in default_ops_str.split(','):
     67       op, kernel = s.split(':')
     68       op_and_kernel = (op, kernel)
     69       if op_and_kernel not in ops:
     70         ops.add(op_and_kernel)
     71 
     72   return list(sorted(ops))
     73 
     74 
     75 def get_header_from_ops_and_kernels(ops_and_kernels,
     76                                     include_all_ops_and_kernels):
     77   """Returns a header for use with tensorflow SELECTIVE_REGISTRATION.
     78 
     79   Args:
     80     ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include.
     81     include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op
     82     kernels are included.
     83 
     84   Returns:
     85     the string of the header that should be written as ops_to_register.h.
     86   """
     87   ops = set([op for op, _ in ops_and_kernels])
     88   result_list = []
     89 
     90   def append(s):
     91     result_list.append(s)
     92 
     93   _, script_name = os.path.split(sys.argv[0])
     94   append('// This file was autogenerated by %s' % script_name)
     95   append('#ifndef OPS_TO_REGISTER')
     96   append('#define OPS_TO_REGISTER')
     97 
     98   if include_all_ops_and_kernels:
     99     append('#define SHOULD_REGISTER_OP(op) true')
    100     append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
    101     append('#define SHOULD_REGISTER_OP_GRADIENT true')
    102   else:
    103     line = '''
    104     namespace {
    105       constexpr const char* skip(const char* x) {
    106         return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
    107       }
    108 
    109       constexpr bool isequal(const char* x, const char* y) {
    110         return (*skip(x) && *skip(y))
    111                    ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
    112                    : (!*skip(x) && !*skip(y));
    113       }
    114 
    115       template<int N>
    116       struct find_in {
    117         static constexpr bool f(const char* x, const char* const y[N]) {
    118           return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
    119         }
    120       };
    121 
    122       template<>
    123       struct find_in<0> {
    124         static constexpr bool f(const char* x, const char* const y[]) {
    125           return false;
    126         }
    127       };
    128     }  // end namespace
    129     '''
    130     line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
    131     for _, kernel_class in ops_and_kernels:
    132       line += '"%s",\n' % kernel_class
    133     line += '};'
    134     append(line)
    135     append('#define SHOULD_REGISTER_OP_KERNEL(clz) '
    136            '(find_in<sizeof(kNecessaryOpKernelClasses) '
    137            '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, '
    138            'kNecessaryOpKernelClasses))')
    139     append('')
    140 
    141     append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
    142     append('  return false')
    143     for op in sorted(ops):
    144       append('     || isequal(op, "%s")' % op)
    145     append('  ;')
    146     append('}')
    147     append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
    148     append('')
    149 
    150     append('#define SHOULD_REGISTER_OP_GRADIENT ' + (
    151         'true' if 'SymbolicGradient' in ops else 'false'))
    152 
    153   append('#endif')
    154   return '\n'.join(result_list)
    155 
    156 
    157 def get_header(graphs,
    158                proto_fileformat='rawproto',
    159                default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'):
    160   """Computes a header for use with tensorflow SELECTIVE_REGISTRATION.
    161 
    162   Args:
    163     graphs: a list of paths to GraphDef files to include.
    164     proto_fileformat: optional format of proto file, either 'textproto' or
    165       'rawproto' (default).
    166     default_ops: optional comma-separated string of operator:kernel pairs to
    167       always include implementation for. Pass 'all' to have all operators and
    168       kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
    169   Returns:
    170     the string of the header that should be written as ops_to_register.h.
    171   """
    172   ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops)
    173   if not ops_and_kernels:
    174     print('Error reading graph!')
    175     return 1
    176 
    177   return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all')
    178