Home | History | Annotate | Download | only in aot
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/aot/codegen.h"
     17 
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "llvm/Support/TargetSelect.h"
     22 #include "tensorflow/compiler/xla/shape_util.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 #include "tensorflow/core/lib/core/status_test_util.h"
     25 #include "tensorflow/core/lib/core/stringpiece.h"
     26 #include "tensorflow/core/lib/io/path.h"
     27 #include "tensorflow/core/platform/env.h"
     28 #include "tensorflow/core/platform/test.h"
     29 
     30 namespace tensorflow {
     31 namespace tfcompile {
     32 namespace {
     33 
     34 void ExpectErrorContains(const Status& status, StringPiece str) {
     35   EXPECT_NE(Status::OK(), status);
     36   EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
     37       << "expected error: " << status.error_message() << " to contain: " << str;
     38 }
     39 
     40 TEST(ValidateCppIdent, Simple) {
     41   TF_EXPECT_OK(ValidateCppIdent("a", ""));
     42   TF_EXPECT_OK(ValidateCppIdent("abc", ""));
     43   TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
     44   TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
     45   // Make sure we didn't skip a valid letter or digit
     46   string ident;
     47   for (char c = 'a'; c <= 'z'; c++) {
     48     ident.append(1, c);
     49   }
     50   for (char c = 'A'; c <= 'Z'; c++) {
     51     ident.append(1, c);
     52   }
     53   for (char c = '0'; c <= '9'; c++) {
     54     ident.append(1, c);
     55   }
     56   ident += "_";
     57   TF_EXPECT_OK(ValidateCppIdent(ident, ""));
     58 
     59   ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
     60   ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
     61   ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
     62   ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
     63   ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
     64   ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
     65   ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
     66   ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
     67 }
     68 
     69 class ParseCppClassTest : public ::testing::Test {
     70  protected:
     71   void ExpectOK(const string& cpp_class, const string& want_class_name,
     72                 const std::vector<string>& want_namespaces) {
     73     string class_name;
     74     std::vector<string> namespaces;
     75     TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
     76     EXPECT_EQ(class_name, want_class_name);
     77     EXPECT_EQ(namespaces, want_namespaces);
     78   }
     79 
     80   void ExpectFail(const string& cpp_class) {
     81     string class_name;
     82     std::vector<string> namespaces;
     83     EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK());
     84   }
     85 };
     86 
     87 TEST_F(ParseCppClassTest, ParseOK) {
     88   ExpectOK("MyClass", "MyClass", {});
     89   ExpectOK("_MyClass", "_MyClass", {});
     90   ExpectOK("a::MyClass", "MyClass", {"a"});
     91   ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
     92   ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
     93   ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
     94   ExpectOK("foo::MyClass", "MyClass", {"foo"});
     95   ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
     96   ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
     97   // Make sure we didn't skip a valid letter or digit
     98   string ident;
     99   for (char c = 'a'; c <= 'z'; c++) {
    100     ident.append(1, c);
    101   }
    102   for (char c = 'A'; c <= 'Z'; c++) {
    103     ident.append(1, c);
    104   }
    105   for (char c = '0'; c <= '9'; c++) {
    106     ident.append(1, c);
    107   }
    108   ident += "_";
    109   ExpectOK(ident, ident, {});
    110   ExpectOK(ident + "::" + ident, ident, {ident});
    111   ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
    112 }
    113 
    114 TEST_F(ParseCppClassTest, ParseFail) {
    115   ExpectFail("");
    116   ExpectFail("::");
    117   ExpectFail("::MyClass");  // valid C++, but disallowed for simpler code.
    118   ExpectFail("0");
    119   ExpectFail("a.b");
    120   ExpectFail("a:b");
    121   ExpectFail("good::.bad");
    122   ExpectFail("good:::bad");
    123   ExpectFail("good:: bad");
    124   ExpectFail("good::0bad");
    125 }
    126 
    127 static void CompareWithGoldenFile(
    128     const string& tensorflow_relative_golden_file_name,
    129     const string& expected_contents) {
    130   // To update the golden file, flip update_golden to true and run the
    131   // following:
    132   // bazel test --test_strategy=local \
    133   //   third_party/tensorflow/compiler/aot:codegen_test
    134   const bool update_golden = false;
    135   const string golden_file_name = io::JoinPath(
    136       testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
    137 
    138   if (update_golden) {
    139     TF_EXPECT_OK(
    140         WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
    141   }
    142 
    143   string golden_file_contents;
    144   TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
    145                                 &golden_file_contents));
    146   EXPECT_EQ(golden_file_contents, expected_contents);
    147 }
    148 
    149 TEST(CodegenTest, Golden) {
    150   // Normally CpuCompiler::CpuCompiler does this, but in this test we've
    151   // bypassed the Cpu compiler so we have to do this manually.
    152   llvm::InitializeNativeTarget();
    153   llvm::InitializeNativeTargetAsmPrinter();
    154   LLVMInitializeX86Target();
    155   LLVMInitializeX86TargetMC();
    156 
    157   CodegenOpts opts;
    158   opts.class_name = "MyClass";
    159   opts.target_triple = "x86_64-pc-linux";
    160   opts.namespaces = {"foo", "bar"};
    161   opts.gen_name_to_index = true;
    162   opts.gen_program_shape = true;
    163   tf2xla::Config config;
    164   tf2xla::Feed* feed = config.add_feed();
    165   feed->mutable_id()->set_node_name("feed0");
    166   feed->set_name("myfeed");
    167   feed = config.add_feed();
    168   feed->mutable_id()->set_node_name("feed1");
    169   tf2xla::Fetch* fetch = config.add_fetch();
    170   fetch->mutable_id()->set_node_name("fetch0");
    171   fetch->set_name("myfetch");
    172   CompileResult compile_result;
    173   compile_result.aot.reset(
    174       new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
    175   compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
    176       {
    177           xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
    178           xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
    179       },
    180       xla::ShapeUtil::MakeTupleShape(
    181           {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
    182   compile_result.entry_point = "entry_point";
    183   compile_result.pointer_size = 8;
    184 
    185   MetadataResult metadata_result;
    186   TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result));
    187 
    188   // The other fields in metadata_result are tested as part of the generated
    189   // header test.
    190 
    191   CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
    192                         metadata_result.object_file_data);
    193 
    194   string header;
    195   TF_ASSERT_OK(
    196       GenerateHeader(opts, config, compile_result, metadata_result, &header));
    197 
    198   CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
    199 }
    200 }  // namespace
    201 }  // namespace tfcompile
    202 }  // namespace tensorflow
    203