Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/conditional_accumulator_base.h"
     17 
     18 namespace tensorflow {
     19 
     20 ConditionalAccumulatorBase::ConditionalAccumulatorBase(
     21     const DataType& dtype, const PartialTensorShape& shape, const string& name)
     22     : dtype_(dtype), shape_(shape), name_(name) {
     23   counter_ = 0;
     24   current_global_step_ = 0;
     25 }
     26 
     27 Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) {
     28   // TODO(xinghao@): implement the checks for the node definition
     29   return Status::OK();
     30 }
     31 
     32 /**
     33  * Sets the time step of the accumulator to be in line with the global time
     34  * step. Logs warning if the accumulator's time step is already larger than the
     35  * provided time step.
     36  */
     37 Status ConditionalAccumulatorBase::SetGlobalStep(int64 new_global_step) {
     38   mutex_lock lock(mu_);
     39   if (new_global_step < current_global_step_) {
     40     LOG(WARNING) << "Attempt to set current_global_step_ to smaller value: "
     41                  << "current_global_step_ = " << current_global_step_
     42                  << " >= " << new_global_step << " = new_global_step.";
     43   }
     44   current_global_step_ = new_global_step;
     45   return Status::OK();
     46 }
     47 
     48 /**
     49  * Logs an attempt to extract the average gradient, and tries to flush all
     50  * TakeGrad attempts.
     51  * A TakeGrad attempt is blocked until num_required > counter_, i.e.,
     52  * sufficient gradients have been accumulated.
     53  *
     54  * num_required: Number of gradients that needs to be accumulated before the
     55  *               attempt is unblocked.
     56  * ctx:          Context in which the op is executed.
     57  * callback:     A callback to be executed after the attempt has been completed.
     58  */
     59 void ConditionalAccumulatorBase::TryTakeGrad(int num_required,
     60                                              OpKernelContext* ctx,
     61                                              DoneCallback callback) {
     62   if (num_required <= 0) {
     63     ctx->CtxFailureWithWarning(errors::InvalidArgument(
     64         "Argument num_required must be positive, but was ", num_required));
     65     callback();
     66   } else {
     67     CancellationManager* cm = ctx->cancellation_manager();
     68     CancellationToken token = cm->get_cancellation_token();
     69     bool already_cancelled;
     70     {
     71       mutex_lock l(mu_);
     72       already_cancelled = !cm->RegisterCallback(
     73           token, [this, cm, token]() { Cancel(cm, token); });
     74       if (!already_cancelled) {
     75         takegrad_attempts_.emplace_back(
     76             num_required, callback, ctx, cm, token,
     77             [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     78               if (counter_ >= attempt->elements_requested) {
     79                 bool successful_take_grad = TakeGradLockedHelper(
     80                     attempt->context, attempt->done_callback);
     81                 if (successful_take_grad) {
     82                   return kComplete;
     83                 } else {
     84                   // Try again
     85                   return kNoProgress;
     86                 }
     87               } else {
     88                 return kNoProgress;
     89               }
     90             });
     91       }
     92     }
     93     if (!already_cancelled) {
     94       FlushUnlocked();
     95     } else {
     96       ctx->SetStatus(errors::Cancelled("TakeGrad operation was cancelled"));
     97       callback();
     98     }
     99   }
    100 }
    101 
    102 /**
    103  * Cancellation callback.
    104  */
    105 void ConditionalAccumulatorBase::Cancel(
    106     CancellationManager* cancellation_manager, CancellationToken token) {
    107   DoneCallback callback = nullptr;
    108   {
    109     mutex_lock lock(mu_);
    110 
    111     for (Attempt& attempt : takegrad_attempts_) {
    112       if (attempt.cancellation_manager == cancellation_manager &&
    113           attempt.cancellation_token == token) {
    114         if (!attempt.is_cancelled) {
    115           attempt.is_cancelled = true;
    116           attempt.context->SetStatus(
    117               errors::Cancelled("TakeGrad operation was cancelled"));
    118           std::swap(callback, attempt.done_callback);
    119         }
    120         break;
    121       }
    122     }
    123   }
    124   if (callback) {
    125     callback();
    126     FlushUnlocked();
    127   }
    128 }
    129 
    130 /**
    131  * Try to flush logged, blocked TakeGrad attempts.
    132  */
    133 bool ConditionalAccumulatorBase::TryAttemptLocked(
    134     std::vector<CleanUp>* clean_up) {
    135   bool progress = false;
    136   bool done = false;
    137   while (!done && !takegrad_attempts_.empty()) {
    138     if (takegrad_attempts_.front().is_cancelled) {
    139       VLOG(1) << "Skipping cancelled TakeGrad attempt";
    140       takegrad_attempts_.pop_front();
    141     } else {
    142       Attempt* cur_attempt = &takegrad_attempts_.front();
    143       switch (cur_attempt->run_callback(cur_attempt)) {
    144         case kNoProgress:
    145           done = true;
    146           break;
    147         case kComplete:
    148           progress = true;
    149           clean_up->emplace_back(std::move(cur_attempt->done_callback),
    150                                  cur_attempt->cancellation_token,
    151                                  cur_attempt->context->cancellation_manager());
    152           takegrad_attempts_.pop_front();
    153           break;
    154       }
    155     }
    156   }
    157   return progress;
    158 }
    159 
    160 /**
    161  * Try to flush logged, blocked TakeGrad attempts.
    162  */
    163 void ConditionalAccumulatorBase::FlushUnlocked() {
    164   std::vector<CleanUp> clean_up;
    165   Ref();
    166   {
    167     mutex_lock lock(mu_);
    168     bool changed;
    169     do {
    170       changed = TryAttemptLocked(&clean_up);
    171     } while (changed);
    172   }
    173   Unref();
    174   for (const auto& to_clean : clean_up) {
    175     if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
    176       // NOTE(mrry): We can safely ignore the return value of
    177       // DeregisterCallback because the mutex mu_ ensures that the
    178       // cleanup action only executes once.
    179       to_clean.cm->DeregisterCallback(to_clean.to_deregister);
    180     }
    181     to_clean.finished();
    182   }
    183 }
    184 
    185 bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
    186                                                       DoneCallback callback) {
    187   // At this point, the conditional should have been passed
    188 
    189   // Implicitly increment global_step
    190   current_global_step_++;
    191 
    192   // Average the accumulated gradient
    193   DivideAccumGradByCounter(ctx);
    194 
    195   // Set output for accumulated gradient tensor
    196   bool successful_set_output = SetOutput(ctx);
    197 
    198   // Reset counter
    199   if (successful_set_output) counter_ = 0;
    200 
    201   return successful_set_output;
    202 }
    203 
    204 }  // namespace tensorflow
    205