Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
     17 #define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
     18 
     19 #include "tensorflow/core/kernels/fill_functor.h"
     20 #include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
     21 
     22 namespace tensorflow {
     23 
     24 /**
     25  * An aggregation object for adding dense gradients.
     26  *
     27  * The two main methods of this class are TryApplyGrad and TryTakeGrad.
     28  *
     29  * TryApplyGrad tries add a gradient to the accumulator. The attempt is
     30  * successful if local_step >= global_step, i.e., if the gradient is not stale,
     31  * having been computed using up-to-date information. Otherwise, the gradient is
     32  * silently dropped.
     33  *
     34  * TryTakeGrad logs an attempt to read the average gradient. The attempt is
     35  * blocked until the number of gradients accumulated (via TryApplyGrad) is equal
     36  * or exceeds the number requested by TryTakeGrad.
     37  * Once this condition is satisfied, the following actions are taken:
     38  * (1) the value of the average gradient is returned
     39  * (2) the count of accumulated gradients is reset to 0
     40  * (3) the internal global_step value (current_global_step_) is incremented by 1
     41  *
     42  * ConditionalAccumulator is the datatype-dependent templated sub-class of
     43  * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that
     44  * are used by for aggregating, averaging, allocating, returning dense Tensors.
     45  */
     46 template <typename Device, typename T>
     47 class ConditionalAccumulator
     48     : public TypedConditionalAccumulatorBase<const Tensor> {
     49  public:
     50   // Args:
     51   //   dtype: The datatype of the gradients to be accumulated.
     52   //   shape: The shape of the accumulated gradients.
     53   //   name:  A name to use for the ConditionalAccumulator.
     54   ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
     55                          const string& name)
     56       : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
     57   ~ConditionalAccumulator() override{};
     58 
     59  protected:
     60   // accum_grad is the tensor that holds the aggregate gradient.
     61   // It is initialized the first time ApplyGrad is called.
     62   Tensor* accum_grad_ = nullptr;
     63   PersistentTensor accum_grad_persistent_;
     64 
     65   functor::SetZeroFunctor<Device, T> set_zero_functor_;
     66 
     67   Status ValidateShape(const Tensor* tensor)
     68       EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
     69     // Must be compatible with accumulated gradient if available
     70     if (counter_ > 0) {
     71       if (!accum_grad_->shape().IsSameSize(tensor->shape())) {
     72         return errors::InvalidArgument("Shape mismatch: expected ",
     73                                        accum_grad_->shape().DebugString(),
     74                                        ", got ", tensor->shape().DebugString());
     75       }
     76     }
     77     // Must also be compatible with given shape
     78     if (!shape_.IsCompatibleWith(tensor->shape())) {
     79       return errors::InvalidArgument("Shape mismatch: expected ",
     80                                      shape_.DebugString(), ", got ",
     81                                      tensor->shape().DebugString());
     82     }
     83     return Status::OK();
     84   }
     85 
     86   void AllocateAndAssignToAccumGradFunction(OpKernelContext* ctx,
     87                                             const Tensor* grad) override {
     88     // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
     89     ctx->allocate_persistent(dtype_, grad->shape(), &accum_grad_persistent_,
     90                              &accum_grad_)
     91         .IgnoreError();
     92     accum_grad_->flat<T>().device(ctx->template eigen_device<Device>()) =
     93         grad->flat<T>();
     94   }
     95 
     96   void AddToAccumGradFunction(OpKernelContext* ctx,
     97                               const Tensor* grad) override {
     98     accum_grad_->flat<T>().device(ctx->template eigen_device<Device>()) +=
     99         grad->flat<T>();
    100   }
    101 
    102   void DivideAccumGradByCounter(OpKernelContext* ctx) override
    103       EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
    104     Tensor c(DataTypeToEnum<T>::value, {});
    105     c.scalar<T>()() = TypeConverter<T, int>::ConvertUToT(this->counter_);
    106     this->accum_grad_->template flat<T>().device(
    107         ctx->template eigen_device<Device>()) =
    108         this->accum_grad_->template flat<T>() / c.scalar<T>()();
    109   }
    110 
    111   bool SetOutput(OpKernelContext* ctx) override {
    112     ctx->set_output(0, *accum_grad_);
    113     return true;
    114   }
    115 
    116   bool GetAndValidateTensorInputForApplyGrad(OpKernelContext* ctx,
    117                                              const Tensor** tensor) override
    118       EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
    119     // Get input gradient tensor
    120     const Tensor* grad_tensor;
    121     OP_REQUIRES_OK_BOOLEAN(ctx, ctx->input("gradient", &grad_tensor));
    122     *tensor = grad_tensor;
    123     OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor));
    124     return true;
    125   }
    126 
    127   void CleanUpGradTensor(const Tensor* tensor) override {
    128     // do nothing
    129   }
    130 
    131   TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulator);
    132 };
    133 
    134 }  // namespace tensorflow
    135 
    136 #endif  // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
    137