Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
     17 #define TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
     18 
     19 #include "tensorflow/core/kernels/conditional_accumulator_base.h"
     20 
     21 namespace tensorflow {
     22 
     23 /*
     24  * TypedConditionalAccumulatorBase is a templated companion of
     25  * ConditionalAccumulatorBase which allows for subclasses to use different
     26  * types for the input gradients. (See ConditionalAccumulator and
     27  * SparseConditionalAccumulator.)
     28  *
     29  * TypedConditionalAccumulatorBase defines virtual methods and implements
     30  * methods which depend on the gradient type. These are mainly methods that are
     31  * used for adding a new gradient to the accumulator.
     32  */
     33 template <typename GradientTensorType>
     34 class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
     35  public:
     36   TypedConditionalAccumulatorBase(const DataType& dtype,
     37                                   const PartialTensorShape& shape,
     38                                   const string& name)
     39       : ConditionalAccumulatorBase(dtype, shape, name) {}
     40 
     41   /**
     42    * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
     43    * successful (i.e., has its gradient applied) if its local_step >=
     44    * current_global_step_ at the time the attempt is processed. Otherwise, if
     45    * local_step < current_global_step_, the stale gradient is silently dropped.
     46    *
     47    * local_step: Time-step at which the gradient was computed.
     48    * grad:       Gradient tensor to be added to the accumulator.
     49    * ctx:        Context in which the op is executed.
     50    */
     51   void TryApplyGrad(int64 local_step, OpKernelContext* ctx) override {
     52     {
     53       mutex_lock l(mu_);
     54       if (local_step >= current_global_step_) {
     55         GradientTensorType* grad = nullptr;
     56         bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad);
     57         if (is_valid) {
     58           if (counter_ > 0) {
     59             AddToAccumGradFunction(ctx, grad);
     60           } else {
     61             AllocateAndAssignToAccumGradFunction(ctx, grad);
     62           }
     63           counter_++;
     64         }
     65         CleanUpGradTensor(grad);
     66       }
     67     }
     68     FlushUnlocked();
     69   }
     70 
     71  protected:
     72   // Virtual methods to be implemented by sub-classes for different datatypes.
     73   // Implements arithmetic operations specific to datatype.
     74   virtual void AllocateAndAssignToAccumGradFunction(
     75       OpKernelContext* ctx, GradientTensorType* grad) = 0;
     76 
     77   virtual void AddToAccumGradFunction(OpKernelContext* ctx,
     78                                       GradientTensorType* grad) = 0;
     79 
     80   // Method for extracting and validating input provided in an OpKernelContext.
     81   // Returns true if input was successfully retrieved and is valid.
     82   // Gradient is returned via the GradientTensorType** tensor.
     83   virtual bool GetAndValidateTensorInputForApplyGrad(
     84       OpKernelContext* ctx, GradientTensorType** tensor)
     85       EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
     86 
     87   // Method for cleaning up any memory allocated in
     88   // GetAndValidateTensorInputForApplyGrad
     89   virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0;
     90 };
     91 
     92 }  // namespace tensorflow
     93 
     94 #endif  // TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
     95