Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
     18 
     19 #include <functional>
     20 
     21 #include "tensorflow/core/distributed_runtime/call_options.h"
     22 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
     23 #include "tensorflow/core/lib/core/notification.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/platform/types.h"
     26 #include "tensorflow/core/protobuf/worker.pb.h"
     27 
     28 namespace tensorflow {
     29 
     30 // Status callback.
     31 typedef std::function<void(const Status&)> StatusCallback;
     32 
     33 // Custom decoder for a response to RecvTensorAsync.
     34 class TensorResponse;
     35 
     36 // Interface for talking with the TensorFlow Worker service.
     37 class WorkerInterface {
     38  public:
     39   virtual void GetStatusAsync(const GetStatusRequest* request,
     40                               GetStatusResponse* response,
     41                               StatusCallback done) = 0;
     42 
     43   virtual void CreateWorkerSessionAsync(
     44       const CreateWorkerSessionRequest* request,
     45       CreateWorkerSessionResponse* response, StatusCallback done) = 0;
     46 
     47   virtual void DeleteWorkerSessionAsync(
     48       const DeleteWorkerSessionRequest* request,
     49       DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
     50 
     51   virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
     52                                   RegisterGraphResponse* response,
     53                                   StatusCallback done) = 0;
     54 
     55   virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
     56                                     DeregisterGraphResponse* response,
     57                                     StatusCallback done) = 0;
     58 
     59   virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
     60                              MutableRunGraphResponseWrapper* repsonse,
     61                              StatusCallback done) = 0;
     62 
     63   virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
     64                              RunGraphResponse* response, StatusCallback done) {
     65     // TODO(mrry): Convert this to std::bind/std::move if the overhead
     66     // of std::function copying becomes too much.
     67     RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
     68     MutableRunGraphResponseWrapper* wrapped_response =
     69         new NonOwnedProtoRunGraphResponse(response);
     70     RunGraphAsync(opts, wrapped_request, wrapped_response,
     71                   [wrapped_request, wrapped_response, done](const Status& s) {
     72                     done(s);
     73                     delete wrapped_request;
     74                     delete wrapped_response;
     75                   });
     76   }
     77 
     78   // Returns a request object for use in calls to
     79   // `RunGraphAsync()`. Ownership is transferred to the caller.
     80   //
     81   // The message returned from this method must only be used in a
     82   // `RunGraph()` call on the same `WorkerInterface` instance.
     83   virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
     84     return new MutableProtoRunGraphRequest;
     85   }
     86 
     87   // Returns a response object for use in calls to
     88   // `RunGraphAsync()`. Ownership is transferred to the caller.
     89   //
     90   // The message returned from this method must only be used in a
     91   // `RunGraph()` call on the same `WorkerInterface` instance.
     92   virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
     93     return new OwnedProtoRunGraphResponse;
     94   }
     95 
     96   virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
     97                                  CleanupGraphResponse* response,
     98                                  StatusCallback done) = 0;
     99 
    100   virtual void CleanupAllAsync(const CleanupAllRequest* request,
    101                                CleanupAllResponse* response,
    102                                StatusCallback done) = 0;
    103 
    104   virtual void RecvTensorAsync(CallOptions* opts,
    105                                const RecvTensorRequest* request,
    106                                TensorResponse* response,
    107                                StatusCallback done) = 0;
    108 
    109   virtual void LoggingAsync(const LoggingRequest* request,
    110                             LoggingResponse* response, StatusCallback done) = 0;
    111 
    112   virtual void TracingAsync(const TracingRequest* request,
    113                             TracingResponse* response, StatusCallback done) = 0;
    114 
    115   Status GetStatus(const GetStatusRequest* request,
    116                    GetStatusResponse* response) {
    117     return CallAndWait(&ME::GetStatusAsync, request, response);
    118   }
    119 
    120   Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
    121                              CreateWorkerSessionResponse* response) {
    122     return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
    123   }
    124 
    125   Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
    126                              DeleteWorkerSessionResponse* response) {
    127     return CallAndWait(&ME::DeleteWorkerSessionAsync, request, response);
    128   }
    129 
    130   Status RegisterGraph(const RegisterGraphRequest* request,
    131                        RegisterGraphResponse* response) {
    132     return CallAndWait(&ME::RegisterGraphAsync, request, response);
    133   }
    134 
    135   Status DeregisterGraph(const DeregisterGraphRequest* request,
    136                          DeregisterGraphResponse* response) {
    137     return CallAndWait(&ME::DeregisterGraphAsync, request, response);
    138   }
    139 
    140   Status CleanupGraph(const CleanupGraphRequest* request,
    141                       CleanupGraphResponse* response) {
    142     return CallAndWait(&ME::CleanupGraphAsync, request, response);
    143   }
    144 
    145   Status CleanupAll(const CleanupAllRequest* request,
    146                     CleanupAllResponse* response) {
    147     return CallAndWait(&ME::CleanupAllAsync, request, response);
    148   }
    149 
    150   Status Logging(const LoggingRequest* request, LoggingResponse* response) {
    151     return CallAndWait(&ME::LoggingAsync, request, response);
    152   }
    153 
    154   Status Tracing(const TracingRequest* request, TracingResponse* response) {
    155     return CallAndWait(&ME::TracingAsync, request, response);
    156   }
    157 
    158  protected:
    159   // Instances of WorkerInterface must be deleted by a call to
    160   // WorkerCacheInterface::ReleaseWorker().
    161   virtual ~WorkerInterface() {}
    162   friend class WorkerCacheInterface;
    163 
    164   // NOTE: This should only be called by implementations of this
    165   // interface whose CreateRunGraphResponse() method returns a
    166   // proto-based wrappers for the RunGraphResponse message.
    167   RunGraphResponse* get_proto_from_wrapper(
    168       MutableRunGraphResponseWrapper* wrapper) {
    169     return wrapper->get_proto();
    170   }
    171 
    172  private:
    173   typedef WorkerInterface ME;
    174 
    175   template <typename Method, typename Req, typename Resp>
    176   Status CallAndWait(Method func, const Req* req, Resp* resp) {
    177     Status ret;
    178     Notification n;
    179     (this->*func)(req, resp, [&ret, &n](const Status& s) {
    180       ret = s;
    181       n.Notify();
    182     });
    183     n.WaitForNotification();
    184     return ret;
    185   }
    186 };
    187 
    188 }  // namespace tensorflow
    189 
    190 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
    191