Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
     18 
     19 #include <unordered_map>
     20 
     21 #include "tensorflow/core/distributed_runtime/worker_interface.h"
     22 #include "tensorflow/core/framework/cancellation.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 #include "tensorflow/core/platform/macros.h"
     25 #include "tensorflow/core/platform/mutex.h"
     26 #include "tensorflow/core/platform/types.h"
     27 
     28 namespace tensorflow {
     29 
     30 // PartialRunMgr keeps track of pending partial run requests, and ensures that
     31 // the partial run is only marked complete when the corresponding executor is
     32 // run to completion.
     33 //
     34 // In tensorflow workers, the executor runs operations asynchronously until
     35 // specified fetches (operations that return tensors) or targets (operations
     36 // that don't return tensors) are reached. A PartialRun has two components: a
     37 // setup which specifies all desired fetches and targets, and run calls that
     38 // specify fetch values (from the setup calls) to retrieve.
     39 // On the last partial run call, it is possible to satisfy the
     40 // required fetches before the executor has completed running the graph to all
     41 // the desired targets.
     42 // PartialRunMgr is used to ensure that we don't complete and return the final
     43 // partial run call to the user until both the partial run and executor have
     44 // completed.
     45 //
     46 // PartialRunMgr is thread-safe.
     47 class PartialRunMgr {
     48  public:
     49   // Find or create the CancellationManager associated with step_id.
     50   // The PartialRunMgr owns the cancellation_manager.
     51   // Returns true if a new CancellationManager was created
     52   // (i.e this is a new partial run).
     53   bool FindOrCreate(int step_id, CancellationManager** cancellation_manager);
     54 
     55   // Calls the final callback if the PartialRunRequest has already completed.
     56   // Otherwise stores the executor_status to be propagated when the
     57   // PartialRunRequest completes (PartialRunDone has been called).
     58   void ExecutorDone(int step_id, const Status& executor_status);
     59 
     60   // Calls done if the executor has already completed (ExecutorDone has been
     61   // called). Otherwise, stores the status and done callback, calling them when
     62   // ExecutorDone is called. The callback will either be called by the calling
     63   // thread of either PartialRunDone or ExecutorDone.
     64   // If executor_status in ExecutorDone is not OK, it takes precedence over
     65   // status and is passed to the done callback.
     66   void PartialRunDone(int step_id, StatusCallback done, const Status& status);
     67 
     68  private:
     69   // PartialRunState stores state associated with a pending partial run request.
     70   // This is protected by the mutex in PartialRunMgr.
     71   struct PartialRunState {
     72     std::unique_ptr<CancellationManager> cancellation_manager;
     73 
     74     bool executor_done = false;
     75     StatusCallback final_callback = nullptr;
     76     Status final_status;
     77   };
     78 
     79   mutex mu_;
     80 
     81   std::unordered_map<int, std::unique_ptr<PartialRunState>>
     82       step_id_to_partial_run_ GUARDED_BY(mu_);
     83 };
     84 
     85 }  // namespace tensorflow
     86 
     87 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
     88