1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 // TODO(b/31098934): Figure out why this is necessary here but not in 20 // any other place, e.g., the cwise lgamma ops. 21 #define EIGEN_HAS_C99_MATH 1 22 23 #include "tensorflow/core/kernels/betainc_op.h" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/numeric_op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/util/bcast.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 typedef Eigen::GpuDevice GPUDevice; 36 37 template <typename Device, typename T> 38 class BetaincOp : public OpKernel { 39 public: 40 explicit BetaincOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 41 42 void Compute(OpKernelContext* ctx) override { 43 const Tensor& a = ctx->input(0); 44 const Tensor& b = ctx->input(1); 45 const Tensor& x = ctx->input(2); 46 47 const TensorShape& a_shape = a.shape(); 48 const TensorShape& b_shape = b.shape(); 49 const TensorShape& x_shape = x.shape(); 50 if (a_shape.dims() > 0 && b_shape.dims() > 0) { 51 OP_REQUIRES(ctx, a_shape == b_shape, 52 errors::InvalidArgument( 53 "Shapes of a and b are inconsistent: ", 54 a_shape.DebugString(), " vs. ", b_shape.DebugString())); 55 } 56 if (a_shape.dims() > 0 && x_shape.dims() > 0) { 57 OP_REQUIRES(ctx, a_shape == x_shape, 58 errors::InvalidArgument( 59 "Shapes of a and x are inconsistent: ", 60 a_shape.DebugString(), " vs. ", x_shape.DebugString())); 61 } 62 if (b_shape.dims() > 0 && x_shape.dims() > 0) { 63 OP_REQUIRES(ctx, b_shape == x_shape, 64 errors::InvalidArgument( 65 "Shapes of b and x are inconsistent: ", 66 b_shape.DebugString(), " vs. ", x_shape.DebugString())); 67 } 68 69 TensorShape merged_shape(a_shape); 70 if (b_shape.dims() > 0) merged_shape = b_shape; 71 if (x_shape.dims() > 0) merged_shape = x_shape; 72 73 Tensor* output = nullptr; 74 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, merged_shape, &output)); 75 76 if (a_shape == b_shape && a_shape == x_shape) { 77 functor::Betainc<Device, T, 1> functor; 78 functor(ctx->eigen_device<Device>(), a.flat<T>(), b.flat<T>(), 79 x.flat<T>(), output->flat<T>()); 80 return; 81 } 82 83 auto merged_shape_vec = BCast::FromShape(merged_shape); 84 BCast a_shaper(BCast::FromShape(a_shape), merged_shape_vec); 85 BCast b_shaper(BCast::FromShape(b_shape), merged_shape_vec); 86 BCast x_shaper(BCast::FromShape(x_shape), merged_shape_vec); 87 88 int ndims = static_cast<int>(a_shaper.x_reshape().size()); 89 90 switch (ndims) { 91 #define CASE(NDIM) \ 92 case NDIM: { \ 93 functor::Betainc<Device, T, NDIM> functor; \ 94 auto a_value = a.shaped<T, NDIM>(a_shaper.x_reshape()); \ 95 auto b_value = b.shaped<T, NDIM>(b_shaper.x_reshape()); \ 96 auto x_value = x.shaped<T, NDIM>(x_shaper.x_reshape()); \ 97 functor.BCast(ctx->eigen_device<Device>(), a_value, \ 98 BCast::ToIndexArray<NDIM>(a_shaper.x_bcast()), b_value, \ 99 BCast::ToIndexArray<NDIM>(b_shaper.x_bcast()), x_value, \ 100 BCast::ToIndexArray<NDIM>(x_shaper.x_bcast()), \ 101 output->shaped<T, NDIM>(a_shaper.y_reshape())); \ 102 return; \ 103 } 104 105 CASE(1); 106 CASE(2); 107 default: { 108 ctx->SetStatus(errors::InvalidArgument( 109 "Broadcasting rank not supported: ", ndims)); 110 return; 111 } 112 } 113 } 114 }; 115 116 #define REGISTER_KERNELS(type) \ 117 REGISTER_KERNEL_BUILDER( \ 118 Name("Betainc").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 119 BetaincOp<CPUDevice, type>); 120 121 REGISTER_KERNELS(float); 122 REGISTER_KERNELS(double); 123 #undef REGISTER_KERNELS 124 125 #if GOOGLE_CUDA 126 // Forward declarations of the functor specializations for GPU. 127 namespace functor { 128 #define DECLARE_GPU_SPEC_NDIM(T, NDIM) \ 129 template <> \ 130 void Betainc<GPUDevice, T, NDIM>::operator()( \ 131 const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a, \ 132 typename TTypes<T, NDIM>::ConstTensor b, \ 133 typename TTypes<T, NDIM>::ConstTensor x, \ 134 typename TTypes<T, NDIM>::Tensor output); \ 135 template <> \ 136 void Betainc<GPUDevice, T, NDIM>::BCast( \ 137 const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a, \ 138 const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_a, \ 139 typename TTypes<T, NDIM>::ConstTensor b, \ 140 const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_b, \ 141 typename TTypes<T, NDIM>::ConstTensor x, \ 142 const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_x, \ 143 typename TTypes<T, NDIM>::Tensor output); \ 144 extern template struct Betainc<GPUDevice, T, NDIM>; 145 146 #define DECLARE_GPU_SPEC(T) \ 147 DECLARE_GPU_SPEC_NDIM(T, 1) \ 148 DECLARE_GPU_SPEC_NDIM(T, 2) 149 150 DECLARE_GPU_SPEC(float); 151 DECLARE_GPU_SPEC(double); 152 153 #undef DECLARE_GPU_SPEC 154 #undef DECLARE_GPU_SPEC_NDIM 155 } // namespace functor 156 157 // Registration of the GPU implementations. 158 #define REGISTER_GPU_KERNELS(type) \ 159 REGISTER_KERNEL_BUILDER( \ 160 Name("Betainc").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 161 BetaincOp<GPUDevice, type>); 162 163 REGISTER_GPU_KERNELS(float); 164 REGISTER_GPU_KERNELS(double); 165 #undef REGISTER_GPU_KERNELS 166 167 #endif // GOOGLE_CUDA 168 169 } // namespace tensorflow 170