Home | History | Annotate | Download | only in framework
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/framework/cc_op_gen.h"
     17 
     18 #include "tensorflow/core/framework/op_def.pb.h"
     19 #include "tensorflow/core/framework/op_gen_lib.h"
     20 #include "tensorflow/core/lib/core/status_test_util.h"
     21 #include "tensorflow/core/lib/io/path.h"
     22 #include "tensorflow/core/platform/test.h"
     23 
     24 namespace tensorflow {
     25 namespace {
     26 
     27 constexpr char kBaseOpDef[] = R"(
     28 op {
     29   name: "Foo"
     30   input_arg {
     31     name: "images"
     32     description: "Images to process."
     33   }
     34   input_arg {
     35     name: "dim"
     36     description: "Description for dim."
     37     type: DT_FLOAT
     38   }
     39   output_arg {
     40     name: "output"
     41     description: "Description for output."
     42     type: DT_FLOAT
     43   }
     44   attr {
     45     name: "T"
     46     type: "type"
     47     description: "Type for images"
     48     allowed_values {
     49       list {
     50         type: DT_UINT8
     51         type: DT_INT8
     52       }
     53     }
     54     default_value {
     55       i: 1
     56     }
     57   }
     58   summary: "Summary for op Foo."
     59   description: "Description for op Foo."
     60 }
     61 )";
     62 
     63 void ExpectHasSubstr(StringPiece s, StringPiece expected) {
     64   EXPECT_TRUE(s.contains(expected))
     65       << "'" << s << "' does not contain '" << expected << "'";
     66 }
     67 
     68 void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
     69   EXPECT_FALSE(s.contains(expected))
     70       << "'" << s << "' contains '" << expected << "'";
     71 }
     72 
     73 void ExpectSubstrOrder(const string& s, const string& before,
     74                        const string& after) {
     75   int before_pos = s.find(before);
     76   int after_pos = s.find(after);
     77   ASSERT_NE(std::string::npos, before_pos);
     78   ASSERT_NE(std::string::npos, after_pos);
     79   EXPECT_LT(before_pos, after_pos)
     80       << before << " is not before " << after << " in " << s;
     81 }
     82 
     83 // Runs WriteCCOps and stores output in (internal_)cc_file_path and
     84 // (internal_)h_file_path.
     85 void GenerateCcOpFiles(Env* env, const OpList& ops,
     86                        const ApiDefMap& api_def_map, string* h_file_text,
     87                        string* internal_h_file_text) {
     88   const string& tmpdir = testing::TmpDir();
     89 
     90   const auto h_file_path = io::JoinPath(tmpdir, "test.h");
     91   const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
     92   const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
     93   const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
     94 
     95   WriteCCOps(ops, api_def_map, h_file_path, cc_file_path);
     96 
     97   TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
     98   TF_ASSERT_OK(
     99       ReadFileToString(env, internal_h_file_path, internal_h_file_text));
    100 }
    101 
    102 TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
    103   const string api_def = R"(
    104 op {
    105   graph_op_name: "Foo"
    106   visibility: HIDDEN
    107 }
    108 )";
    109   Env* env = Env::Default();
    110   OpList op_defs;
    111   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);  // NOLINT
    112   ApiDefMap api_def_map(op_defs);
    113 
    114   string h_file_text, internal_h_file_text;
    115   // Without ApiDef
    116   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
    117                     &internal_h_file_text);
    118   ExpectHasSubstr(h_file_text, "class Foo");
    119   ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
    120 
    121   // With ApiDef
    122   TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
    123   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
    124                     &internal_h_file_text);
    125   ExpectHasSubstr(internal_h_file_text, "class Foo");
    126   ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
    127 }
    128 
    129 TEST(CcOpGenTest, TestArgNameChanges) {
    130   const string api_def = R"(
    131 op {
    132   graph_op_name: "Foo"
    133   arg_order: "dim"
    134   arg_order: "images"
    135 }
    136 )";
    137   Env* env = Env::Default();
    138   OpList op_defs;
    139   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);  // NOLINT
    140 
    141   ApiDefMap api_def_map(op_defs);
    142   string cc_file_text, h_file_text;
    143   string internal_cc_file_text, internal_h_file_text;
    144   // Without ApiDef
    145   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
    146                     &internal_h_file_text);
    147   ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
    148 
    149   // With ApiDef
    150   TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
    151   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
    152                     &internal_h_file_text);
    153   ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
    154 }
    155 
    156 TEST(CcOpGenTest, TestEndpoints) {
    157   const string api_def = R"(
    158 op {
    159   graph_op_name: "Foo"
    160   endpoint {
    161     name: "Foo1"
    162   }
    163   endpoint {
    164     name: "Foo2"
    165   }
    166 }
    167 )";
    168   Env* env = Env::Default();
    169   OpList op_defs;
    170   protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);  // NOLINT
    171 
    172   ApiDefMap api_def_map(op_defs);
    173   string cc_file_text, h_file_text;
    174   string internal_cc_file_text, internal_h_file_text;
    175   // Without ApiDef
    176   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
    177                     &internal_h_file_text);
    178   ExpectHasSubstr(h_file_text, "class Foo {");
    179   ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
    180   ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
    181 
    182   // With ApiDef
    183   TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
    184   GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
    185                     &internal_h_file_text);
    186   ExpectHasSubstr(h_file_text, "class Foo1");
    187   ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
    188   ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
    189 }
    190 }  // namespace
    191 }  // namespace tensorflow
    192