Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
     18 
     19 #include <map>
     20 #include <memory>
     21 #include <utility>
     22 
     23 #include "tensorflow/compiler/xla/executable_run_options.h"
     24 #include "tensorflow/compiler/xla/service/backend.h"
     25 #include "tensorflow/compiler/xla/service/stream_pool.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace xla {
     37 
     38 // Represents an asynchronously launched execution. Owns the stream (from the
     39 // passed run_options->stream()) on which the execution is launched and releases
     40 // the stream when destructed.
     41 class AsyncExecution {
     42  public:
     43   AsyncExecution(Backend* backend, std::vector<StreamPool::Ptr> streams,
     44                  const ExecutionProfile& profile, GlobalDataHandle result);
     45 
     46   Status BlockUntilDone() const;
     47 
     48   const GlobalDataHandle& result() const { return result_; }
     49 
     50   const ExecutionProfile& profile() const { return profile_; }
     51 
     52  private:
     53   // Backend to execute the computation on.
     54   Backend* backend_;
     55 
     56   // Stream on which the execution is launched.
     57   std::vector<StreamPool::Ptr> streams_;
     58 
     59   // Profile object of the execution to be returned to the user.
     60   ExecutionProfile profile_;
     61 
     62   // Data handle to the result of the execution. Data represented by this handle
     63   // is valid only after BlockUntilDone() is called.
     64   GlobalDataHandle result_;
     65 };
     66 
     67 // Tracks asynchronously launched executions for the XLA service.
     68 class ExecutionTracker {
     69  public:
     70   ExecutionTracker();
     71 
     72   // Registers an execution with its backend, streams, and data handle to the
     73   // execution result. Returns a handle for the registered execution.
     74   ExecutionHandle Register(Backend* backend,
     75                            std::vector<StreamPool::Ptr> stream,
     76                            const ExecutionProfile& profile,
     77                            GlobalDataHandle data);
     78 
     79   // Unregisters the execution for the given handle.
     80   Status Unregister(const ExecutionHandle& handle);
     81 
     82   // Resolves the given ExecutionHandle to an AsyncExecution. Returns an
     83   // error status if the given handle is not found, which means that the
     84   // execution is not yet registered or already unregistered.
     85   StatusOr<const AsyncExecution*> Resolve(const ExecutionHandle& handle);
     86 
     87  private:
     88   // The next handle to assign to an execution.
     89   int64 next_handle_ GUARDED_BY(execution_mutex_);
     90 
     91   // Mapping from ExecutionHandle handle to the corresponding registered
     92   // AsyncExecution object.
     93   std::map<int64, std::unique_ptr<AsyncExecution>> handle_to_execution_
     94       GUARDED_BY(execution_mutex_);
     95 
     96   tensorflow::mutex execution_mutex_;  // Guards the execution mapping.
     97 
     98   TF_DISALLOW_COPY_AND_ASSIGN(ExecutionTracker);
     99 };
    100 
    101 }  // namespace xla
    102 
    103 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
    104