Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif
     21 
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/kernels/assign_op.h"
     26 #include "tensorflow/core/kernels/dense_update_functor.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 #include "tensorflow/core/platform/mutex.h"
     29 #include "tensorflow/core/platform/types.h"
     30 
     31 namespace tensorflow {
     32 
     33 template <typename Device, typename T>
     34 class AssignOpT : public AssignOp {
     35  public:
     36   using AssignOp::AssignOp;
     37 
     38   void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override {
     39     functor::DenseUpdate<Device, T, ASSIGN> copy;
     40     copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>());
     41   }
     42 };
     43 
     44 // TODO(jeff): Get rid of use_exclusive_lock_ option
     45 template <typename Device, typename T, DenseUpdateType OP>
     46 class DenseUpdateOp : public OpKernel {
     47  public:
     48   explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) {
     49     OP_REQUIRES_OK(context,
     50                    context->GetAttr("use_locking", &use_exclusive_lock_));
     51     const DataType dt = DataTypeToEnum<T>::v();
     52     OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt},
     53                                                     {MakeRefType(dt)}));
     54   }
     55 
     56   void Compute(OpKernelContext* context) override {
     57     // We always return the input ref.
     58     context->forward_ref_input_to_ref_output(0, 0);
     59 
     60     if (use_exclusive_lock_) {
     61       mutex_lock l(*context->input_ref_mutex(0));
     62       DoUpdate(context);
     63     } else {
     64       DoUpdate(context);
     65     }
     66   }
     67 
     68  private:
     69   void DoUpdate(OpKernelContext* context) {
     70     Tensor Tparams = context->mutable_input(0, use_exclusive_lock_);
     71     const Tensor& Tupdate = context->input(1);
     72     OP_REQUIRES(context, Tparams.IsInitialized(),
     73                 errors::FailedPrecondition("Attempting to use uninitialized "
     74                                            "parameters: ",
     75                                            requested_input(0)));
     76     OP_REQUIRES(
     77         context, Tparams.IsSameSize(Tupdate),
     78         errors::InvalidArgument("Parameters and update must be the same size"));
     79 
     80     functor::DenseUpdate<Device, T, OP> update_functor;
     81     update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(),
     82                    Tupdate.flat<T>());
     83   }
     84 
     85   bool use_exclusive_lock_;
     86 };
     87 
     88 typedef Eigen::ThreadPoolDevice CPUDevice;
     89 typedef Eigen::GpuDevice GPUDevice;
     90 #ifdef TENSORFLOW_USE_SYCL
     91 typedef Eigen::SyclDevice SYCLDevice;
     92 #endif  // TENSORFLOW_USE_SYCL
     93 
     94 #define REGISTER_KERNELS(type)                                     \
     95   REGISTER_KERNEL_BUILDER(                                         \
     96       Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
     97       AssignOpT<CPUDevice, type>);
     98 
     99 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    100 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
    101 #undef REGISTER_KERNELS
    102 
    103 #if GOOGLE_CUDA
    104 // Only register 'Assign' on GPU for the subset of types also supported by
    105 // 'Variable' (see variable_ops.cc.)
    106 #define REGISTER_GPU_KERNELS(type)                                 \
    107   REGISTER_KERNEL_BUILDER(                                         \
    108       Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    109       AssignOpT<GPUDevice, type>);
    110 
    111 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
    112 TF_CALL_int64(REGISTER_GPU_KERNELS);
    113 #undef REGISTER_GPU_KERNELS
    114 #endif  // GOOGLE_CUDA
    115 
    116 #ifdef TENSORFLOW_USE_SYCL
    117 #define REGISTER_SYCL_KERNELS(type)                                 \
    118   REGISTER_KERNEL_BUILDER(                                          \
    119       Name("Assign").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    120       AssignOpT<SYCLDevice, type>);
    121 
    122 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
    123 #undef REGISTER_SYCL_KERNELS
    124 #endif  // TENSORFLOW_USE_SYCL
    125 
    126 #define REGISTER_KERNELS(type)                                        \
    127   REGISTER_KERNEL_BUILDER(                                            \
    128       Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    129       DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>);          \
    130   REGISTER_KERNEL_BUILDER(                                            \
    131       Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    132       DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>);
    133 
    134 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
    135 #undef REGISTER_KERNELS
    136 
    137 #if GOOGLE_CUDA
    138 #define REGISTER_GPU_KERNELS(type)                                    \
    139   REGISTER_KERNEL_BUILDER(                                            \
    140       Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    141       DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>);          \
    142   REGISTER_KERNEL_BUILDER(                                            \
    143       Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    144       DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>);
    145 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
    146 TF_CALL_int64(REGISTER_GPU_KERNELS);
    147 #undef REGISTER_GPU_KERNELS
    148 #endif  // end GOOGLE_CUDA
    149 
    150 #ifdef TENSORFLOW_USE_SYCL
    151 #define REGISTER_SYCL_KERNELS(type)                                    \
    152   REGISTER_KERNEL_BUILDER(                                             \
    153       Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    154       DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>);          \
    155   REGISTER_KERNEL_BUILDER(                                             \
    156       Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    157       DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
    158 
    159 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
    160 #undef REGISTER_SYCL_KERNELS
    161 #endif  // TENSORFLOW_USE_SYCL
    162 }  // namespace tensorflow
    163