1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/framework/cc_op_gen.h" 17 18 #include "tensorflow/core/framework/op_def.pb.h" 19 #include "tensorflow/core/framework/op_gen_lib.h" 20 #include "tensorflow/core/lib/core/status_test_util.h" 21 #include "tensorflow/core/lib/io/path.h" 22 #include "tensorflow/core/platform/test.h" 23 24 namespace tensorflow { 25 namespace { 26 27 constexpr char kBaseOpDef[] = R"( 28 op { 29 name: "Foo" 30 input_arg { 31 name: "images" 32 description: "Images to process." 33 } 34 input_arg { 35 name: "dim" 36 description: "Description for dim." 37 type: DT_FLOAT 38 } 39 output_arg { 40 name: "output" 41 description: "Description for output." 42 type: DT_FLOAT 43 } 44 attr { 45 name: "T" 46 type: "type" 47 description: "Type for images" 48 allowed_values { 49 list { 50 type: DT_UINT8 51 type: DT_INT8 52 } 53 } 54 default_value { 55 i: 1 56 } 57 } 58 summary: "Summary for op Foo." 59 description: "Description for op Foo." 60 } 61 )"; 62 63 void ExpectHasSubstr(StringPiece s, StringPiece expected) { 64 EXPECT_TRUE(s.contains(expected)) 65 << "'" << s << "' does not contain '" << expected << "'"; 66 } 67 68 void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { 69 EXPECT_FALSE(s.contains(expected)) 70 << "'" << s << "' contains '" << expected << "'"; 71 } 72 73 void ExpectSubstrOrder(const string& s, const string& before, 74 const string& after) { 75 int before_pos = s.find(before); 76 int after_pos = s.find(after); 77 ASSERT_NE(std::string::npos, before_pos); 78 ASSERT_NE(std::string::npos, after_pos); 79 EXPECT_LT(before_pos, after_pos) 80 << before << " is not before " << after << " in " << s; 81 } 82 83 // Runs WriteCCOps and stores output in (internal_)cc_file_path and 84 // (internal_)h_file_path. 85 void GenerateCcOpFiles(Env* env, const OpList& ops, 86 const ApiDefMap& api_def_map, string* h_file_text, 87 string* internal_h_file_text) { 88 const string& tmpdir = testing::TmpDir(); 89 90 const auto h_file_path = io::JoinPath(tmpdir, "test.h"); 91 const auto cc_file_path = io::JoinPath(tmpdir, "test.cc"); 92 const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); 93 const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); 94 95 WriteCCOps(ops, api_def_map, h_file_path, cc_file_path); 96 97 TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); 98 TF_ASSERT_OK( 99 ReadFileToString(env, internal_h_file_path, internal_h_file_text)); 100 } 101 102 TEST(CcOpGenTest, TestVisibilityChangedToHidden) { 103 const string api_def = R"( 104 op { 105 graph_op_name: "Foo" 106 visibility: HIDDEN 107 } 108 )"; 109 Env* env = Env::Default(); 110 OpList op_defs; 111 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT 112 ApiDefMap api_def_map(op_defs); 113 114 string h_file_text, internal_h_file_text; 115 // Without ApiDef 116 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, 117 &internal_h_file_text); 118 ExpectHasSubstr(h_file_text, "class Foo"); 119 ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo"); 120 121 // With ApiDef 122 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); 123 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, 124 &internal_h_file_text); 125 ExpectHasSubstr(internal_h_file_text, "class Foo"); 126 ExpectDoesNotHaveSubstr(h_file_text, "class Foo"); 127 } 128 129 TEST(CcOpGenTest, TestArgNameChanges) { 130 const string api_def = R"( 131 op { 132 graph_op_name: "Foo" 133 arg_order: "dim" 134 arg_order: "images" 135 } 136 )"; 137 Env* env = Env::Default(); 138 OpList op_defs; 139 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT 140 141 ApiDefMap api_def_map(op_defs); 142 string cc_file_text, h_file_text; 143 string internal_cc_file_text, internal_h_file_text; 144 // Without ApiDef 145 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, 146 &internal_h_file_text); 147 ExpectSubstrOrder(h_file_text, "Input images", "Input dim"); 148 149 // With ApiDef 150 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); 151 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, 152 &internal_h_file_text); 153 ExpectSubstrOrder(h_file_text, "Input dim", "Input images"); 154 } 155 156 TEST(CcOpGenTest, TestEndpoints) { 157 const string api_def = R"( 158 op { 159 graph_op_name: "Foo" 160 endpoint { 161 name: "Foo1" 162 } 163 endpoint { 164 name: "Foo2" 165 } 166 } 167 )"; 168 Env* env = Env::Default(); 169 OpList op_defs; 170 protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT 171 172 ApiDefMap api_def_map(op_defs); 173 string cc_file_text, h_file_text; 174 string internal_cc_file_text, internal_h_file_text; 175 // Without ApiDef 176 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, 177 &internal_h_file_text); 178 ExpectHasSubstr(h_file_text, "class Foo {"); 179 ExpectDoesNotHaveSubstr(h_file_text, "class Foo1"); 180 ExpectDoesNotHaveSubstr(h_file_text, "class Foo2"); 181 182 // With ApiDef 183 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); 184 GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, 185 &internal_h_file_text); 186 ExpectHasSubstr(h_file_text, "class Foo1"); 187 ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2"); 188 ExpectDoesNotHaveSubstr(h_file_text, "class Foo {"); 189 } 190 } // namespace 191 } // namespace tensorflow 192