1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/conditional_accumulator.h" 19 #include "tensorflow/core/kernels/conditional_accumulator_base_op.h" 20 21 namespace tensorflow { 22 23 /** 24 * Defines a ConditionalAccumulatorOp, which constructs a ConditionalAccumulator 25 * and returns its handle. 26 */ 27 template <typename Device, typename T> 28 class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { 29 public: 30 explicit ConditionalAccumulatorOp(OpKernelConstruction* context) 31 : ConditionalAccumulatorBaseOp(context) {} 32 33 protected: 34 Creator GetCreator() const override { 35 return [this](ConditionalAccumulatorBase** ret) { 36 ConditionalAccumulator<Device, T>* accumulator = 37 new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name()); 38 *ret = accumulator; 39 return Status::OK(); 40 }; 41 } 42 43 TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp); 44 }; 45 46 #define REGISTER_KERNELS(type, dev) \ 47 REGISTER_KERNEL_BUILDER(Name("ConditionalAccumulator") \ 48 .Device(DEVICE_##dev) \ 49 .TypeConstraint<type>("dtype"), \ 50 ConditionalAccumulatorOp<dev##Device, type>) 51 52 #define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU) 53 54 TF_CALL_half(REGISTER_KERNELS_CPU); 55 TF_CALL_float(REGISTER_KERNELS_CPU); 56 TF_CALL_double(REGISTER_KERNELS_CPU); 57 58 #undef REGISTER_KERNELS_CPU 59 #undef REGISTER_KERNELS 60 61 /** 62 * Defines a AccumulateGradientOp, the execution of which adds a gradient to the 63 * given ConditionalAccumulator. 64 */ 65 class AccumulatorApplyGradientOp 66 : public ConditionalAccumulatorBaseApplyGradientOp { 67 public: 68 explicit AccumulatorApplyGradientOp(OpKernelConstruction* context) 69 : ConditionalAccumulatorBaseApplyGradientOp(context) {} 70 71 protected: 72 void CheckSignature(OpKernelContext* ctx, 73 ConditionalAccumulatorBase* accumulator) override { 74 // Check input signature 75 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64}; 76 expected_inputs.push_back(accumulator->dtype()); 77 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); 78 } 79 80 private: 81 TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorApplyGradientOp); 82 }; 83 84 REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient").Device(DEVICE_CPU), 85 AccumulatorApplyGradientOp); 86 87 /** 88 * Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which 89 * returns the average gradient accumulated by the given ConditionalAccumulator. 90 */ 91 class AccumulatorTakeGradientOp 92 : public ConditionalAccumulatorBaseTakeGradientOp { 93 public: 94 explicit AccumulatorTakeGradientOp(OpKernelConstruction* context) 95 : ConditionalAccumulatorBaseTakeGradientOp(context) {} 96 97 protected: 98 void CheckSignature(OpKernelContext* ctx, 99 ConditionalAccumulatorBase* accumulator, 100 DoneCallback callback) override { 101 // Check signature 102 OP_REQUIRES_OK_ASYNC( 103 ctx, 104 ctx->MatchSignature({DT_STRING_REF, DT_INT32}, {accumulator->dtype()}), 105 callback); 106 } 107 108 private: 109 TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp); 110 }; 111 112 REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU), 113 AccumulatorTakeGradientOp); 114 115 } // namespace tensorflow 116