1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/aot/codegen.h" 17 18 #include <string> 19 #include <vector> 20 21 #include "llvm/Support/TargetSelect.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/lib/core/status_test_util.h" 25 #include "tensorflow/core/lib/core/stringpiece.h" 26 #include "tensorflow/core/lib/io/path.h" 27 #include "tensorflow/core/platform/env.h" 28 #include "tensorflow/core/platform/test.h" 29 30 namespace tensorflow { 31 namespace tfcompile { 32 namespace { 33 34 void ExpectErrorContains(const Status& status, StringPiece str) { 35 EXPECT_NE(Status::OK(), status); 36 EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) 37 << "expected error: " << status.error_message() << " to contain: " << str; 38 } 39 40 TEST(ValidateCppIdent, Simple) { 41 TF_EXPECT_OK(ValidateCppIdent("a", "")); 42 TF_EXPECT_OK(ValidateCppIdent("abc", "")); 43 TF_EXPECT_OK(ValidateCppIdent("_abc", "")); 44 TF_EXPECT_OK(ValidateCppIdent("_abc123", "")); 45 // Make sure we didn't skip a valid letter or digit 46 string ident; 47 for (char c = 'a'; c <= 'z'; c++) { 48 ident.append(1, c); 49 } 50 for (char c = 'A'; c <= 'Z'; c++) { 51 ident.append(1, c); 52 } 53 for (char c = '0'; c <= '9'; c++) { 54 ident.append(1, c); 55 } 56 ident += "_"; 57 TF_EXPECT_OK(ValidateCppIdent(ident, "")); 58 59 ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier"); 60 ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char"); 61 ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char"); 62 ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char"); 63 ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char"); 64 ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char"); 65 ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); 66 ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); 67 } 68 69 class ParseCppClassTest : public ::testing::Test { 70 protected: 71 void ExpectOK(const string& cpp_class, const string& want_class_name, 72 const std::vector<string>& want_namespaces) { 73 string class_name; 74 std::vector<string> namespaces; 75 TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces)); 76 EXPECT_EQ(class_name, want_class_name); 77 EXPECT_EQ(namespaces, want_namespaces); 78 } 79 80 void ExpectFail(const string& cpp_class) { 81 string class_name; 82 std::vector<string> namespaces; 83 EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK()); 84 } 85 }; 86 87 TEST_F(ParseCppClassTest, ParseOK) { 88 ExpectOK("MyClass", "MyClass", {}); 89 ExpectOK("_MyClass", "_MyClass", {}); 90 ExpectOK("a::MyClass", "MyClass", {"a"}); 91 ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"}); 92 ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"}); 93 ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"}); 94 ExpectOK("foo::MyClass", "MyClass", {"foo"}); 95 ExpectOK("_foo::MyClass", "MyClass", {"_foo"}); 96 ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"}); 97 // Make sure we didn't skip a valid letter or digit 98 string ident; 99 for (char c = 'a'; c <= 'z'; c++) { 100 ident.append(1, c); 101 } 102 for (char c = 'A'; c <= 'Z'; c++) { 103 ident.append(1, c); 104 } 105 for (char c = '0'; c <= '9'; c++) { 106 ident.append(1, c); 107 } 108 ident += "_"; 109 ExpectOK(ident, ident, {}); 110 ExpectOK(ident + "::" + ident, ident, {ident}); 111 ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident}); 112 } 113 114 TEST_F(ParseCppClassTest, ParseFail) { 115 ExpectFail(""); 116 ExpectFail("::"); 117 ExpectFail("::MyClass"); // valid C++, but disallowed for simpler code. 118 ExpectFail("0"); 119 ExpectFail("a.b"); 120 ExpectFail("a:b"); 121 ExpectFail("good::.bad"); 122 ExpectFail("good:::bad"); 123 ExpectFail("good:: bad"); 124 ExpectFail("good::0bad"); 125 } 126 127 static void CompareWithGoldenFile( 128 const string& tensorflow_relative_golden_file_name, 129 const string& expected_contents) { 130 // To update the golden file, flip update_golden to true and run the 131 // following: 132 // bazel test --test_strategy=local \ 133 // third_party/tensorflow/compiler/aot:codegen_test 134 const bool update_golden = false; 135 const string golden_file_name = io::JoinPath( 136 testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name); 137 138 if (update_golden) { 139 TF_EXPECT_OK( 140 WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); 141 } 142 143 string golden_file_contents; 144 TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, 145 &golden_file_contents)); 146 EXPECT_EQ(golden_file_contents, expected_contents); 147 } 148 149 TEST(CodegenTest, Golden) { 150 // Normally CpuCompiler::CpuCompiler does this, but in this test we've 151 // bypassed the Cpu compiler so we have to do this manually. 152 llvm::InitializeNativeTarget(); 153 llvm::InitializeNativeTargetAsmPrinter(); 154 LLVMInitializeX86Target(); 155 LLVMInitializeX86TargetMC(); 156 157 CodegenOpts opts; 158 opts.class_name = "MyClass"; 159 opts.target_triple = "x86_64-pc-linux"; 160 opts.namespaces = {"foo", "bar"}; 161 opts.gen_name_to_index = true; 162 opts.gen_program_shape = true; 163 tf2xla::Config config; 164 tf2xla::Feed* feed = config.add_feed(); 165 feed->mutable_id()->set_node_name("feed0"); 166 feed->set_name("myfeed"); 167 feed = config.add_feed(); 168 feed->mutable_id()->set_node_name("feed1"); 169 tf2xla::Fetch* fetch = config.add_fetch(); 170 fetch->mutable_id()->set_node_name("fetch0"); 171 fetch->set_name("myfetch"); 172 CompileResult compile_result; 173 compile_result.aot.reset( 174 new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5)); 175 compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( 176 { 177 xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), 178 xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), 179 }, 180 xla::ShapeUtil::MakeTupleShape( 181 {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); 182 compile_result.entry_point = "entry_point"; 183 compile_result.pointer_size = 8; 184 185 MetadataResult metadata_result; 186 TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result)); 187 188 // The other fields in metadata_result are tested as part of the generated 189 // header test. 190 191 CompareWithGoldenFile("compiler/aot/codegen_test_o.golden", 192 metadata_result.object_file_data); 193 194 string header; 195 TF_ASSERT_OK( 196 GenerateHeader(opts, config, compile_result, metadata_result, &header)); 197 198 CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header); 199 } 200 } // namespace 201 } // namespace tfcompile 202 } // namespace tensorflow 203