Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/kernels/conditional_accumulator.h"
     19 #include "tensorflow/core/kernels/conditional_accumulator_base_op.h"
     20 
     21 namespace tensorflow {
     22 
     23 /**
     24  * Defines a ConditionalAccumulatorOp, which constructs a ConditionalAccumulator
     25  * and returns its handle.
     26  */
     27 template <typename Device, typename T>
     28 class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
     29  public:
     30   explicit ConditionalAccumulatorOp(OpKernelConstruction* context)
     31       : ConditionalAccumulatorBaseOp(context) {}
     32 
     33  protected:
     34   Creator GetCreator() const override {
     35     return [this](ConditionalAccumulatorBase** ret) {
     36       ConditionalAccumulator<Device, T>* accumulator =
     37           new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
     38       *ret = accumulator;
     39       return Status::OK();
     40     };
     41   }
     42 
     43   TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulatorOp);
     44 };
     45 
     46 #define REGISTER_KERNELS(type, dev)                           \
     47   REGISTER_KERNEL_BUILDER(Name("ConditionalAccumulator")      \
     48                               .Device(DEVICE_##dev)           \
     49                               .TypeConstraint<type>("dtype"), \
     50                           ConditionalAccumulatorOp<dev##Device, type>)
     51 
     52 #define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS(type, CPU)
     53 
     54 TF_CALL_half(REGISTER_KERNELS_CPU);
     55 TF_CALL_float(REGISTER_KERNELS_CPU);
     56 TF_CALL_double(REGISTER_KERNELS_CPU);
     57 
     58 #undef REGISTER_KERNELS_CPU
     59 #undef REGISTER_KERNELS
     60 
     61 /**
     62  * Defines a AccumulateGradientOp, the execution of which adds a gradient to the
     63  * given ConditionalAccumulator.
     64  */
     65 class AccumulatorApplyGradientOp
     66     : public ConditionalAccumulatorBaseApplyGradientOp {
     67  public:
     68   explicit AccumulatorApplyGradientOp(OpKernelConstruction* context)
     69       : ConditionalAccumulatorBaseApplyGradientOp(context) {}
     70 
     71  protected:
     72   void CheckSignature(OpKernelContext* ctx,
     73                       ConditionalAccumulatorBase* accumulator) override {
     74     // Check input signature
     75     DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT64};
     76     expected_inputs.push_back(accumulator->dtype());
     77     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
     78   }
     79 
     80  private:
     81   TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorApplyGradientOp);
     82 };
     83 
     84 REGISTER_KERNEL_BUILDER(Name("AccumulatorApplyGradient").Device(DEVICE_CPU),
     85                         AccumulatorApplyGradientOp);
     86 
     87 /**
     88  * Defines a ConditionalAccumulatorBaseTakeGradientOp, the execution of which
     89  * returns the average gradient accumulated by the given ConditionalAccumulator.
     90  */
     91 class AccumulatorTakeGradientOp
     92     : public ConditionalAccumulatorBaseTakeGradientOp {
     93  public:
     94   explicit AccumulatorTakeGradientOp(OpKernelConstruction* context)
     95       : ConditionalAccumulatorBaseTakeGradientOp(context) {}
     96 
     97  protected:
     98   void CheckSignature(OpKernelContext* ctx,
     99                       ConditionalAccumulatorBase* accumulator,
    100                       DoneCallback callback) override {
    101     // Check signature
    102     OP_REQUIRES_OK_ASYNC(
    103         ctx,
    104         ctx->MatchSignature({DT_STRING_REF, DT_INT32}, {accumulator->dtype()}),
    105         callback);
    106   }
    107 
    108  private:
    109   TF_DISALLOW_COPY_AND_ASSIGN(AccumulatorTakeGradientOp);
    110 };
    111 
    112 REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU),
    113                         AccumulatorTakeGradientOp);
    114 
    115 }  // namespace tensorflow
    116