1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for print_selective_registration_header.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import sys 23 24 from google.protobuf import text_format 25 26 from tensorflow.core.framework import graph_pb2 27 from tensorflow.python.platform import gfile 28 from tensorflow.python.platform import test 29 from tensorflow.python.tools import selective_registration_header_lib 30 31 # Note that this graph def is not valid to be loaded - its inputs are not 32 # assigned correctly in all cases. 33 GRAPH_DEF_TXT = """ 34 node: { 35 name: "node_1" 36 op: "Reshape" 37 input: [ "none", "none" ] 38 device: "/cpu:0" 39 attr: { key: "T" value: { type: DT_FLOAT } } 40 } 41 node: { 42 name: "node_2" 43 op: "MatMul" 44 input: [ "none", "none" ] 45 device: "/cpu:0" 46 attr: { key: "T" value: { type: DT_FLOAT } } 47 attr: { key: "transpose_a" value: { b: false } } 48 attr: { key: "transpose_b" value: { b: false } } 49 } 50 node: { 51 name: "node_3" 52 op: "MatMul" 53 input: [ "none", "none" ] 54 device: "/cpu:0" 55 attr: { key: "T" value: { type: DT_DOUBLE } } 56 attr: { key: "transpose_a" value: { b: false } } 57 attr: { key: "transpose_b" value: { b: false } } 58 } 59 """ 60 61 GRAPH_DEF_TXT_2 = """ 62 node: { 63 name: "node_4" 64 op: "BiasAdd" 65 input: [ "none", "none" ] 66 device: "/cpu:0" 67 attr: { key: "T" value: { type: DT_FLOAT } } 68 } 69 70 """ 71 72 73 class PrintOpFilegroupTest(test.TestCase): 74 75 def setUp(self): 76 _, self.script_name = os.path.split(sys.argv[0]) 77 78 def WriteGraphFiles(self, graphs): 79 fnames = [] 80 for i, graph in enumerate(graphs): 81 fname = os.path.join(self.get_temp_dir(), 'graph%s.pb' % i) 82 with gfile.GFile(fname, 'wb') as f: 83 f.write(graph.SerializeToString()) 84 fnames.append(fname) 85 return fnames 86 87 def testGetOps(self): 88 default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' 89 graphs = [ 90 text_format.Parse(d, graph_pb2.GraphDef()) 91 for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] 92 ] 93 94 ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 95 'rawproto', self.WriteGraphFiles(graphs), default_ops) 96 self.assertListEqual( 97 [ 98 ('BiasAdd', 'BiasOp<CPUDevice, float>'), # 99 ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # 100 ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # 101 ('NoOp', 'NoOp'), # 102 ('Reshape', 'ReshapeOp'), # 103 ('_Recv', 'RecvOp'), # 104 ('_Send', 'SendOp'), # 105 ], 106 ops_and_kernels) 107 108 graphs[0].node[0].ClearField('device') 109 graphs[0].node[2].ClearField('device') 110 ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 111 'rawproto', self.WriteGraphFiles(graphs), default_ops) 112 self.assertListEqual( 113 [ 114 ('BiasAdd', 'BiasOp<CPUDevice, float>'), # 115 ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # 116 ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # 117 ('NoOp', 'NoOp'), # 118 ('Reshape', 'ReshapeOp'), # 119 ('_Recv', 'RecvOp'), # 120 ('_Send', 'SendOp'), # 121 ], 122 ops_and_kernels) 123 124 def testAll(self): 125 default_ops = 'all' 126 graphs = [ 127 text_format.Parse(d, graph_pb2.GraphDef()) 128 for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2] 129 ] 130 ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 131 'rawproto', self.WriteGraphFiles(graphs), default_ops) 132 133 header = selective_registration_header_lib.get_header_from_ops_and_kernels( 134 ops_and_kernels, include_all_ops_and_kernels=True) 135 self.assertListEqual( 136 [ 137 '// This file was autogenerated by %s' % self.script_name, 138 '#ifndef OPS_TO_REGISTER', # 139 '#define OPS_TO_REGISTER', # 140 '#define SHOULD_REGISTER_OP(op) true', # 141 '#define SHOULD_REGISTER_OP_KERNEL(clz) true', # 142 '#define SHOULD_REGISTER_OP_GRADIENT true', # 143 '#endif' 144 ], 145 header.split('\n')) 146 147 self.assertListEqual( 148 header.split('\n'), 149 selective_registration_header_lib.get_header( 150 self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n')) 151 152 def testGetSelectiveHeader(self): 153 default_ops = '' 154 graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())] 155 156 expected = '''// This file was autogenerated by %s 157 #ifndef OPS_TO_REGISTER 158 #define OPS_TO_REGISTER 159 160 namespace { 161 constexpr const char* skip(const char* x) { 162 return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; 163 } 164 165 constexpr bool isequal(const char* x, const char* y) { 166 return (*skip(x) && *skip(y)) 167 ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1)) 168 : (!*skip(x) && !*skip(y)); 169 } 170 171 template<int N> 172 struct find_in { 173 static constexpr bool f(const char* x, const char* const y[N]) { 174 return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1); 175 } 176 }; 177 178 template<> 179 struct find_in<0> { 180 static constexpr bool f(const char* x, const char* const y[]) { 181 return false; 182 } 183 }; 184 } // end namespace 185 constexpr const char* kNecessaryOpKernelClasses[] = { 186 "BiasOp<CPUDevice, float>", 187 }; 188 #define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses)) 189 190 constexpr inline bool ShouldRegisterOp(const char op[]) { 191 return false 192 || isequal(op, "BiasAdd") 193 ; 194 } 195 #define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) 196 197 #define SHOULD_REGISTER_OP_GRADIENT false 198 #endif''' % self.script_name 199 200 header = selective_registration_header_lib.get_header( 201 self.WriteGraphFiles(graphs), 'rawproto', default_ops) 202 print(header) 203 self.assertListEqual(expected.split('\n'), header.split('\n')) 204 205 206 if __name__ == '__main__': 207 test.main() 208