Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/remote_fused_graph_ops.cc.
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
     20 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
     21 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
     22 #include "tensorflow/core/lib/core/status.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 class RemoteFusedGraphExecuteOp : public OpKernel {
     28  public:
     29   explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
     30       : OpKernel(ctx), execute_info_() {
     31     string serialized_proto;
     32     OP_REQUIRES_OK(
     33         ctx, ctx->GetAttr(RemoteFusedGraphExecuteUtils::
     34                               ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
     35                           &serialized_proto));
     36     OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_));
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_));
     38     execute_info_.ParseFromString(serialized_proto);
     39     if (!execute_info_.executor_name().empty()) {
     40       const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func =
     41           RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
     42               execute_info_.executor_name());
     43       if (build_func != nullptr) {
     44         TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
     45         CHECK(remote_fused_graph_executor_->IsEnabled());
     46       } else {
     47         LOG(ERROR) << "Executor not found for "
     48                    << execute_info_.executor_name();
     49       }
     50     }
     51 
     52     if (remote_fused_graph_executor_) {
     53       // 1. Initialize remote processor
     54       remote_fused_graph_executor_->Init(execute_info_);
     55       // Explicitly clear serialized executor parameter after initialization
     56       // to release unnecessary memory.
     57       execute_info_.clear_serialized_executor_parameters();
     58 
     59       // 2. Setup graph in remote processor
     60       remote_fused_graph_executor_->SetupGraph();
     61     }
     62   }
     63 
     64   ~RemoteFusedGraphExecuteOp() final {
     65     if (remote_fused_graph_executor_) {
     66       // 6. Teardown graph in remote processor
     67       remote_fused_graph_executor_->TeardownGraph();
     68 
     69       // 7. Finalize remote processor
     70       remote_fused_graph_executor_->Finalize();
     71     }
     72   }
     73 
     74   void Compute(OpKernelContext* const ctx) final {
     75     CHECK(ctx != nullptr);
     76     const int input_count = ctx->num_inputs();
     77     const int graph_input_count = execute_info_.graph_input_node_name_size();
     78     CHECK(input_count == graph_input_count &&
     79           input_count == input_types_.size())
     80         << "input_count = " << input_count
     81         << ", gt input count = " << execute_info_.graph_input_node_name_size()
     82         << ", type count = " << input_types_.size();
     83 
     84     // 3. Send first data type inputs into remote processor
     85     for (int i = 0; i < graph_input_count; ++i) {
     86       const Tensor& input_tensor = ctx->input(i);
     87       const string& input_node_name = execute_info_.graph_input_node_name(i);
     88       if (remote_fused_graph_executor_) {
     89         remote_fused_graph_executor_->FillInputNode(input_node_name,
     90                                                     input_tensor);
     91       }
     92     }
     93 
     94     // 4. Execute graph in remote processor
     95     if (remote_fused_graph_executor_) {
     96       remote_fused_graph_executor_->ExecuteGraph();
     97     }
     98 
     99     // 5. Load outputs from remote processor
    100     const int output_count = ctx->num_outputs();
    101     CHECK(output_count == execute_info_.graph_output_node_name_size() &&
    102           output_count == output_types_.size());
    103     for (int i = 0; i < output_count; ++i) {
    104       Tensor* output = nullptr;
    105       const string& output_node_name = execute_info_.graph_output_node_name(i);
    106       if (remote_fused_graph_executor_) {
    107         remote_fused_graph_executor_->ReadOutputNode(
    108             output_node_name,
    109             [i, &ctx, &output](const TensorShape& shape) -> Tensor* {
    110               TF_CHECK_OK(ctx->allocate_output(i, shape, &output));
    111               return output;
    112             });
    113       } else {
    114         // For compatibility purpose, returns an empty tensor with specified
    115         // data type as output if no executor is used.
    116         Tensor* output = nullptr;
    117         TensorShape ts({});
    118         TF_CHECK_OK(ctx->allocate_output(i, ts, &output));
    119       }
    120     }
    121   }
    122 
    123   bool IsExpensive() final { return true; }
    124 
    125  private:
    126   RemoteFusedGraphExecuteInfo execute_info_;
    127   std::unique_ptr<IRemoteFusedGraphExecutor> remote_fused_graph_executor_;
    128   DataTypeVector input_types_;
    129   DataTypeVector output_types_;
    130 
    131   TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp);
    132 };
    133 
    134 REGISTER_KERNEL_BUILDER(Name("RemoteFusedGraphExecute").Device(DEVICE_CPU),
    135                         RemoteFusedGraphExecuteOp);
    136 
    137 }  // namespace tensorflow
    138