Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/kernels/conditional_accumulator_base_op.h"
     19 #include "tensorflow/core/kernels/sparse_conditional_accumulator.h"
     20 
     21 namespace tensorflow {
     22 
     23 /**
     24  * Defines a SparseConditionalAccumulatorOp, which constructs a
     25  * SparseConditionalAccumulator and returns its handle.
     26  */
     27 template <typename Device, typename T>
     28 class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
     29  public:
     30   explicit SparseConditionalAccumulatorOp(OpKernelConstruction* context)
     31       : ConditionalAccumulatorBaseOp(context) {}
     32 
     33  protected:
     34   Creator GetCreator() const override {
     35     return [this](ConditionalAccumulatorBase** ret) {
     36       SparseConditionalAccumulator<Device, T>* accumulator =
     37           new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
     38                                                       cinfo_.name());
     39       *ret = accumulator;
     40       return Status::OK();
     41     };
     42   }
     43 
     44   TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulatorOp);
     45 };
     46 
     47 #define REGISTER_KERNELS(type, dev)                            \
     48   REGISTER_KERNEL_BUILDER(Name("SparseConditionalAccumulator") \
     49                               .Device(DEVICE_##dev)            \
     50                               .TypeConstraint<type>("dtype"),  \
     51                           SparseConditionalAccumulatorOp<dev##Device, type>)
     52 
     53 #define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
     54 
     55 TF_CALL_half(REGISTER_KERNELS_CPU);
     56 TF_CALL_float(REGISTER_KERNELS_CPU);
     57 TF_CALL_double(REGISTER_KERNELS_CPU);
     58 
     59 #undef REGISTER_KERNELS_CPU
     60 #undef REGISTER_KERNELS
     61 
     62 /**
     63  * Defines a SparseAccumulateGradientOp, the execution of which adds a gradient
     64  * to the given SparseConditionalAccumulator.
     65  */
     66 class SparseAccumulatorApplyGradientOp
     67     : public ConditionalAccumulatorBaseApplyGradientOp {
     68  public:
     69   explicit SparseAccumulatorApplyGradientOp(OpKernelConstruction* context)
     70       : ConditionalAccumulatorBaseApplyGradientOp(context) {}
     71 
     72  protected:
     73   void CheckSignature(OpKernelContext* ctx,
     74                       ConditionalAccumulatorBase* accumulator) override {
     75     // Check input signature
     76     DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64, DT_INT64};
     77     expected_inputs.push_back(accumulator->dtype());
     78     expected_inputs.push_back(DT_INT64);
     79     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
     80   }
     81 
     82  private:
     83   TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorApplyGradientOp);
     84 };
     85 
     86 REGISTER_KERNEL_BUILDER(
     87     Name("SparseAccumulatorApplyGradient").Device(DEVICE_CPU),
     88     SparseAccumulatorApplyGradientOp);
     89 
     90 /**
     91  * Defines a SparseAccumulatorTakeGradientOp, the execution of which returns the
     92  * average sparse gradient accumulated by the given ConditionalAccumulator.
     93  */
     94 class SparseAccumulatorTakeGradientOp
     95     : public ConditionalAccumulatorBaseTakeGradientOp {
     96  public:
     97   explicit SparseAccumulatorTakeGradientOp(OpKernelConstruction* context)
     98       : ConditionalAccumulatorBaseTakeGradientOp(context) {}
     99 
    100  protected:
    101   void CheckSignature(OpKernelContext* ctx,
    102                       ConditionalAccumulatorBase* accumulator,
    103                       DoneCallback callback) override {
    104     // Check signature
    105     OP_REQUIRES_OK_ASYNC(
    106         ctx,
    107         ctx->MatchSignature({DT_STRING_REF, DT_INT32},
    108                             {DT_INT64, accumulator->dtype(), DT_INT64}),
    109         callback);
    110   }
    111 
    112  private:
    113   TF_DISALLOW_COPY_AND_ASSIGN(SparseAccumulatorTakeGradientOp);
    114 };
    115 
    116 REGISTER_KERNEL_BUILDER(
    117     Name("SparseAccumulatorTakeGradient").Device(DEVICE_CPU),
    118     SparseAccumulatorTakeGradientOp);
    119 
    120 }  // namespace tensorflow
    121