Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
     17 
     18 #include <utility>
     19 
     20 #include "tensorflow/core/common_runtime/device_mgr.h"
     21 #include "tensorflow/core/common_runtime/renamed_device.h"
     22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
     23 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 #include "tensorflow/core/protobuf/cluster.pb.h"
     26 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
     27 
     28 namespace tensorflow {
     29 
     30 SessionMgr::SessionMgr(
     31     WorkerEnv* worker_env, const string& default_worker_name,
     32     std::unique_ptr<WorkerCacheInterface> default_worker_cache,
     33     WorkerCacheFactory worker_cache_factory)
     34     : worker_env_(worker_env),
     35       default_worker_cache_(std::move(default_worker_cache)),
     36       legacy_session_(new WorkerSession(
     37           "", default_worker_name,
     38           std::unique_ptr<WorkerCacheInterface>(
     39               new WorkerCacheWrapper(default_worker_cache_.get())),
     40           std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
     41           std::unique_ptr<GraphMgr>(
     42               new GraphMgr(worker_env, worker_env->device_mgr)))),
     43       worker_cache_factory_(std::move(worker_cache_factory)) {}
     44 
     45 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
     46   return strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:",
     47                          server_def.task_index());
     48 }
     49 
     50 Status SessionMgr::CreateSession(const string& session,
     51                                  const ServerDef& server_def,
     52                                  bool isolate_session_state) {
     53   mutex_lock l(mu_);
     54   if (session.empty()) {
     55     return errors::InvalidArgument("Session must be non-empty.");
     56   }
     57 
     58   const string worker_name = WorkerNameFromServerDef(server_def);
     59 
     60   WorkerCacheInterface* worker_cache = nullptr;
     61   if (server_def.cluster().job().empty()) {
     62     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
     63   } else {
     64     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
     65   }
     66 
     67   if (worker_cache != nullptr & default_worker_cache_.get() != nullptr) {
     68     worker_cache->SetLogging(this->is_logging_active_);
     69   }
     70 
     71   CHECK(!worker_env_->local_devices.empty())
     72       << "The WorkerEnv must have at least one device in `local_devices`.";
     73 
     74   std::vector<Device*> renamed_devices;
     75   for (Device* d : worker_env_->local_devices) {
     76     renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
     77         worker_name, d, false, isolate_session_state));
     78   }
     79   std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
     80 
     81   std::unique_ptr<GraphMgr> graph_mgr(
     82       new GraphMgr(worker_env_, device_mgr.get()));
     83 
     84   std::shared_ptr<WorkerSession> worker_session(new WorkerSession(
     85       session, worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
     86       std::move(device_mgr), std::move(graph_mgr)));
     87 
     88   sessions_.insert(std::make_pair(session, std::move(worker_session)));
     89   return Status::OK();
     90 }
     91 
     92 Status SessionMgr::DeleteSession(const string& session) {
     93   mutex_lock l(mu_);
     94   auto it = sessions_.find(session);
     95   if (it != sessions_.end()) {
     96     sessions_.erase(it);
     97   }
     98   return Status::OK();
     99 }
    100 
    101 std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSessionUnlocked(
    102     const string& session) {
    103   auto it = sessions_.find(session);
    104   if (it == sessions_.end()) {
    105     return legacy_session_;
    106   } else {
    107     return it->second;
    108   }
    109 }
    110 
    111 std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSession(
    112     const string& session) {
    113   mutex_lock l(mu_);
    114   return WorkerSessionForSessionUnlocked(session);
    115 }
    116 
    117 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
    118   return legacy_session_;
    119 }
    120 
    121 void SessionMgr::SetLogging(bool active) {
    122   mutex_lock l(mu_);
    123   this->is_logging_active_ = active;
    124   // Legacy Session
    125   if (legacy_session_) {
    126     auto* worker_cache = legacy_session_->worker_cache.get();
    127     if (worker_cache) {
    128       worker_cache->SetLogging(active);
    129     }
    130   }
    131 
    132   for (const auto& session_kv : sessions_) {
    133     auto session = session_kv.second.get();
    134     if (session) {
    135       auto* worker_cache = session->worker_cache.get();
    136       if (worker_cache) {
    137         worker_cache->SetLogging(active);
    138       }
    139     }
    140   }
    141 }
    142 
    143 void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
    144                               LoggingResponse* response) {
    145   mutex_lock l(mu_);
    146   // Legacy Session
    147   if (legacy_session_) {
    148     auto* worker_cache = legacy_session_->worker_cache.get();
    149     if (worker_cache) {
    150       auto step_stats = StepStats();
    151       if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
    152         auto* labeled_step_stats = response->add_step();
    153         labeled_step_stats->set_step_id(step_id);
    154         labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
    155       }
    156     }
    157   }
    158   for (const auto& session_kv : sessions_) {
    159     auto session = session_kv.second.get();
    160     if (session) {
    161       auto* worker_cache = session->worker_cache.get();
    162       if (worker_cache) {
    163         auto step_stats = StepStats();
    164         if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
    165           auto* labeled_step_stats = response->add_step();
    166           labeled_step_stats->set_step_id(step_id);
    167           labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
    168         }
    169       }
    170     }
    171   }
    172 }
    173 
    174 void SessionMgr::ClearLogs() {
    175   mutex_lock l(mu_);
    176   // Legacy Session
    177   if (legacy_session_) {
    178     auto* worker_cache = legacy_session_->worker_cache.get();
    179     if (worker_cache) {
    180       worker_cache->ClearLogs();
    181     }
    182   }
    183 
    184   for (const auto& session_kv : sessions_) {
    185     auto session = session_kv.second.get();
    186     if (session) {
    187       auto* worker_cache = session->worker_cache.get();
    188       if (worker_cache) {
    189         worker_cache->ClearLogs();
    190       }
    191     }
    192   }
    193 }
    194 }  // namespace tensorflow
    195