Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 // TODO(b/31098934): Figure out why this is necessary here but not in
     20 // any other place, e.g., the cwise lgamma ops.
     21 #define EIGEN_HAS_C99_MATH 1
     22 
     23 #include "tensorflow/core/kernels/betainc_op.h"
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/util/bcast.h"
     31 
     32 namespace tensorflow {
     33 
     34 typedef Eigen::ThreadPoolDevice CPUDevice;
     35 typedef Eigen::GpuDevice GPUDevice;
     36 
     37 template <typename Device, typename T>
     38 class BetaincOp : public OpKernel {
     39  public:
     40   explicit BetaincOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     41 
     42   void Compute(OpKernelContext* ctx) override {
     43     const Tensor& a = ctx->input(0);
     44     const Tensor& b = ctx->input(1);
     45     const Tensor& x = ctx->input(2);
     46 
     47     const TensorShape& a_shape = a.shape();
     48     const TensorShape& b_shape = b.shape();
     49     const TensorShape& x_shape = x.shape();
     50     if (a_shape.dims() > 0 && b_shape.dims() > 0) {
     51       OP_REQUIRES(ctx, a_shape == b_shape,
     52                   errors::InvalidArgument(
     53                       "Shapes of a and b are inconsistent: ",
     54                       a_shape.DebugString(), " vs. ", b_shape.DebugString()));
     55     }
     56     if (a_shape.dims() > 0 && x_shape.dims() > 0) {
     57       OP_REQUIRES(ctx, a_shape == x_shape,
     58                   errors::InvalidArgument(
     59                       "Shapes of a and x are inconsistent: ",
     60                       a_shape.DebugString(), " vs. ", x_shape.DebugString()));
     61     }
     62     if (b_shape.dims() > 0 && x_shape.dims() > 0) {
     63       OP_REQUIRES(ctx, b_shape == x_shape,
     64                   errors::InvalidArgument(
     65                       "Shapes of b and x are inconsistent: ",
     66                       b_shape.DebugString(), " vs. ", x_shape.DebugString()));
     67     }
     68 
     69     TensorShape merged_shape(a_shape);
     70     if (b_shape.dims() > 0) merged_shape = b_shape;
     71     if (x_shape.dims() > 0) merged_shape = x_shape;
     72 
     73     Tensor* output = nullptr;
     74     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, merged_shape, &output));
     75 
     76     if (a_shape == b_shape && a_shape == x_shape) {
     77       functor::Betainc<Device, T, 1> functor;
     78       functor(ctx->eigen_device<Device>(), a.flat<T>(), b.flat<T>(),
     79               x.flat<T>(), output->flat<T>());
     80       return;
     81     }
     82 
     83     auto merged_shape_vec = BCast::FromShape(merged_shape);
     84     BCast a_shaper(BCast::FromShape(a_shape), merged_shape_vec);
     85     BCast b_shaper(BCast::FromShape(b_shape), merged_shape_vec);
     86     BCast x_shaper(BCast::FromShape(x_shape), merged_shape_vec);
     87 
     88     int ndims = static_cast<int>(a_shaper.x_reshape().size());
     89 
     90     switch (ndims) {
     91 #define CASE(NDIM)                                                        \
     92   case NDIM: {                                                            \
     93     functor::Betainc<Device, T, NDIM> functor;                            \
     94     auto a_value = a.shaped<T, NDIM>(a_shaper.x_reshape());               \
     95     auto b_value = b.shaped<T, NDIM>(b_shaper.x_reshape());               \
     96     auto x_value = x.shaped<T, NDIM>(x_shaper.x_reshape());               \
     97     functor.BCast(ctx->eigen_device<Device>(), a_value,                   \
     98                   BCast::ToIndexArray<NDIM>(a_shaper.x_bcast()), b_value, \
     99                   BCast::ToIndexArray<NDIM>(b_shaper.x_bcast()), x_value, \
    100                   BCast::ToIndexArray<NDIM>(x_shaper.x_bcast()),          \
    101                   output->shaped<T, NDIM>(a_shaper.y_reshape()));         \
    102     return;                                                               \
    103   }
    104 
    105       CASE(1);
    106       CASE(2);
    107       default: {
    108         ctx->SetStatus(errors::InvalidArgument(
    109             "Broadcasting rank not supported: ", ndims));
    110         return;
    111       }
    112     }
    113   }
    114 };
    115 
    116 #define REGISTER_KERNELS(type)                                      \
    117   REGISTER_KERNEL_BUILDER(                                          \
    118       Name("Betainc").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    119       BetaincOp<CPUDevice, type>);
    120 
    121 REGISTER_KERNELS(float);
    122 REGISTER_KERNELS(double);
    123 #undef REGISTER_KERNELS
    124 
    125 #if GOOGLE_CUDA
    126 // Forward declarations of the functor specializations for GPU.
    127 namespace functor {
    128 #define DECLARE_GPU_SPEC_NDIM(T, NDIM)                               \
    129   template <>                                                        \
    130   void Betainc<GPUDevice, T, NDIM>::operator()(                      \
    131       const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a,   \
    132       typename TTypes<T, NDIM>::ConstTensor b,                       \
    133       typename TTypes<T, NDIM>::ConstTensor x,                       \
    134       typename TTypes<T, NDIM>::Tensor output);                      \
    135   template <>                                                        \
    136   void Betainc<GPUDevice, T, NDIM>::BCast(                           \
    137       const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a,   \
    138       const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_a, \
    139       typename TTypes<T, NDIM>::ConstTensor b,                       \
    140       const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_b, \
    141       typename TTypes<T, NDIM>::ConstTensor x,                       \
    142       const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_x, \
    143       typename TTypes<T, NDIM>::Tensor output);                      \
    144   extern template struct Betainc<GPUDevice, T, NDIM>;
    145 
    146 #define DECLARE_GPU_SPEC(T)   \
    147   DECLARE_GPU_SPEC_NDIM(T, 1) \
    148   DECLARE_GPU_SPEC_NDIM(T, 2)
    149 
    150 DECLARE_GPU_SPEC(float);
    151 DECLARE_GPU_SPEC(double);
    152 
    153 #undef DECLARE_GPU_SPEC
    154 #undef DECLARE_GPU_SPEC_NDIM
    155 }  // namespace functor
    156 
    157 // Registration of the GPU implementations.
    158 #define REGISTER_GPU_KERNELS(type)                                  \
    159   REGISTER_KERNEL_BUILDER(                                          \
    160       Name("Betainc").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    161       BetaincOp<GPUDevice, type>);
    162 
    163 REGISTER_GPU_KERNELS(float);
    164 REGISTER_GPU_KERNELS(double);
    165 #undef REGISTER_GPU_KERNELS
    166 
    167 #endif  // GOOGLE_CUDA
    168 
    169 }  // namespace tensorflow
    170