Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
     17 #define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
     18 
     19 #include <deque>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/resource_mgr.h"
     26 
     27 namespace tensorflow {
     28 
     29 /**
     30  * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation
     31  * object for adding gradients.
     32  * The two main methods of this class are TryApplyGrad and TryTakeGrad.
     33  *
     34  * TryApplyGrad tries add a gradient to the accumulator. The attempt is
     35  * successful if local_step >= global_step, i.e., if the gradient is not stale,
     36  * having been computed using up-to-date information. Otherwise, the gradient is
     37  * silently dropped.
     38  *
     39  * TryTakeGrad logs an attempt to read the average gradient. The attempt is
     40  * blocked until the number of gradients accumulated (via TryApplyGrad) is equal
     41  * or exceeds the number requested by TryTakeGrad.
     42  * Once this condition is satisfied, the following actions are taken:
     43  * (1) the value of the average gradient is returned
     44  * (2) the count of accumulated gradients is reset to 0
     45  * (3) the internal global_step value (current_global_step_) is incremented by 1
     46  */
     47 class ConditionalAccumulatorBase : public ResourceBase {
     48  public:
     49   // Args:
     50   //   dtype: The datatype of the gradients to be accumulated.
     51   //   shape: The shape of the accumulated gradients.
     52   //   name:  A name to use for the ConditionalAccumulator.
     53   ConditionalAccumulatorBase(const DataType& dtype,
     54                              const PartialTensorShape& shape,
     55                              const string& name);
     56 
     57   typedef AsyncOpKernel::DoneCallback DoneCallback;
     58 
     59   virtual void TryApplyGrad(int64 local_step, OpKernelContext* ctx) = 0;
     60   void TryTakeGrad(int num_required, OpKernelContext* ctx,
     61                    DoneCallback callback);
     62 
     63   // Accessor methods
     64   uint32 num_accumulated() {
     65     mutex_lock lock(mu_);
     66     return counter_;
     67   }
     68 
     69   const DataType& dtype() const { return dtype_; }
     70 
     71   string DebugString() override { return "A conditional accumulator"; }
     72 
     73   // SetGlobalStep is a modifier method for current_global_step.
     74   // It returns an InvalidArgument error if the new_global_step is less than
     75   // current_global_step.
     76   Status SetGlobalStep(int64 new_global_step);
     77 
     78   Status MatchesNodeDef(const NodeDef& node_def);
     79 
     80  protected:
     81   // Virtual methods to be implemented by sub-classes for different datatypes.
     82   // Implements arithmetic operations specific to datatype.
     83   virtual void DivideAccumGradByCounter(OpKernelContext* ctx)
     84       EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
     85   virtual bool SetOutput(OpKernelContext* ctx) = 0;
     86 
     87   enum RunResult { kNoProgress, kComplete };
     88 
     89   // Helper struct holding information about a TakeGrad attempt
     90   struct Attempt;
     91   typedef std::function<RunResult(Attempt*)> RunCallback;
     92   struct Attempt {
     93     int elements_requested;
     94     DoneCallback done_callback;  // must be run outside mu_
     95     OpKernelContext* context;
     96     CancellationManager* cancellation_manager;  // not owned
     97     CancellationToken cancellation_token;
     98     RunCallback run_callback;  // must be run while holding mu_
     99     bool is_cancelled;
    100 
    101     Attempt(int elements_requested, DoneCallback done_callback,
    102             OpKernelContext* context, CancellationManager* cancellation_manager,
    103             CancellationToken cancellation_token, RunCallback run_callback)
    104         : elements_requested(elements_requested),
    105           done_callback(std::move(done_callback)),
    106           context(context),
    107           cancellation_manager(cancellation_manager),
    108           cancellation_token(cancellation_token),
    109           run_callback(std::move(run_callback)),
    110           is_cancelled(false) {}
    111   };
    112 
    113   // Helper struct for deregistration of a cancellation token and executing a
    114   // DoneCallback after a TakeGrad attempt is complete.
    115   struct CleanUp {
    116     CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
    117         : finished(f), to_deregister(ct), cm(cm) {}
    118     DoneCallback finished;
    119     CancellationToken to_deregister;
    120     CancellationManager* cm;
    121   };
    122 
    123   // Fields
    124 
    125   const DataType dtype_;
    126   const PartialTensorShape shape_;
    127   const string name_;
    128   mutex mu_;
    129   int counter_ GUARDED_BY(mu_);
    130   int64 current_global_step_ GUARDED_BY(mu_);
    131 
    132   std::deque<Attempt> takegrad_attempts_ GUARDED_BY(mu_);
    133 
    134   // Methods
    135 
    136   // Helper function for creating cancellation callback
    137   void Cancel(CancellationManager* cancellation_manager,
    138               CancellationToken token);
    139 
    140   // Helper functions to process TakeGrad attempts.
    141   // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad
    142   // calls to try to clear the TakeGrad attempts. This in turn calls
    143   // TryAttemptLocked, which then executes the RunCallback of the logged
    144   // attempts.
    145   // Both functions are modeled after core/kernels/queue_base.
    146   // Note: ApplyGrad attempts never block -- unlike in a queue with limited
    147   //       capacity, we can always add the newest gradient to our accumulator
    148   //       (if it is not stale) or drop it silently (if it is stale).
    149   void FlushUnlocked();
    150   bool TryAttemptLocked(std::vector<CleanUp>* clean_up)
    151       EXCLUSIVE_LOCKS_REQUIRED(mu_);
    152 
    153   // Helper methods
    154   //  void DeepCopy(Tensor* dst);
    155   bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback)
    156       EXCLUSIVE_LOCKS_REQUIRED(mu_);
    157 };
    158 
    159 /*
    160  * Modifications to convenience macros defined in core/framework/op_kernel.h.
    161  * The below macros return a boolean if the test fails, so that the calling
    162  * function can get an indication that a failure has occurred.
    163  */
    164 #define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS)          \
    165   do {                                                 \
    166     if (!TF_PREDICT_TRUE(EXP)) {                       \
    167       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
    168       return false;                                    \
    169     }                                                  \
    170   } while (0)
    171 
    172 #define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS)                 \
    173   do {                                                      \
    174     ::tensorflow::Status _s(STATUS);                        \
    175     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
    176       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
    177       return false;                                         \
    178     }                                                       \
    179   } while (0)
    180 
    181 /*
    182  * Convenience classes for helping to convert between numeric types.
    183  * The specialization for Eigen::half here simplifies specialization of
    184  * ConditionalAccumulator classes later.
    185  */
    186 template <typename T, typename U>
    187 class TypeConverter {
    188  public:
    189   static T ConvertUToT(U c) { return c; /* implicit conversion */ }
    190 };
    191 
    192 template <typename U>
    193 class TypeConverter<Eigen::half, U> {
    194  public:
    195   static Eigen::half ConvertUToT(U c) {
    196     return Eigen::half_impl::float_to_half_rtne(c);
    197   }
    198 };
    199 
    200 }  // namespace tensorflow
    201 
    202 #endif  // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
    203