1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""Computes a header file to be used with SELECTIVE_REGISTRATION. 16 17 See the executable wrapper, print_selective_registration_header.py, for more 18 information. 19 """ 20 21 from __future__ import absolute_import 22 from __future__ import division 23 from __future__ import print_function 24 25 import os 26 import sys 27 28 from google.protobuf import text_format 29 30 from tensorflow.core.framework import graph_pb2 31 from tensorflow.python import pywrap_tensorflow 32 from tensorflow.python.platform import gfile 33 from tensorflow.python.platform import tf_logging 34 35 36 def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): 37 """Gets the ops and kernels needed from the model files.""" 38 ops = set() 39 40 for proto_file in proto_files: 41 tf_logging.info('Loading proto file %s', proto_file) 42 # Load GraphDef. 43 file_data = gfile.GFile(proto_file, 'rb').read() 44 if proto_fileformat == 'rawproto': 45 graph_def = graph_pb2.GraphDef.FromString(file_data) 46 else: 47 assert proto_fileformat == 'textproto' 48 graph_def = text_format.Parse(file_data, graph_pb2.GraphDef()) 49 50 # Find all ops and kernels used by the graph. 51 for node_def in graph_def.node: 52 if not node_def.device: 53 node_def.device = '/cpu:0' 54 kernel_class = pywrap_tensorflow.TryFindKernelClass( 55 node_def.SerializeToString()) 56 if kernel_class: 57 op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8'))) 58 if op_and_kernel not in ops: 59 ops.add(op_and_kernel) 60 else: 61 print( 62 'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr) 63 64 # Add default ops. 65 if default_ops_str and default_ops_str != 'all': 66 for s in default_ops_str.split(','): 67 op, kernel = s.split(':') 68 op_and_kernel = (op, kernel) 69 if op_and_kernel not in ops: 70 ops.add(op_and_kernel) 71 72 return list(sorted(ops)) 73 74 75 def get_header_from_ops_and_kernels(ops_and_kernels, 76 include_all_ops_and_kernels): 77 """Returns a header for use with tensorflow SELECTIVE_REGISTRATION. 78 79 Args: 80 ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include. 81 include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op 82 kernels are included. 83 84 Returns: 85 the string of the header that should be written as ops_to_register.h. 86 """ 87 ops = set([op for op, _ in ops_and_kernels]) 88 result_list = [] 89 90 def append(s): 91 result_list.append(s) 92 93 _, script_name = os.path.split(sys.argv[0]) 94 append('// This file was autogenerated by %s' % script_name) 95 append('#ifndef OPS_TO_REGISTER') 96 append('#define OPS_TO_REGISTER') 97 98 if include_all_ops_and_kernels: 99 append('#define SHOULD_REGISTER_OP(op) true') 100 append('#define SHOULD_REGISTER_OP_KERNEL(clz) true') 101 append('#define SHOULD_REGISTER_OP_GRADIENT true') 102 else: 103 line = ''' 104 namespace { 105 constexpr const char* skip(const char* x) { 106 return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; 107 } 108 109 constexpr bool isequal(const char* x, const char* y) { 110 return (*skip(x) && *skip(y)) 111 ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1)) 112 : (!*skip(x) && !*skip(y)); 113 } 114 115 template<int N> 116 struct find_in { 117 static constexpr bool f(const char* x, const char* const y[N]) { 118 return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1); 119 } 120 }; 121 122 template<> 123 struct find_in<0> { 124 static constexpr bool f(const char* x, const char* const y[]) { 125 return false; 126 } 127 }; 128 } // end namespace 129 ''' 130 line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' 131 for _, kernel_class in ops_and_kernels: 132 line += '"%s",\n' % kernel_class 133 line += '};' 134 append(line) 135 append('#define SHOULD_REGISTER_OP_KERNEL(clz) ' 136 '(find_in<sizeof(kNecessaryOpKernelClasses) ' 137 '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, ' 138 'kNecessaryOpKernelClasses))') 139 append('') 140 141 append('constexpr inline bool ShouldRegisterOp(const char op[]) {') 142 append(' return false') 143 for op in sorted(ops): 144 append(' || isequal(op, "%s")' % op) 145 append(' ;') 146 append('}') 147 append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') 148 append('') 149 150 append('#define SHOULD_REGISTER_OP_GRADIENT ' + ( 151 'true' if 'SymbolicGradient' in ops else 'false')) 152 153 append('#endif') 154 return '\n'.join(result_list) 155 156 157 def get_header(graphs, 158 proto_fileformat='rawproto', 159 default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'): 160 """Computes a header for use with tensorflow SELECTIVE_REGISTRATION. 161 162 Args: 163 graphs: a list of paths to GraphDef files to include. 164 proto_fileformat: optional format of proto file, either 'textproto' or 165 'rawproto' (default). 166 default_ops: optional comma-separated string of operator:kernel pairs to 167 always include implementation for. Pass 'all' to have all operators and 168 kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'. 169 Returns: 170 the string of the header that should be written as ops_to_register.h. 171 """ 172 ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops) 173 if not ops_and_kernels: 174 print('Error reading graph!') 175 return 1 176 177 return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all') 178