Home | History | Annotate | Download | only in tools
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for print_selective_registration_header."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import sys
     23 
     24 from google.protobuf import text_format
     25 
     26 from tensorflow.core.framework import graph_pb2
     27 from tensorflow.python.platform import gfile
     28 from tensorflow.python.platform import test
     29 from tensorflow.python.tools import selective_registration_header_lib
     30 
     31 # Note that this graph def is not valid to be loaded - its inputs are not
     32 # assigned correctly in all cases.
     33 GRAPH_DEF_TXT = """
     34   node: {
     35     name: "node_1"
     36     op: "Reshape"
     37     input: [ "none", "none" ]
     38     device: "/cpu:0"
     39     attr: { key: "T" value: { type: DT_FLOAT } }
     40   }
     41   node: {
     42     name: "node_2"
     43     op: "MatMul"
     44     input: [ "none", "none" ]
     45     device: "/cpu:0"
     46     attr: { key: "T" value: { type: DT_FLOAT } }
     47     attr: { key: "transpose_a" value: { b: false } }
     48     attr: { key: "transpose_b" value: { b: false } }
     49   }
     50   node: {
     51     name: "node_3"
     52     op: "MatMul"
     53     input: [ "none", "none" ]
     54     device: "/cpu:0"
     55     attr: { key: "T" value: { type: DT_DOUBLE } }
     56     attr: { key: "transpose_a" value: { b: false } }
     57     attr: { key: "transpose_b" value: { b: false } }
     58   }
     59 """
     60 
     61 GRAPH_DEF_TXT_2 = """
     62   node: {
     63     name: "node_4"
     64     op: "BiasAdd"
     65     input: [ "none", "none" ]
     66     device: "/cpu:0"
     67     attr: { key: "T" value: { type: DT_FLOAT } }
     68   }
     69 
     70 """
     71 
     72 
     73 class PrintOpFilegroupTest(test.TestCase):
     74 
     75   def setUp(self):
     76     _, self.script_name = os.path.split(sys.argv[0])
     77 
     78   def WriteGraphFiles(self, graphs):
     79     fnames = []
     80     for i, graph in enumerate(graphs):
     81       fname = os.path.join(self.get_temp_dir(), 'graph%s.pb' % i)
     82       with gfile.GFile(fname, 'wb') as f:
     83         f.write(graph.SerializeToString())
     84       fnames.append(fname)
     85     return fnames
     86 
     87   def testGetOps(self):
     88     default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
     89     graphs = [
     90         text_format.Parse(d, graph_pb2.GraphDef())
     91         for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
     92     ]
     93 
     94     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
     95         'rawproto', self.WriteGraphFiles(graphs), default_ops)
     96     self.assertListEqual(
     97         [
     98             ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
     99             ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
    100             ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
    101             ('NoOp', 'NoOp'),  #
    102             ('Reshape', 'ReshapeOp'),  #
    103             ('_Recv', 'RecvOp'),  #
    104             ('_Send', 'SendOp'),  #
    105         ],
    106         ops_and_kernels)
    107 
    108     graphs[0].node[0].ClearField('device')
    109     graphs[0].node[2].ClearField('device')
    110     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
    111         'rawproto', self.WriteGraphFiles(graphs), default_ops)
    112     self.assertListEqual(
    113         [
    114             ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
    115             ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
    116             ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
    117             ('NoOp', 'NoOp'),  #
    118             ('Reshape', 'ReshapeOp'),  #
    119             ('_Recv', 'RecvOp'),  #
    120             ('_Send', 'SendOp'),  #
    121         ],
    122         ops_and_kernels)
    123 
    124   def testAll(self):
    125     default_ops = 'all'
    126     graphs = [
    127         text_format.Parse(d, graph_pb2.GraphDef())
    128         for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    129     ]
    130     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
    131         'rawproto', self.WriteGraphFiles(graphs), default_ops)
    132 
    133     header = selective_registration_header_lib.get_header_from_ops_and_kernels(
    134         ops_and_kernels, include_all_ops_and_kernels=True)
    135     self.assertListEqual(
    136         [
    137             '// This file was autogenerated by %s' % self.script_name,
    138             '#ifndef OPS_TO_REGISTER',  #
    139             '#define OPS_TO_REGISTER',  #
    140             '#define SHOULD_REGISTER_OP(op) true',  #
    141             '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
    142             '#define SHOULD_REGISTER_OP_GRADIENT true',  #
    143             '#endif'
    144         ],
    145         header.split('\n'))
    146 
    147     self.assertListEqual(
    148         header.split('\n'),
    149         selective_registration_header_lib.get_header(
    150             self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n'))
    151 
    152   def testGetSelectiveHeader(self):
    153     default_ops = ''
    154     graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())]
    155 
    156     expected = '''// This file was autogenerated by %s
    157 #ifndef OPS_TO_REGISTER
    158 #define OPS_TO_REGISTER
    159 
    160     namespace {
    161       constexpr const char* skip(const char* x) {
    162         return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
    163       }
    164 
    165       constexpr bool isequal(const char* x, const char* y) {
    166         return (*skip(x) && *skip(y))
    167                    ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
    168                    : (!*skip(x) && !*skip(y));
    169       }
    170 
    171       template<int N>
    172       struct find_in {
    173         static constexpr bool f(const char* x, const char* const y[N]) {
    174           return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
    175         }
    176       };
    177 
    178       template<>
    179       struct find_in<0> {
    180         static constexpr bool f(const char* x, const char* const y[]) {
    181           return false;
    182         }
    183       };
    184     }  // end namespace
    185     constexpr const char* kNecessaryOpKernelClasses[] = {
    186 "BiasOp<CPUDevice, float>",
    187 };
    188 #define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
    189 
    190 constexpr inline bool ShouldRegisterOp(const char op[]) {
    191   return false
    192      || isequal(op, "BiasAdd")
    193   ;
    194 }
    195 #define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
    196 
    197 #define SHOULD_REGISTER_OP_GRADIENT false
    198 #endif''' % self.script_name
    199 
    200     header = selective_registration_header_lib.get_header(
    201         self.WriteGraphFiles(graphs), 'rawproto', default_ops)
    202     print(header)
    203     self.assertListEqual(expected.split('\n'), header.split('\n'))
    204 
    205 
    206 if __name__ == '__main__':
    207   test.main()
    208