Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
     17 #define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
     18 
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "tensorflow/core/kernels/conditional_accumulator_base.h"
     22 
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/resource_mgr.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/thread_annotations.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 
     37 typedef std::function<void()> DoneCallback;
     38 
     39 namespace tensorflow {
     40 
     41 /**
     42  * Defines a ConditionalAccumulatorBaseOp, which constructs a
     43  * ConditionalAccumulatorBase (via sub-class's Creator) and returns its handle.
     44  */
     45 class ConditionalAccumulatorBaseOp : public OpKernel {
     46  public:
     47   explicit ConditionalAccumulatorBaseOp(OpKernelConstruction* context)
     48       : OpKernel(context), accumulator_handle_set_(false) {
     49     OP_REQUIRES_OK(context,
     50                    context->allocate_persistent(DT_STRING, TensorShape({2}),
     51                                                 &accumulator_handle_, nullptr));
     52     OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
     53     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
     54   }
     55 
     56   void Compute(OpKernelContext* ctx) override {
     57     mutex_lock l(mu_);
     58     if (!accumulator_handle_set_) {
     59       OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx));
     60     }
     61     ctx->set_output_ref(0, &mu_, accumulator_handle_.AccessTensor(ctx));
     62   }
     63 
     64  protected:
     65   ~ConditionalAccumulatorBaseOp() override {
     66     // If the accumulator object was not shared, delete it.
     67     if (accumulator_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
     68       TF_CHECK_OK((cinfo_.resource_manager()
     69                        ->template Delete<ConditionalAccumulatorBase>(
     70                            cinfo_.container(), cinfo_.name())));
     71     }
     72   }
     73 
     74  protected:
     75   typedef std::function<Status(ConditionalAccumulatorBase**)> Creator;
     76 
     77   // Subclasses must override this
     78   virtual Creator GetCreator() const = 0;
     79 
     80   // Variables required to construct ConditionalAccumulator
     81   DataType dtype_;
     82   PartialTensorShape shape_;
     83   ContainerInfo cinfo_;
     84 
     85  private:
     86   Status SetAccumulatorHandle(OpKernelContext* ctx)
     87       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     88     TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
     89 
     90     // Check input signature
     91     DataTypeVector expected_inputs = {};
     92     TF_RETURN_IF_ERROR(ctx->MatchSignature(expected_inputs, {DT_STRING_REF}));
     93 
     94     Creator creator = GetCreator();
     95     ConditionalAccumulatorBase* accumulator;
     96     TF_RETURN_IF_ERROR(
     97         (cinfo_.resource_manager()
     98              ->template LookupOrCreate<ConditionalAccumulatorBase>(
     99                  cinfo_.container(), cinfo_.name(), &accumulator, creator)));
    100     core::ScopedUnref unref_me(accumulator);
    101 
    102     // Verify that the shared accumulator is compatible
    103     // with the requested arguments.
    104     TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def()));
    105     auto h = accumulator_handle_.AccessTensor(ctx)->template flat<string>();
    106     h(0) = cinfo_.container();
    107     h(1) = cinfo_.name();
    108     accumulator_handle_set_ = true;
    109     return Status::OK();
    110   }
    111 
    112   mutex mu_;
    113   PersistentTensor accumulator_handle_ GUARDED_BY(mu_);
    114   bool accumulator_handle_set_ GUARDED_BY(mu_);
    115 };
    116 
    117 /**
    118  * General OpKernel for ConditionalAccumulatorBase-related ops.
    119  */
    120 class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel {
    121  public:
    122   explicit ConditionalAccumulatorBaseAsyncOpKernel(
    123       OpKernelConstruction* context)
    124       : AsyncOpKernel(context) {}
    125 
    126   void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
    127     ConditionalAccumulatorBase* accumulator;
    128     OP_REQUIRES_OK_ASYNC(
    129         ctx, GetResourceFromContext(ctx, "handle", &accumulator), callback);
    130     ComputeAsync(ctx, accumulator, [callback, accumulator]() {
    131       accumulator->Unref();
    132       callback();
    133     });
    134   }
    135 
    136  protected:
    137   virtual void ComputeAsync(OpKernelContext* ctx,
    138                             ConditionalAccumulatorBase* accumulator,
    139                             DoneCallback callback) = 0;
    140 };
    141 
    142 /**
    143  * General OpKernel for ConditionalAccumulatorBase-related ops.
    144  */
    145 class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel {
    146  public:
    147   explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context)
    148       : OpKernel(context) {}
    149 
    150   void Compute(OpKernelContext* ctx) final {
    151     ConditionalAccumulatorBase* accumulator;
    152     OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator));
    153     Compute(ctx, accumulator);
    154     accumulator->Unref();
    155   }
    156 
    157  protected:
    158   virtual void Compute(OpKernelContext* ctx,
    159                        ConditionalAccumulatorBase* accumulator) = 0;
    160 };
    161 
    162 /**
    163  * Defines a AccumulateGradientOp, the execution of which adds a gradient to the
    164  * given ConditionalAccumulator.
    165  */
    166 class ConditionalAccumulatorBaseApplyGradientOp
    167     : public ConditionalAccumulatorBaseSyncOpKernel {
    168  public:
    169   explicit ConditionalAccumulatorBaseApplyGradientOp(
    170       OpKernelConstruction* context)
    171       : ConditionalAccumulatorBaseSyncOpKernel(context) {}
    172 
    173  protected:
    174   virtual void CheckSignature(OpKernelContext* ctx,
    175                               ConditionalAccumulatorBase* accumulator) = 0;
    176 
    177   void Compute(OpKernelContext* ctx,
    178                ConditionalAccumulatorBase* accumulator) override {
    179     // Check input signature
    180     CheckSignature(ctx, accumulator);
    181 
    182     // Get input local_step
    183     const Tensor* local_step_tensor;
    184     OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor));
    185     if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) {
    186       ctx->CtxFailureWithWarning(errors::InvalidArgument(
    187           "Argument local_step must be scalar, but had bad shape ",
    188           local_step_tensor->shape().DebugString()));
    189     }
    190 
    191     // Actually try to apply gradient now
    192     accumulator->TryApplyGrad(local_step_tensor->scalar<int64>()(), ctx);
    193   }
    194 };
    195 
    196 /**
    197  * Defines a TakeAccumulatedGradientOp, the execution of which adds a gradient
    198  * to the given ConditionalAccumulator.
    199  */
    200 class ConditionalAccumulatorBaseTakeGradientOp
    201     : public ConditionalAccumulatorBaseAsyncOpKernel {
    202  public:
    203   explicit ConditionalAccumulatorBaseTakeGradientOp(
    204       OpKernelConstruction* context)
    205       : ConditionalAccumulatorBaseAsyncOpKernel(context) {}
    206 
    207  protected:
    208   virtual void CheckSignature(OpKernelContext* ctx,
    209                               ConditionalAccumulatorBase* accumulator,
    210                               DoneCallback callback) = 0;
    211 
    212   void ComputeAsync(OpKernelContext* ctx,
    213                     ConditionalAccumulatorBase* accumulator,
    214                     DoneCallback callback) override {
    215     // Check signature
    216     CheckSignature(ctx, accumulator, callback);
    217 
    218     // Get input num_required
    219     const Tensor* num_required_tensor;
    220     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_required", &num_required_tensor),
    221                          callback);
    222     if (!TensorShapeUtils::IsScalar(num_required_tensor->shape())) {
    223       ctx->CtxFailureWithWarning(errors::InvalidArgument(
    224           "Argument num_required must be scalar, but had bad shape ",
    225           num_required_tensor->shape().DebugString()));
    226       callback();
    227     }
    228 
    229     // Actually try to take gradient now
    230     accumulator->TryTakeGrad(num_required_tensor->scalar<int32>()(), ctx,
    231                              callback);
    232   }
    233 };
    234 
    235 }  // namespace tensorflow
    236 
    237 #endif  // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
    238