1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/conditional_accumulator_base_op.h" 19 #include "tensorflow/core/kernels/sparse_conditional_accumulator.h" 20 21 namespace tensorflow { 22 23 /** 24 * Defines a SparseConditionalAccumulatorOp, which constructs a 25 * SparseConditionalAccumulator and returns its handle. 26 */ 27 template <typename Device, typename T> 28 class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { 29 public: 30 explicit SparseConditionalAccumulatorOp(OpKernelConstruction* context) 31 : ConditionalAccumulatorBaseOp(context) {} 32 33 protected: 34 Creator GetCreator() const override { 35 return [this](ConditionalAccumulatorBase** ret) { 36 SparseConditionalAccumulator<Device, T>* accumulator = 37 new SparseConditionalAccumulator<Device, T>(dtype_, shape_, 38 cinfo_.name()); 39 *ret = accumulator; 40 return Status::OK(); 41 }; 42 } 43 44 TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp); 45 }; 46 47 #define REGISTER_KERNELS(type, dev) \ 48 REGISTER_KERNEL_BUILDER(Name("SparseConditionalAccumulator") \ 49 .Device(DEVICE_##dev) \ 50 .TypeConstraint<type>("dtype"), \ 51 SparseConditionalAccumulatorOp<dev##Device, type>) 52 53 #define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU) 54 55 TF_CALL_half(REGISTER_KERNELS_CPU); 56 TF_CALL_float(REGISTER_KERNELS_CPU); 57 TF_CALL_double(REGISTER_KERNELS_CPU); 58 59 #undef REGISTER_KERNELS_CPU 60 #undef REGISTER_KERNELS 61 62 /** 63 * Defines a SparseAccumulateGradientOp, the execution of which adds a gradient 64 * to the given SparseConditionalAccumulator. 65 */ 66 class SparseAccumulatorApplyGradientOp 67 : public ConditionalAccumulatorBaseApplyGradientOp { 68 public: 69 explicit SparseAccumulatorApplyGradientOp(OpKernelConstruction* context) 70 : ConditionalAccumulatorBaseApplyGradientOp(context) {} 71 72 protected: 73 void CheckSignature(OpKernelContext* ctx, 74 ConditionalAccumulatorBase* accumulator) override { 75 // Check input signature 76 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64}; 77 expected_inputs.push_back(accumulator->dtype()); 78 expected_inputs.push_back(DT_INT64); 79 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); 80 } 81 82 private: 83 TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorApplyGradientOp); 84 }; 85 86 REGISTER_KERNEL_BUILDER( 87 Name("SparseAccumulatorApplyGradient").Device(DEVICE_CPU), 88 SparseAccumulatorApplyGradientOp); 89 90 /** 91 * Defines a SparseAccumulatorTakeGradientOp, the execution of which returns the 92 * average sparse gradient accumulated by the given ConditionalAccumulator. 93 */ 94 class SparseAccumulatorTakeGradientOp 95 : public ConditionalAccumulatorBaseTakeGradientOp { 96 public: 97 explicit SparseAccumulatorTakeGradientOp(OpKernelConstruction* context) 98 : ConditionalAccumulatorBaseTakeGradientOp(context) {} 99 100 protected: 101 void CheckSignature(OpKernelContext* ctx, 102 ConditionalAccumulatorBase* accumulator, 103 DoneCallback callback) override { 104 // Check signature 105 OP_REQUIRES_OK_ASYNC( 106 ctx, 107 ctx->MatchSignature({DT_STRING_REF, DT_INT32}, 108 {DT_INT64, accumulator->dtype(), DT_INT64}), 109 callback); 110 } 111 112 private: 113 TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp); 114 }; 115 116 REGISTER_KERNEL_BUILDER( 117 Name("SparseAccumulatorTakeGradient").Device(DEVICE_CPU), 118 SparseAccumulatorTakeGradientOp); 119 120 } // namespace tensorflow 121