1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_ 17 #define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_ 18 19 #include "tensorflow/core/kernels/fill_functor.h" 20 #include "tensorflow/core/kernels/typed_conditional_accumulator_base.h" 21 22 namespace tensorflow { 23 24 /** 25 * An aggregation object for adding dense gradients. 26 * 27 * The two main methods of this class are TryApplyGrad and TryTakeGrad. 28 * 29 * TryApplyGrad tries add a gradient to the accumulator. The attempt is 30 * successful if local_step >= global_step, i.e., if the gradient is not stale, 31 * having been computed using up-to-date information. Otherwise, the gradient is 32 * silently dropped. 33 * 34 * TryTakeGrad logs an attempt to read the average gradient. The attempt is 35 * blocked until the number of gradients accumulated (via TryApplyGrad) is equal 36 * or exceeds the number requested by TryTakeGrad. 37 * Once this condition is satisfied, the following actions are taken: 38 * (1) the value of the average gradient is returned 39 * (2) the count of accumulated gradients is reset to 0 40 * (3) the internal global_step value (current_global_step_) is incremented by 1 41 * 42 * ConditionalAccumulator is the datatype-dependent templated sub-class of 43 * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that 44 * are used by for aggregating, averaging, allocating, returning dense Tensors. 45 */ 46 template <typename Device, typename T> 47 class ConditionalAccumulator 48 : public TypedConditionalAccumulatorBase<const Tensor> { 49 public: 50 // Args: 51 // dtype: The datatype of the gradients to be accumulated. 52 // shape: The shape of the accumulated gradients. 53 // name: A name to use for the ConditionalAccumulator. 54 ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, 55 const string& name) 56 : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {} 57 ~ConditionalAccumulator() override{}; 58 59 protected: 60 // accum_grad is the tensor that holds the aggregate gradient. 61 // It is initialized the first time ApplyGrad is called. 62 Tensor* accum_grad_ = nullptr; 63 PersistentTensor accum_grad_persistent_; 64 65 functor::SetZeroFunctor<Device, T> set_zero_functor_; 66 67 Status ValidateShape(const Tensor* tensor) 68 EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 69 // Must be compatible with accumulated gradient if available 70 if (counter_ > 0) { 71 if (!accum_grad_->shape().IsSameSize(tensor->shape())) { 72 return errors::InvalidArgument("Shape mismatch: expected ", 73 accum_grad_->shape().DebugString(), 74 ", got ", tensor->shape().DebugString()); 75 } 76 } 77 // Must also be compatible with given shape 78 if (!shape_.IsCompatibleWith(tensor->shape())) { 79 return errors::InvalidArgument("Shape mismatch: expected ", 80 shape_.DebugString(), ", got ", 81 tensor->shape().DebugString()); 82 } 83 return Status::OK(); 84 } 85 86 void AllocateAndAssignToAccumGradFunction(OpKernelContext* ctx, 87 const Tensor* grad) override { 88 // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! 89 ctx->allocate_persistent(dtype_, grad->shape(), &accum_grad_persistent_, 90 &accum_grad_) 91 .IgnoreError(); 92 accum_grad_->flat<T>().device(ctx->template eigen_device<Device>()) = 93 grad->flat<T>(); 94 } 95 96 void AddToAccumGradFunction(OpKernelContext* ctx, 97 const Tensor* grad) override { 98 accum_grad_->flat<T>().device(ctx->template eigen_device<Device>()) += 99 grad->flat<T>(); 100 } 101 102 void DivideAccumGradByCounter(OpKernelContext* ctx) override 103 EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 104 Tensor c(DataTypeToEnum<T>::value, {}); 105 c.scalar<T>()() = TypeConverter<T, int>::ConvertUToT(this->counter_); 106 this->accum_grad_->template flat<T>().device( 107 ctx->template eigen_device<Device>()) = 108 this->accum_grad_->template flat<T>() / c.scalar<T>()(); 109 } 110 111 bool SetOutput(OpKernelContext* ctx) override { 112 ctx->set_output(0, *accum_grad_); 113 return true; 114 } 115 116 bool GetAndValidateTensorInputForApplyGrad(OpKernelContext* ctx, 117 const Tensor** tensor) override 118 EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 119 // Get input gradient tensor 120 const Tensor* grad_tensor; 121 OP_REQUIRES_OK_BOOLEAN(ctx, ctx->input("gradient", &grad_tensor)); 122 *tensor = grad_tensor; 123 OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor)); 124 return true; 125 } 126 127 void CleanUpGradTensor(const Tensor* tensor) override { 128 // do nothing 129 } 130 131 TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulator); 132 }; 133 134 } // namespace tensorflow 135 136 #endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_ 137