1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ 17 #define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ 18 19 #define EIGEN_USE_THREADS 20 21 #include "tensorflow/core/kernels/conditional_accumulator_base.h" 22 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #include "tensorflow/core/platform/thread_annotations.h" 33 #include "tensorflow/core/platform/types.h" 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 37 typedef std::function<void()> DoneCallback; 38 39 namespace tensorflow { 40 41 /** 42 * Defines a ConditionalAccumulatorBaseOp, which constructs a 43 * ConditionalAccumulatorBase (via sub-class's Creator) and returns its handle. 44 */ 45 class ConditionalAccumulatorBaseOp : public OpKernel { 46 public: 47 explicit ConditionalAccumulatorBaseOp(OpKernelConstruction* context) 48 : OpKernel(context), accumulator_handle_set_(false) { 49 OP_REQUIRES_OK(context, 50 context->allocate_persistent(DT_STRING, TensorShape({2}), 51 &accumulator_handle_, nullptr)); 52 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); 53 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 54 } 55 56 void Compute(OpKernelContext* ctx) override { 57 mutex_lock l(mu_); 58 if (!accumulator_handle_set_) { 59 OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx)); 60 } 61 ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx)); 62 } 63 64 protected: 65 ~ConditionalAccumulatorBaseOp() override { 66 // If the accumulator object was not shared, delete it. 67 if (accumulator_handle_set_ && cinfo_.resource_is_private_to_kernel()) { 68 TF_CHECK_OK((cinfo_.resource_manager() 69 ->template Delete<ConditionalAccumulatorBase>( 70 cinfo_.container(), cinfo_.name()))); 71 } 72 } 73 74 protected: 75 typedef std::function<Status(ConditionalAccumulatorBase**)> Creator; 76 77 // Subclasses must override this 78 virtual Creator GetCreator() const = 0; 79 80 // Variables required to construct ConditionalAccumulator 81 DataType dtype_; 82 PartialTensorShape shape_; 83 ContainerInfo cinfo_; 84 85 private: 86 Status SetAccumulatorHandle(OpKernelContext* ctx) 87 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 88 TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); 89 90 // Check input signature 91 DataTypeVector expected_inputs = {}; 92 TF_RETURN_IF_ERROR(ctx->MatchSignature(expected_inputs, {DT_STRING_REF})); 93 94 Creator creator = GetCreator(); 95 ConditionalAccumulatorBase* accumulator; 96 TF_RETURN_IF_ERROR( 97 (cinfo_.resource_manager() 98 ->template LookupOrCreate<ConditionalAccumulatorBase>( 99 cinfo_.container(), cinfo_.name(), &accumulator, creator))); 100 core::ScopedUnref unref_me(accumulator); 101 102 // Verify that the shared accumulator is compatible 103 // with the requested arguments. 104 TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def())); 105 auto h = accumulator_handle_.AccessTensor(ctx)->template flat<string>(); 106 h(0) = cinfo_.container(); 107 h(1) = cinfo_.name(); 108 accumulator_handle_set_ = true; 109 return Status::OK(); 110 } 111 112 mutex mu_; 113 PersistentTensor accumulator_handle_ GUARDED_BY(mu_); 114 bool accumulator_handle_set_ GUARDED_BY(mu_); 115 }; 116 117 /** 118 * General OpKernel for ConditionalAccumulatorBase-related ops. 119 */ 120 class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel { 121 public: 122 explicit ConditionalAccumulatorBaseAsyncOpKernel( 123 OpKernelConstruction* context) 124 : AsyncOpKernel(context) {} 125 126 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { 127 ConditionalAccumulatorBase* accumulator; 128 OP_REQUIRES_OK_ASYNC( 129 ctx, GetResourceFromContext(ctx, "handle", &accumulator), callback); 130 ComputeAsync(ctx, accumulator, [callback, accumulator]() { 131 accumulator->Unref(); 132 callback(); 133 }); 134 } 135 136 protected: 137 virtual void ComputeAsync(OpKernelContext* ctx, 138 ConditionalAccumulatorBase* accumulator, 139 DoneCallback callback) = 0; 140 }; 141 142 /** 143 * General OpKernel for ConditionalAccumulatorBase-related ops. 144 */ 145 class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel { 146 public: 147 explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context) 148 : OpKernel(context) {} 149 150 void Compute(OpKernelContext* ctx) final { 151 ConditionalAccumulatorBase* accumulator; 152 OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator)); 153 Compute(ctx, accumulator); 154 accumulator->Unref(); 155 } 156 157 protected: 158 virtual void Compute(OpKernelContext* ctx, 159 ConditionalAccumulatorBase* accumulator) = 0; 160 }; 161 162 /** 163 * Defines a AccumulateGradientOp, the execution of which adds a gradient to the 164 * given ConditionalAccumulator. 165 */ 166 class ConditionalAccumulatorBaseApplyGradientOp 167 : public ConditionalAccumulatorBaseSyncOpKernel { 168 public: 169 explicit ConditionalAccumulatorBaseApplyGradientOp( 170 OpKernelConstruction* context) 171 : ConditionalAccumulatorBaseSyncOpKernel(context) {} 172 173 protected: 174 virtual void CheckSignature(OpKernelContext* ctx, 175 ConditionalAccumulatorBase* accumulator) = 0; 176 177 void Compute(OpKernelContext* ctx, 178 ConditionalAccumulatorBase* accumulator) override { 179 // Check input signature 180 CheckSignature(ctx, accumulator); 181 182 // Get input local_step 183 const Tensor* local_step_tensor; 184 OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor)); 185 if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) { 186 ctx->CtxFailureWithWarning(errors::InvalidArgument( 187 "Argument local_step must be scalar, but had bad shape ", 188 local_step_tensor->shape().DebugString())); 189 } 190 191 // Actually try to apply gradient now 192 accumulator->TryApplyGrad(local_step_tensor->scalar<int64>()(), ctx); 193 } 194 }; 195 196 /** 197 * Defines a TakeAccumulatedGradientOp, the execution of which adds a gradient 198 * to the given ConditionalAccumulator. 199 */ 200 class ConditionalAccumulatorBaseTakeGradientOp 201 : public ConditionalAccumulatorBaseAsyncOpKernel { 202 public: 203 explicit ConditionalAccumulatorBaseTakeGradientOp( 204 OpKernelConstruction* context) 205 : ConditionalAccumulatorBaseAsyncOpKernel(context) {} 206 207 protected: 208 virtual void CheckSignature(OpKernelContext* ctx, 209 ConditionalAccumulatorBase* accumulator, 210 DoneCallback callback) = 0; 211 212 void ComputeAsync(OpKernelContext* ctx, 213 ConditionalAccumulatorBase* accumulator, 214 DoneCallback callback) override { 215 // Check signature 216 CheckSignature(ctx, accumulator, callback); 217 218 // Get input num_required 219 const Tensor* num_required_tensor; 220 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_required", &num_required_tensor), 221 callback); 222 if (!TensorShapeUtils::IsScalar(num_required_tensor->shape())) { 223 ctx->CtxFailureWithWarning(errors::InvalidArgument( 224 "Argument num_required must be scalar, but had bad shape ", 225 num_required_tensor->shape().DebugString())); 226 callback(); 227 } 228 229 // Actually try to take gradient now 230 accumulator->TryTakeGrad(num_required_tensor->scalar<int32>()(), ctx, 231 callback); 232 } 233 }; 234 235 } // namespace tensorflow 236 237 #endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ 238