Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <deque>
     19 #include <utility>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/resource_mgr.h"
     23 #include "tensorflow/core/framework/variant.h"
     24 #include "tensorflow/core/framework/variant_encode_decode.h"
     25 #include "tensorflow/core/kernels/ops_util.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/threadpool.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/platform/types.h"
     31 
     32 namespace tensorflow {
     33 
     34 namespace {
     35 
     36 class Mutex : public ResourceBase {
     37  public:
     38   explicit Mutex(OpKernelContext* c, const string& name)
     39       : locked_(false),
     40         thread_pool_(new thread::ThreadPool(
     41             c->env(), ThreadOptions(),
     42             strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)),
     43             1 /* num_threads */, false /* low_latency_hint */)),
     44         name_(name) {
     45     VLOG(2) << "Creating mutex with name " << name << ": " << this;
     46   }
     47 
     48   string DebugString() const override {
     49     return strings::StrCat("Mutex ", name_);
     50   }
     51 
     52   class LockReleaser {
     53    public:
     54     explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {}
     55 
     56     LockReleaser(const LockReleaser&) = delete;
     57     LockReleaser& operator=(const LockReleaser&) = delete;
     58 
     59     virtual ~LockReleaser() {
     60       VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_;
     61       if (mutex_) {
     62         mutex_lock lock(mutex_->mu_);
     63         mutex_->locked_ = false;
     64         mutex_->cv_.notify_all();
     65         VLOG(3) << "Destroying LockReleaser " << this
     66                 << ": sent notifications.";
     67       }
     68     }
     69 
     70    private:
     71     Mutex* mutex_;
     72   };
     73 
     74   struct SharedLockReleaser {
     75     std::shared_ptr<LockReleaser> shared_lock;
     76 
     77     explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
     78         : shared_lock(std::forward<decltype(lock)>(lock)) {
     79       VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
     80               << " count is: " << shared_lock.use_count();
     81     }
     82 
     83     SharedLockReleaser(SharedLockReleaser&& rhs)
     84         : shared_lock(std::move(rhs.shared_lock)) {
     85       VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
     86               << " count is: " << shared_lock.use_count();
     87     }
     88 
     89     SharedLockReleaser(const SharedLockReleaser& rhs)
     90         : shared_lock(rhs.shared_lock) {
     91       VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
     92               << " count is: " << shared_lock.use_count();
     93     }
     94 
     95     ~SharedLockReleaser() {
     96       VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
     97               << " count is: " << shared_lock.use_count();
     98     }
     99 
    100     void Encode(VariantTensorData*) const {
    101       // Not supported.
    102     }
    103 
    104     bool Decode(const VariantTensorData&) {
    105       return false;  // Not supported.
    106     }
    107   };
    108 
    109   void AcquireAsync(
    110       OpKernelContext* c,
    111       std::function<void(const Status& s, SharedLockReleaser lock)> fn) {
    112     CancellationManager* cm = c->cancellation_manager();
    113     CancellationToken token{};
    114     bool* cancelled = nullptr;
    115     if (cm) {
    116       cancelled = new bool(false);  // GUARDED_BY(mu_);
    117       token = cm->get_cancellation_token();
    118       const bool already_cancelled =
    119           !cm->RegisterCallback(token, [this, cancelled]() {
    120             mutex_lock lock(mu_);
    121             *cancelled = true;
    122             cv_.notify_all();
    123           });
    124       if (already_cancelled) {
    125         delete cancelled;
    126         fn(errors::Cancelled("Lock acquisition cancelled."),
    127            SharedLockReleaser{nullptr});
    128         return;
    129       }
    130     }
    131     thread_pool_->Schedule(std::bind(
    132         [this, cm, cancelled,
    133          token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
    134                     fn_) {
    135           bool local_locked;
    136           {
    137             mutex_lock lock(mu_);
    138             while (locked_ && !(cancelled && *cancelled)) {
    139               cv_.wait(lock);
    140             }
    141             local_locked = locked_ = !(cancelled && *cancelled);
    142           }
    143           if (cm) {
    144             cm->DeregisterCallback(token);
    145             delete cancelled;
    146           }
    147           if (local_locked) {  // Not cancelled.
    148             fn_(Status::OK(),
    149                 SharedLockReleaser{std::make_shared<LockReleaser>(this)});
    150           } else {
    151             fn_(errors::Cancelled("Lock acquisition cancelled."),
    152                 SharedLockReleaser{nullptr});
    153           }
    154         },
    155         std::move(fn)));
    156   }
    157 
    158  private:
    159   mutex mu_;
    160   condition_variable cv_ GUARDED_BY(mu_);
    161   bool locked_ GUARDED_BY(mu_);
    162   std::unique_ptr<thread::ThreadPool> thread_pool_;
    163   string name_;
    164 };
    165 
    166 }  // namespace
    167 
    168 class MutexLockOp : public AsyncOpKernel {
    169  public:
    170   explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {}
    171 
    172  public:
    173   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
    174     Mutex* mutex = nullptr;
    175     OP_REQUIRES_OK_ASYNC(
    176         c,
    177         LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
    178                                       [c](Mutex** ptr) {
    179                                         *ptr = new Mutex(
    180                                             c, HandleFromInput(c, 0).name());
    181                                         return Status::OK();
    182                                       }),
    183         done);
    184 
    185     Tensor* variant;
    186     OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant),
    187                          done);
    188 
    189     mutex->AcquireAsync(
    190         c, std::bind(
    191                [c, variant, mutex](DoneCallback done_,
    192                                    // End of bound arguments.
    193                                    const Status& s,
    194                                    Mutex::SharedLockReleaser&& lock) {
    195                  VLOG(2) << "Finished locking mutex " << mutex
    196                          << " with lock: " << lock.shared_lock.get()
    197                          << " status: " << s.ToString();
    198                  if (s.ok()) {
    199                    variant->scalar<Variant>()() = std::move(lock);
    200                  } else {
    201                    c->SetStatus(s);
    202                  }
    203                  mutex->Unref();
    204                  done_();
    205                },
    206                std::move(done), std::placeholders::_1, std::placeholders::_2));
    207   }
    208 };
    209 
    210 class ConsumeMutexLockOp : public OpKernel {
    211  public:
    212   explicit ConsumeMutexLockOp(OpKernelConstruction* context)
    213       : OpKernel(context) {}
    214 
    215   void Compute(OpKernelContext* c) override {
    216     VLOG(2) << "Executing ConsumeMutexLockOp";
    217     const Tensor& lock_t = c->input(0);
    218     OP_REQUIRES(
    219         c, lock_t.dims() == 0,
    220         errors::InvalidArgument("Expected input to be a scalar, saw shape: ",
    221                                 lock_t.shape().DebugString()));
    222     OP_REQUIRES(
    223         c, lock_t.dtype() == DT_VARIANT,
    224         errors::InvalidArgument("Expected input to be a variant, saw type: ",
    225                                 DataTypeString(lock_t.dtype())));
    226     const auto* lock =
    227         lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>();
    228     OP_REQUIRES(c, lock,
    229                 errors::InvalidArgument(
    230                     "Expected input to contain a SharedLockReleaser "
    231                     "object, but saw variant: '",
    232                     lock_t.scalar<Variant>()().DebugString(), "'"));
    233     const int use_count = lock->shared_lock.use_count();
    234     OP_REQUIRES(
    235         c, use_count == 1,
    236         errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
    237                                 use_count));
    238   }
    239 
    240   bool IsExpensive() override { return false; }
    241 };
    242 
    243 REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
    244 
    245 REGISTER_KERNEL_BUILDER(Name("MutexLock")
    246                             .Device(DEVICE_GPU)
    247                             .HostMemory("mutex_lock")
    248                             .HostMemory("mutex"),
    249                         MutexLockOp);
    250 
    251 REGISTER_KERNEL_BUILDER(
    252     Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"),
    253     ResourceHandleOp<Mutex>);
    254 
    255 REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_GPU),
    256                         ResourceHandleOp<Mutex>);
    257 
    258 REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
    259                         ConsumeMutexLockOp);
    260 
    261 REGISTER_KERNEL_BUILDER(
    262     Name("ConsumeMutexLock").Device(DEVICE_GPU).HostMemory("mutex_lock"),
    263     ConsumeMutexLockOp);
    264 
    265 }  // namespace tensorflow
    266