1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/conditional_accumulator_base.h" 17 18 namespace tensorflow { 19 20 ConditionalAccumulatorBase::ConditionalAccumulatorBase( 21 const DataType& dtype, const PartialTensorShape& shape, const string& name) 22 : dtype_(dtype), shape_(shape), name_(name) { 23 counter_ = 0; 24 current_global_step_ = 0; 25 } 26 27 Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) { 28 // TODO(xinghao@): implement the checks for the node definition 29 return Status::OK(); 30 } 31 32 /** 33 * Sets the time step of the accumulator to be in line with the global time 34 * step. Logs warning if the accumulator's time step is already larger than the 35 * provided time step. 36 */ 37 Status ConditionalAccumulatorBase::SetGlobalStep(int64 new_global_step) { 38 mutex_lock lock(mu_); 39 if (new_global_step < current_global_step_) { 40 LOG(WARNING) << "Attempt to set current_global_step_ to smaller value: " 41 << "current_global_step_ = " << current_global_step_ 42 << " >= " << new_global_step << " = new_global_step."; 43 } 44 current_global_step_ = new_global_step; 45 return Status::OK(); 46 } 47 48 /** 49 * Logs an attempt to extract the average gradient, and tries to flush all 50 * TakeGrad attempts. 51 * A TakeGrad attempt is blocked until num_required > counter_, i.e., 52 * sufficient gradients have been accumulated. 53 * 54 * num_required: Number of gradients that needs to be accumulated before the 55 * attempt is unblocked. 56 * ctx: Context in which the op is executed. 57 * callback: A callback to be executed after the attempt has been completed. 58 */ 59 void ConditionalAccumulatorBase::TryTakeGrad(int num_required, 60 OpKernelContext* ctx, 61 DoneCallback callback) { 62 if (num_required <= 0) { 63 ctx->CtxFailureWithWarning(errors::InvalidArgument( 64 "Argument num_required must be positive, but was ", num_required)); 65 callback(); 66 } else { 67 CancellationManager* cm = ctx->cancellation_manager(); 68 CancellationToken token = cm->get_cancellation_token(); 69 bool already_cancelled; 70 { 71 mutex_lock l(mu_); 72 already_cancelled = !cm->RegisterCallback( 73 token, [this, cm, token]() { Cancel(cm, token); }); 74 if (!already_cancelled) { 75 takegrad_attempts_.emplace_back( 76 num_required, callback, ctx, cm, token, 77 [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 78 if (counter_ >= attempt->elements_requested) { 79 bool successful_take_grad = TakeGradLockedHelper( 80 attempt->context, attempt->done_callback); 81 if (successful_take_grad) { 82 return kComplete; 83 } else { 84 // Try again 85 return kNoProgress; 86 } 87 } else { 88 return kNoProgress; 89 } 90 }); 91 } 92 } 93 if (!already_cancelled) { 94 FlushUnlocked(); 95 } else { 96 ctx->SetStatus(errors::Cancelled("TakeGrad operation was cancelled")); 97 callback(); 98 } 99 } 100 } 101 102 /** 103 * Cancellation callback. 104 */ 105 void ConditionalAccumulatorBase::Cancel( 106 CancellationManager* cancellation_manager, CancellationToken token) { 107 DoneCallback callback = nullptr; 108 { 109 mutex_lock lock(mu_); 110 111 for (Attempt& attempt : takegrad_attempts_) { 112 if (attempt.cancellation_manager == cancellation_manager && 113 attempt.cancellation_token == token) { 114 if (!attempt.is_cancelled) { 115 attempt.is_cancelled = true; 116 attempt.context->SetStatus( 117 errors::Cancelled("TakeGrad operation was cancelled")); 118 std::swap(callback, attempt.done_callback); 119 } 120 break; 121 } 122 } 123 } 124 if (callback) { 125 callback(); 126 FlushUnlocked(); 127 } 128 } 129 130 /** 131 * Try to flush logged, blocked TakeGrad attempts. 132 */ 133 bool ConditionalAccumulatorBase::TryAttemptLocked( 134 std::vector<CleanUp>* clean_up) { 135 bool progress = false; 136 bool done = false; 137 while (!done && !takegrad_attempts_.empty()) { 138 if (takegrad_attempts_.front().is_cancelled) { 139 VLOG(1) << "Skipping cancelled TakeGrad attempt"; 140 takegrad_attempts_.pop_front(); 141 } else { 142 Attempt* cur_attempt = &takegrad_attempts_.front(); 143 switch (cur_attempt->run_callback(cur_attempt)) { 144 case kNoProgress: 145 done = true; 146 break; 147 case kComplete: 148 progress = true; 149 clean_up->emplace_back(std::move(cur_attempt->done_callback), 150 cur_attempt->cancellation_token, 151 cur_attempt->context->cancellation_manager()); 152 takegrad_attempts_.pop_front(); 153 break; 154 } 155 } 156 } 157 return progress; 158 } 159 160 /** 161 * Try to flush logged, blocked TakeGrad attempts. 162 */ 163 void ConditionalAccumulatorBase::FlushUnlocked() { 164 std::vector<CleanUp> clean_up; 165 Ref(); 166 { 167 mutex_lock lock(mu_); 168 bool changed; 169 do { 170 changed = TryAttemptLocked(&clean_up); 171 } while (changed); 172 } 173 Unref(); 174 for (const auto& to_clean : clean_up) { 175 if (to_clean.to_deregister != CancellationManager::kInvalidToken) { 176 // NOTE(mrry): We can safely ignore the return value of 177 // DeregisterCallback because the mutex mu_ ensures that the 178 // cleanup action only executes once. 179 to_clean.cm->DeregisterCallback(to_clean.to_deregister); 180 } 181 to_clean.finished(); 182 } 183 } 184 185 bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, 186 DoneCallback callback) { 187 // At this point, the conditional should have been passed 188 189 // Implicitly increment global_step 190 current_global_step_++; 191 192 // Average the accumulated gradient 193 DivideAccumGradByCounter(ctx); 194 195 // Set output for accumulated gradient tensor 196 bool successful_set_output = SetOutput(ctx); 197 198 // Reset counter 199 if (successful_set_output) counter_ = 0; 200 201 return successful_set_output; 202 } 203 204 } // namespace tensorflow 205