Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/local_master.h"
     17 
     18 #include <unordered_map>
     19 
     20 #include "tensorflow/core/distributed_runtime/master.h"
     21 #include "tensorflow/core/platform/mutex.h"
     22 
     23 namespace tensorflow {
     24 
     25 namespace {
     26 Status WaitForNotification(CallOptions* call_options,
     27                            const int64 default_timeout_in_ms, Notification* n) {
     28   int64 timeout_in_ms = call_options->GetTimeout();
     29   if (timeout_in_ms == 0) {
     30     timeout_in_ms = default_timeout_in_ms;
     31   }
     32   if (timeout_in_ms > 0) {
     33     int64 timeout_in_us = timeout_in_ms * 1000;
     34     bool notified = WaitForNotificationWithTimeout(n, timeout_in_us);
     35     if (!notified) {
     36       call_options->StartCancel();
     37       // The call has borrowed pointers to the request and response
     38       // messages, so we must still wait for the call to complete.
     39       n->WaitForNotification();
     40       return errors::DeadlineExceeded("Operation timed out.");
     41     }
     42   } else {
     43     n->WaitForNotification();
     44   }
     45   return Status::OK();
     46 }
     47 }  // namespace
     48 
     49 LocalMaster::LocalMaster(Master* master_impl, const int64 default_timeout_in_ms)
     50     : master_impl_(master_impl),
     51       default_timeout_in_ms_(default_timeout_in_ms) {}
     52 
     53 Status LocalMaster::CreateSession(CallOptions* call_options,
     54                                   const CreateSessionRequest* request,
     55                                   CreateSessionResponse* response) {
     56   Notification n;
     57   Status ret;
     58   master_impl_->CreateSession(request, response, [&n, &ret](const Status& s) {
     59     ret.Update(s);
     60     n.Notify();
     61   });
     62   TF_RETURN_IF_ERROR(
     63       WaitForNotification(call_options, default_timeout_in_ms_, &n));
     64   return ret;
     65 }
     66 
     67 Status LocalMaster::ExtendSession(CallOptions* call_options,
     68                                   const ExtendSessionRequest* request,
     69                                   ExtendSessionResponse* response) {
     70   Notification n;
     71   Status ret;
     72   master_impl_->ExtendSession(request, response, [&n, &ret](const Status& s) {
     73     ret.Update(s);
     74     n.Notify();
     75   });
     76   TF_RETURN_IF_ERROR(
     77       WaitForNotification(call_options, default_timeout_in_ms_, &n));
     78   return ret;
     79 }
     80 
     81 Status LocalMaster::PartialRunSetup(CallOptions* call_options,
     82                                     const PartialRunSetupRequest* request,
     83                                     PartialRunSetupResponse* response) {
     84   Notification n;
     85   Status ret;
     86   master_impl_->PartialRunSetup(request, response, [&n, &ret](const Status& s) {
     87     ret.Update(s);
     88     n.Notify();
     89   });
     90   TF_RETURN_IF_ERROR(
     91       WaitForNotification(call_options, default_timeout_in_ms_, &n));
     92   return ret;
     93 }
     94 
     95 Status LocalMaster::RunStep(CallOptions* call_options,
     96                             RunStepRequestWrapper* request,
     97                             MutableRunStepResponseWrapper* response) {
     98   Notification n;
     99   Status ret;
    100   master_impl_->RunStep(call_options, request, response,
    101                         [&n, &ret](const Status& s) {
    102                           ret.Update(s);
    103                           n.Notify();
    104                         });
    105   TF_RETURN_IF_ERROR(
    106       WaitForNotification(call_options, default_timeout_in_ms_, &n));
    107   return ret;
    108 }
    109 
    110 MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() {
    111   return new InMemoryRunStepRequest;
    112 }
    113 
    114 MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() {
    115   return new InMemoryRunStepResponse;
    116 }
    117 
    118 Status LocalMaster::CloseSession(CallOptions* call_options,
    119                                  const CloseSessionRequest* request,
    120                                  CloseSessionResponse* response) {
    121   Notification n;
    122   Status ret;
    123   master_impl_->CloseSession(request, response, [&n, &ret](const Status& s) {
    124     ret.Update(s);
    125     n.Notify();
    126   });
    127   TF_RETURN_IF_ERROR(
    128       WaitForNotification(call_options, default_timeout_in_ms_, &n));
    129   return ret;
    130 }
    131 
    132 Status LocalMaster::ListDevices(CallOptions* call_options,
    133                                 const ListDevicesRequest* request,
    134                                 ListDevicesResponse* response) {
    135   Notification n;
    136   Status ret;
    137   master_impl_->ListDevices(request, response, [&n, &ret](const Status& s) {
    138     ret.Update(s);
    139     n.Notify();
    140   });
    141   TF_RETURN_IF_ERROR(
    142       WaitForNotification(call_options, default_timeout_in_ms_, &n));
    143   return ret;
    144 }
    145 
    146 Status LocalMaster::Reset(CallOptions* call_options,
    147                           const ResetRequest* request,
    148                           ResetResponse* response) {
    149   Notification n;
    150   Status ret;
    151   master_impl_->Reset(request, response, [&n, &ret](const Status& s) {
    152     ret.Update(s);
    153     n.Notify();
    154   });
    155   TF_RETURN_IF_ERROR(
    156       WaitForNotification(call_options, default_timeout_in_ms_, &n));
    157   return ret;
    158 }
    159 
    160 namespace {
    161 mutex* get_local_master_registry_lock() {
    162   static mutex local_master_registry_lock(LINKER_INITIALIZED);
    163   return &local_master_registry_lock;
    164 }
    165 
    166 struct MasterInfo {
    167   Master* master;
    168   const int64 default_timeout_in_ms;
    169 
    170   MasterInfo(Master* master, const int64 default_timeout_in_ms)
    171       : master(master), default_timeout_in_ms(default_timeout_in_ms) {}
    172 };
    173 
    174 typedef std::unordered_map<string, MasterInfo> LocalMasterRegistry;
    175 LocalMasterRegistry* local_master_registry() {
    176   static LocalMasterRegistry* local_master_registry_ = new LocalMasterRegistry;
    177   return local_master_registry_;
    178 }
    179 }  // namespace
    180 
    181 /* static */
    182 void LocalMaster::Register(const string& target, Master* master,
    183                            int64 default_timeout_in_ms) {
    184   mutex_lock l(*get_local_master_registry_lock());
    185   local_master_registry()->insert(
    186       {target, MasterInfo(master, default_timeout_in_ms)});
    187 }
    188 
    189 /* static */
    190 std::unique_ptr<LocalMaster> LocalMaster::Lookup(const string& target) {
    191   std::unique_ptr<LocalMaster> ret;
    192   mutex_lock l(*get_local_master_registry_lock());
    193   auto iter = local_master_registry()->find(target);
    194   if (iter != local_master_registry()->end()) {
    195     ret.reset(new LocalMaster(iter->second.master,
    196                               iter->second.default_timeout_in_ms));
    197   }
    198   return ret;
    199 }
    200 
    201 }  // namespace tensorflow
    202