1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ 17 #define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ 18 19 #include <deque> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/numeric_op.h" 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 27 namespace tensorflow { 28 29 /** 30 * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation 31 * object for adding gradients. 32 * The two main methods of this class are TryApplyGrad and TryTakeGrad. 33 * 34 * TryApplyGrad tries add a gradient to the accumulator. The attempt is 35 * successful if local_step >= global_step, i.e., if the gradient is not stale, 36 * having been computed using up-to-date information. Otherwise, the gradient is 37 * silently dropped. 38 * 39 * TryTakeGrad logs an attempt to read the average gradient. The attempt is 40 * blocked until the number of gradients accumulated (via TryApplyGrad) is equal 41 * or exceeds the number requested by TryTakeGrad. 42 * Once this condition is satisfied, the following actions are taken: 43 * (1) the value of the average gradient is returned 44 * (2) the count of accumulated gradients is reset to 0 45 * (3) the internal global_step value (current_global_step_) is incremented by 1 46 */ 47 class ConditionalAccumulatorBase : public ResourceBase { 48 public: 49 // Args: 50 // dtype: The datatype of the gradients to be accumulated. 51 // shape: The shape of the accumulated gradients. 52 // name: A name to use for the ConditionalAccumulator. 53 ConditionalAccumulatorBase(const DataType& dtype, 54 const PartialTensorShape& shape, 55 const string& name); 56 57 typedef AsyncOpKernel::DoneCallback DoneCallback; 58 59 virtual void TryApplyGrad(int64 local_step, OpKernelContext* ctx) = 0; 60 void TryTakeGrad(int num_required, OpKernelContext* ctx, 61 DoneCallback callback); 62 63 // Accessor methods 64 uint32 num_accumulated() { 65 mutex_lock lock(mu_); 66 return counter_; 67 } 68 69 const DataType& dtype() const { return dtype_; } 70 71 string DebugString() override { return "A conditional accumulator"; } 72 73 // SetGlobalStep is a modifier method for current_global_step. 74 // It returns an InvalidArgument error if the new_global_step is less than 75 // current_global_step. 76 Status SetGlobalStep(int64 new_global_step); 77 78 Status MatchesNodeDef(const NodeDef& node_def); 79 80 protected: 81 // Virtual methods to be implemented by sub-classes for different datatypes. 82 // Implements arithmetic operations specific to datatype. 83 virtual void DivideAccumGradByCounter(OpKernelContext* ctx) 84 EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; 85 virtual bool SetOutput(OpKernelContext* ctx) = 0; 86 87 enum RunResult { kNoProgress, kComplete }; 88 89 // Helper struct holding information about a TakeGrad attempt 90 struct Attempt; 91 typedef std::function<RunResult(Attempt*)> RunCallback; 92 struct Attempt { 93 int elements_requested; 94 DoneCallback done_callback; // must be run outside mu_ 95 OpKernelContext* context; 96 CancellationManager* cancellation_manager; // not owned 97 CancellationToken cancellation_token; 98 RunCallback run_callback; // must be run while holding mu_ 99 bool is_cancelled; 100 101 Attempt(int elements_requested, DoneCallback done_callback, 102 OpKernelContext* context, CancellationManager* cancellation_manager, 103 CancellationToken cancellation_token, RunCallback run_callback) 104 : elements_requested(elements_requested), 105 done_callback(std::move(done_callback)), 106 context(context), 107 cancellation_manager(cancellation_manager), 108 cancellation_token(cancellation_token), 109 run_callback(std::move(run_callback)), 110 is_cancelled(false) {} 111 }; 112 113 // Helper struct for deregistration of a cancellation token and executing a 114 // DoneCallback after a TakeGrad attempt is complete. 115 struct CleanUp { 116 CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) 117 : finished(f), to_deregister(ct), cm(cm) {} 118 DoneCallback finished; 119 CancellationToken to_deregister; 120 CancellationManager* cm; 121 }; 122 123 // Fields 124 125 const DataType dtype_; 126 const PartialTensorShape shape_; 127 const string name_; 128 mutex mu_; 129 int counter_ GUARDED_BY(mu_); 130 int64 current_global_step_ GUARDED_BY(mu_); 131 132 std::deque<Attempt> takegrad_attempts_ GUARDED_BY(mu_); 133 134 // Methods 135 136 // Helper function for creating cancellation callback 137 void Cancel(CancellationManager* cancellation_manager, 138 CancellationToken token); 139 140 // Helper functions to process TakeGrad attempts. 141 // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad 142 // calls to try to clear the TakeGrad attempts. This in turn calls 143 // TryAttemptLocked, which then executes the RunCallback of the logged 144 // attempts. 145 // Both functions are modeled after core/kernels/queue_base. 146 // Note: ApplyGrad attempts never block -- unlike in a queue with limited 147 // capacity, we can always add the newest gradient to our accumulator 148 // (if it is not stale) or drop it silently (if it is stale). 149 void FlushUnlocked(); 150 bool TryAttemptLocked(std::vector<CleanUp>* clean_up) 151 EXCLUSIVE_LOCKS_REQUIRED(mu_); 152 153 // Helper methods 154 // void DeepCopy(Tensor* dst); 155 bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback) 156 EXCLUSIVE_LOCKS_REQUIRED(mu_); 157 }; 158 159 /* 160 * Modifications to convenience macros defined in core/framework/op_kernel.h. 161 * The below macros return a boolean if the test fails, so that the calling 162 * function can get an indication that a failure has occurred. 163 */ 164 #define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ 165 do { \ 166 if (!TF_PREDICT_TRUE(EXP)) { \ 167 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ 168 return false; \ 169 } \ 170 } while (0) 171 172 #define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \ 173 do { \ 174 ::tensorflow::Status _s(STATUS); \ 175 if (!TF_PREDICT_TRUE(_s.ok())) { \ 176 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ 177 return false; \ 178 } \ 179 } while (0) 180 181 /* 182 * Convenience classes for helping to convert between numeric types. 183 * The specialization for Eigen::half here simplifies specialization of 184 * ConditionalAccumulator classes later. 185 */ 186 template <typename T, typename U> 187 class TypeConverter { 188 public: 189 static T ConvertUToT(U c) { return c; /* implicit conversion */ } 190 }; 191 192 template <typename U> 193 class TypeConverter<Eigen::half, U> { 194 public: 195 static Eigen::half ConvertUToT(U c) { 196 return Eigen::half_impl::float_to_half_rtne(c); 197 } 198 }; 199 200 } // namespace tensorflow 201 202 #endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ 203