Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/framework/ops.h"
     17 #include "tensorflow/cc/framework/scope.h"
     18 #include "tensorflow/cc/ops/const_op.h"
     19 #include "tensorflow/core/framework/fake_input.h"
     20 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/graph/graph.h"
     23 #include "tensorflow/core/graph/node_builder.h"
     24 #include "tensorflow/core/graph/testlib.h"
     25 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
     26 #include "tensorflow/core/kernels/ops_testutil.h"
     27 #include "tensorflow/core/kernels/ops_util.h"
     28 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
     29 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
     30 #include "tensorflow/core/lib/core/status_test_util.h"
     31 #include "tensorflow/core/platform/test.h"
     32 #include "tensorflow/core/platform/test_benchmark.h"
     33 #include "tensorflow/core/public/session.h"
     34 #include "tensorflow/core/public/session_options.h"
     35 
     36 namespace tensorflow {
     37 
     38 class RemoteFusedGraphExecuteTest : public OpsTestBase {};
     39 
     40 TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) {
     41   DataTypeVector input_types({DT_FLOAT, DT_FLOAT});
     42   DataTypeVector output_types({DT_FLOAT});
     43   TF_ASSERT_OK(
     44       NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
     45           .Input(FakeInput(2, DT_FLOAT))
     46           .Attr("Tinputs", input_types)
     47           .Attr("Toutputs", output_types)
     48           .Attr("serialized_remote_fused_graph_execute_info", "")
     49           .Finalize(node_def()));
     50   TF_ASSERT_OK(InitOp());
     51   // TODO(satok): Add benchmark
     52 }
     53 
     54 TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) {
     55   DataTypeVector input_types({DT_INT32, DT_INT32});
     56   DataTypeVector output_types({DT_FLOAT});
     57   ASSERT_FALSE(
     58       NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
     59           .Input(FakeInput(2, DT_FLOAT))
     60           .Attr("Tinputs", input_types)
     61           .Attr("Toutputs", output_types)
     62           .Attr("serialized_remote_fused_graph_execute_info", "")
     63           .Finalize(node_def())
     64           .ok());
     65   // TODO(satok): Add benchmark
     66 }
     67 
     68 ////////////////////////////
     69 // End-to-end test: Begin //
     70 ////////////////////////////
     71 // This test does a end-to-end test for a simple usage of
     72 // RemoteFusedGraphExecuteOp.
     73 
     74 constexpr const char* const NAME_A = "a";
     75 constexpr const char* const NAME_B = "b";
     76 constexpr const char* const NAME_A_PLUS_B = "a_plus_b";
     77 constexpr const char* const REMOTE_FUSED_EXECUTE_OP_NODE_NAME =
     78     "remote_fused_execute_op";
     79 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME =
     80     "build_test_remote_fused_graph_executor";
     81 
     82 constexpr float NODE_A_VAL = 2.0f;
     83 constexpr float NODE_A_VAL2 = 10.0f;
     84 constexpr float NODE_B_VAL = 3.0f;
     85 constexpr float FLOAT_VALUE_TOLERANCE = 1e-8f;
     86 
     87 // Utility functions //
     88 static Output BuildPlaceHolderOp(const string& name, const DataType dt,
     89                                  const TensorShape& tensor_shape, Scope* root) {
     90   const Scope& scope = root->WithOpName(name);
     91   Node* ret;
     92   const string unique_name = scope.GetUniqueNameForOp("Placeholder");
     93   NodeBuilder builder = NodeBuilder(unique_name, "Placeholder")
     94                             .Attr("dtype", dt)
     95                             .Attr("shape", tensor_shape);
     96   scope.UpdateBuilder(&builder);
     97   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
     98   CHECK(scope.ok());
     99   return Output(ret, 0);
    100 }
    101 
    102 static Output BuildRemoteFusedGraphExecuteOp(
    103     const string& name, const std::vector<Output>& output_list,
    104     const int output_node_count,
    105     const RemoteFusedGraphExecuteInfo& execute_info, Scope* root) {
    106   const Scope& scope = root->WithOpName(name);
    107   Node* ret;
    108   CHECK(scope.ok());
    109   auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
    110   const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
    111 
    112   DataTypeVector input_types{DT_FLOAT};
    113   DataTypeVector output_types{DT_FLOAT};
    114 
    115   auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
    116                      .Input(node_out_list)
    117                      .Attr("Tinputs", input_types)
    118                      .Attr("Toutputs", output_types)
    119                      .Attr("serialized_remote_fused_graph_execute_info",
    120                            StringPiece(execute_info.SerializeAsString()));
    121   CHECK(scope.ok());
    122   scope.UpdateBuilder(&builder);
    123   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
    124   CHECK(scope.ok());
    125   return Output(ret, 0);
    126 }
    127 
    128 static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
    129     const GraphDef& original_graph) {
    130   RemoteFusedGraphExecuteInfo execute_info;
    131   execute_info.set_executor_name(REMOTE_FUSED_EXECUTOR_NAME);
    132 
    133   // In this example, simply copy all nodes. Basically, you don't need to add
    134   // unused node for inference.
    135   for (const NodeDef& node : original_graph.node()) {
    136     NodeDef& copied_node = *execute_info.mutable_remote_graph()->add_node();
    137     copied_node = node;
    138     // Adding tensor shape type to the node
    139     // TODO(satok): Use TensorShapeMap to detime tensor shape type
    140     RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
    141         std::vector<DataType>({DT_FLOAT}),
    142         std::vector<TensorShape>({TensorShape()}), &copied_node);
    143   }
    144 
    145   // Add node A as input
    146   execute_info.add_graph_input_node_name(NAME_A);
    147   RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a =
    148       *execute_info.add_default_graph_input_tensor_shape();
    149   shape_a.set_dtype(DT_FLOAT);
    150   // (skip setting shape to shape_a as it's shape is rank = 0.)
    151 
    152   // Add node A + B as output
    153   execute_info.add_graph_output_node_name(NAME_A_PLUS_B);
    154   RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a_plus_b =
    155       *execute_info.add_default_graph_output_tensor_shape();
    156   shape_a_plus_b.set_dtype(DT_FLOAT);
    157   // (skip setting shape to shape_a_plus_b as it's shape is rank = 0.)
    158 
    159   return execute_info;
    160 }
    161 
    162 // 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph
    163 class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
    164  public:
    165   int GetVersion() final { return 1; }
    166   bool Init(const RemoteFusedGraphExecuteInfo& info) final {
    167     info_ = &info;
    168     for (const NodeDef& node_def : info.remote_graph().node()) {
    169       node_def_map_.emplace(node_def.name(), &node_def);
    170     }
    171     return true;
    172   }
    173   bool Finalize() final { return true; }
    174   bool SetupGraph() final { return true; }
    175   bool ExecuteGraph() final {
    176     CHECK(info_ != nullptr);
    177     // TODO(satok): Add utilities to implement this function more easily.
    178     // CAVEAT: This test only handles add op. You can implement here as you
    179     // like.
    180     CHECK_EQ(1, info_->graph_input_node_name_size());
    181     const string& input_node_name = info_->graph_input_node_name(0);
    182     const Tensor& input_tensor = input_tensor_cache_[input_node_name];
    183     const float input_val = *input_tensor.scalar<float>().data();
    184     // TODO(satok): Read NAME_B from node_a_plus_b
    185     const NodeDef& node_b = *node_def_map_.at(NAME_B);
    186     const TensorProto* proto = nullptr;
    187     TF_CHECK_OK(GetNodeAttr(node_b, "value", &proto));
    188     Tensor const_tensor;
    189     TF_CHECK_OK(RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
    190         *proto, &const_tensor));
    191     const float b_val = *const_tensor.scalar<float>().data();
    192     Tensor output_a_plus_b(DT_FLOAT, {});
    193     output_a_plus_b.flat<float>().data()[0] = input_val + b_val;
    194     output_tensor_buf_.emplace(info_->graph_output_node_name(0),
    195                                output_a_plus_b);
    196     return true;
    197   }
    198 
    199   bool TeardownGraph() final { return true; }
    200 
    201   bool FillInputNode(const string& node_name, const Tensor& tensor) final {
    202     input_tensor_cache_[node_name] = tensor;
    203     return true;
    204   }
    205 
    206   bool ReadOutputNode(const string& node_name,
    207                       TensorAllocatorFunc tensor_allocator) final {
    208     // TODO(satok): Specify tensor shape by using default_graph_tensor_shape.
    209     const Tensor& buffered_output_tensor = output_tensor_buf_.at(node_name);
    210     const TensorShape& output_shape = buffered_output_tensor.shape();
    211     Tensor* output_tensor = tensor_allocator(output_shape);
    212     CHECK_EQ(buffered_output_tensor.dtype(), output_tensor->dtype());
    213     CHECK(output_tensor->CopyFrom(buffered_output_tensor, output_shape));
    214     return true;
    215   }
    216 
    217   Status FuseRemoteGraph(const GraphDef& original_graph_def,
    218                          const std::vector<string>& /*inputs*/,
    219                          const std::vector<string>& /*outputs*/,
    220                          GraphDef* fused_graph_def) final {
    221     *fused_graph_def = original_graph_def;
    222     return Status::OK();
    223   }
    224 
    225   bool IsEnabled() const final { return true; }
    226 
    227  private:
    228   const RemoteFusedGraphExecuteInfo* info_;
    229   std::unordered_map<string, Tensor> input_tensor_cache_;
    230   std::unordered_map<string, const NodeDef*> node_def_map_;
    231   std::unordered_map<string, Tensor> output_tensor_buf_;
    232 };
    233 
    234 // 2. Register a builder of your custom executor
    235 namespace remote_fused_graph_execute_op {
    236 Status BuildRemoteFusedGraphExecutor(
    237     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
    238   executor->reset(new SampleRemoteFusedGraphExecutor());
    239   return Status::OK();
    240 }
    241 
    242 // This class instantiation registers executor to the
    243 // RemoteFusedGraphExecuteOp. This architecture makes executors to be
    244 // pluggable in order not to link unnecessary libraries.
    245 static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
    246     k_test_remote_fused_graph_executor_build(REMOTE_FUSED_EXECUTOR_NAME,
    247                                              BuildRemoteFusedGraphExecutor);
    248 }  // namespace remote_fused_graph_execute_op
    249 
    250 // 3. Create Graph transform function to fuse your graph
    251 static Status RewriteGraphToFusedGraph(const GraphDef& original_graph,
    252                                        GraphDef* fused_graph) {
    253   Scope root = Scope::NewRootScope();
    254   std::vector<Output> output_list;
    255   const Output op_a = BuildPlaceHolderOp(NAME_A, DT_FLOAT, {}, &root);
    256   output_list.emplace_back(op_a);
    257   const RemoteFusedGraphExecuteInfo execute_info =
    258       BuildRemoteFusedGraphExecuteInfo(original_graph);
    259   BuildRemoteFusedGraphExecuteOp(REMOTE_FUSED_EXECUTE_OP_NODE_NAME, output_list,
    260                                  1, execute_info, &root);
    261   GraphDef fused_graph_def;
    262   TF_CHECK_OK(root.ToGraphDef(&fused_graph_def));
    263   *fused_graph = fused_graph_def;
    264   return Status::OK();
    265 }
    266 
    267 // 4. Register transform function
    268 // You can register transform function by REGISTER_GRAPH_TRANSFORM.
    269 // In this test, we don't use graph transform tool to avoid linking to
    270 // the graph transform library.
    271 // To register transform function, you need to change the interface of
    272 // BuildFusedGraphDefOfAddGraph to
    273 // Status BuildFusedGraphDefOfAddGraph(
    274 // const GraphDef& original_graph, const TransformFuncContext& context,
    275 // GraphDef* output_graph_def);
    276 // Then register the function like:
    277 // REGISTER_GRAPH_TRANSFORM("rewrite_graph", RewriteGraph);
    278 
    279 // 5. Fuse the original graph and run the inference the new fused graph
    280 TEST(RemoteFusedExecuteGraphOp, EndToEndTest) {
    281   // 5.1 Load original graph
    282   GraphDef original_graph;
    283   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    284       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &original_graph));
    285 
    286   // 5.2 Fuse graph
    287   GraphDef fused_graph;
    288   TF_ASSERT_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph));
    289 
    290   // 5.3 Setup session
    291   std::vector<Tensor> output_tensors;
    292   SessionOptions session_options;
    293   session_options.env = Env::Default();
    294   std::unique_ptr<Session> session =
    295       std::unique_ptr<Session>(NewSession(session_options));
    296   Status status = session->Create(fused_graph);
    297   ASSERT_TRUE(status.ok());
    298   RunOptions run_options;
    299   run_options.set_trace_level(RunOptions::FULL_TRACE);
    300   RunMetadata run_metadata;
    301 
    302   // 5.4 Setup input
    303   Tensor input_a(DT_FLOAT, {});
    304   input_a.flat<float>().data()[0] = NODE_A_VAL2;
    305   std::vector<std::pair<string, Tensor>> inputs;
    306   inputs.emplace_back(NAME_A, input_a);
    307 
    308   // 5.5 Setup output
    309   const std::vector<string> outputs{REMOTE_FUSED_EXECUTE_OP_NODE_NAME};
    310 
    311   // 5.6 Run inference with all node as output
    312   status = session->Run(run_options, inputs, outputs, {}, &output_tensors,
    313                         &run_metadata);
    314   ASSERT_TRUE(status.ok());
    315 
    316   // 5.7 Check output tensor value
    317   ASSERT_EQ(1, output_tensors.size());
    318   EXPECT_NEAR(NODE_A_VAL2 + NODE_B_VAL,
    319               output_tensors.at(0).flat<float>().data()[0],
    320               FLOAT_VALUE_TOLERANCE);
    321 }
    322 
    323 ////////////////////////////
    324 // End-to-end test: End   //
    325 ////////////////////////////
    326 
    327 }  // namespace tensorflow
    328