1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif // GOOGLE_CUDA 21 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/kernels/bounds_check.h" 24 #include "tensorflow/core/kernels/cwise_ops_common.h" 25 26 namespace tensorflow { 27 28 typedef Eigen::ThreadPoolDevice CPUDevice; 29 typedef Eigen::GpuDevice GPUDevice; 30 31 #ifdef TENSORFLOW_USE_SYCL 32 typedef Eigen::SyclDevice SYCLDevice; 33 #endif // TENSORFLOW_USE_SYCL 34 35 template <typename Device, typename T> 36 class SelectOp : public OpKernel { 37 public: 38 explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} 39 40 void Compute(OpKernelContext* ctx) override { 41 const Tensor* cond; 42 const Tensor* then; 43 const Tensor* else_; 44 OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); 45 OP_REQUIRES_OK(ctx, ctx->input("t", &then)); 46 OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); 47 48 if (TensorShapeUtils::IsScalar(cond->shape())) { 49 ComputeScalar(ctx, cond, then, else_); 50 return; 51 } 52 53 bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) && 54 !TensorShapeUtils::IsVector(then->shape())); 55 56 if (broadcasting) { 57 ComputeBroadcasting(ctx, cond, then, else_); 58 } else { 59 ComputeElementwise(ctx, cond, then, else_); 60 } 61 } 62 63 protected: 64 void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond, 65 const Tensor* then, const Tensor* else_) { 66 // Preliminary validation of sizes. 67 OP_REQUIRES( 68 ctx, TensorShapeUtils::IsVector(cond->shape()), 69 errors::InvalidArgument("'cond' must be a vector, but saw shape: ", 70 cond->shape().DebugString())); 71 OP_REQUIRES( 72 ctx, 73 FastBoundsCheck(cond->NumElements(), 74 std::numeric_limits<Eigen::DenseIndex>::max()), 75 errors::InvalidArgument("cond vector larger than ", 76 std::numeric_limits<Eigen::DenseIndex>::max())); 77 OP_REQUIRES( 78 ctx, 79 FastBoundsCheck(then->flat_outer_dims<T>().dimension(1), 80 std::numeric_limits<Eigen::DenseIndex>::max()), 81 errors::InvalidArgument("flat outer dims dim 1 size >= ", 82 std::numeric_limits<Eigen::DenseIndex>::max())); 83 84 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()), 85 errors::InvalidArgument( 86 "'then' must be at least a vector, but saw shape: ", 87 then->shape().DebugString())); 88 OP_REQUIRES( 89 ctx, then->shape().dim_size(0) == cond->NumElements(), 90 errors::InvalidArgument( 91 "Number of batches of 'then' must match size of 'cond', but saw: ", 92 then->shape().dim_size(0), " vs. ", cond->NumElements())); 93 OP_REQUIRES( 94 ctx, then->shape().IsSameSize(else_->shape()), 95 errors::InvalidArgument( 96 "'then' and 'else' must have the same size. but received: ", 97 then->shape().DebugString(), " vs. ", 98 else_->shape().DebugString())); 99 100 Tensor* output = nullptr; 101 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 102 {"t", "e"}, "output", then->shape(), &output)); 103 if (output->NumElements() > 0) { 104 functor::BatchSelectFunctor<Device, T> func; 105 func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), 106 cond->vec<bool>(), then->flat_outer_dims<T>(), 107 else_->flat_outer_dims<T>()); 108 } 109 } 110 111 void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond, 112 const Tensor* then, const Tensor* else_) { 113 if (!ctx->ValidateInputsAreSameShape(this)) return; 114 Tensor* output = nullptr; 115 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 116 {"t", "e"}, "output", then->shape(), &output)); 117 if (output->NumElements() > 0) { 118 functor::SelectFunctor<Device, T> func; 119 func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), 120 then->flat<T>(), else_->flat<T>()); 121 } 122 } 123 124 void ComputeScalar(OpKernelContext* ctx, const Tensor* cond, 125 const Tensor* then, const Tensor* else_) { 126 OP_REQUIRES( 127 ctx, then->shape().IsSameSize(else_->shape()), 128 errors::InvalidArgument( 129 "'then' and 'else' must have the same size. but received: ", 130 then->shape().DebugString(), " vs. ", 131 else_->shape().DebugString())); 132 133 Tensor* output = nullptr; 134 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 135 {"t", "e"}, "output", then->shape(), &output)); 136 137 if (output->NumElements() > 0) { 138 functor::SelectScalarFunctor<Device, T> func; 139 TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); 140 func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, 141 then->flat<T>(), else_->flat<T>()); 142 } 143 } 144 145 private: 146 TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); 147 }; 148 149 #define REGISTER_SELECT(type) \ 150 REGISTER_KERNEL_BUILDER( \ 151 Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 152 SelectOp<CPUDevice, type>); 153 154 TF_CALL_ALL_TYPES(REGISTER_SELECT); 155 156 #if GOOGLE_CUDA 157 158 // Registration of the GPU implementations. 159 #define REGISTER_SELECT_GPU(type) \ 160 REGISTER_KERNEL_BUILDER( \ 161 Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 162 SelectOp<GPUDevice, type>); 163 164 REGISTER_SELECT_GPU(Eigen::half); 165 REGISTER_SELECT_GPU(float); 166 REGISTER_SELECT_GPU(double); 167 REGISTER_SELECT_GPU(int32); 168 REGISTER_SELECT_GPU(int64); 169 REGISTER_SELECT_GPU(complex64); 170 REGISTER_SELECT_GPU(complex128); 171 172 #undef REGISTER_SELECT_GPU 173 174 #endif // GOOGLE_CUDA 175 176 #ifdef TENSORFLOW_USE_SYCL 177 // Registration of the SYCL implementations. 178 #define REGISTER_SELECT_SYCL(type) \ 179 REGISTER_KERNEL_BUILDER( \ 180 Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 181 SelectOp<SYCLDevice, type>); 182 183 REGISTER_SELECT_SYCL(float); 184 REGISTER_SELECT_SYCL(double); 185 REGISTER_SELECT_SYCL(int32); 186 REGISTER_SELECT_SYCL(int64); 187 #undef REGISTER_SELECT_SYCL 188 #endif // TENSORFLOW_USE_SYCL 189 190 namespace functor { 191 192 // CPU Specializations of Select functors. 193 template <typename Device, typename T> 194 struct SelectFunctorBase { 195 void operator()(const Device& d, typename TTypes<T>::Flat out, 196 typename TTypes<bool>::ConstFlat cond_flat, 197 typename TTypes<T>::ConstFlat then_flat, 198 typename TTypes<T>::ConstFlat else_flat) { 199 Assign(d, out, cond_flat.select(then_flat, else_flat)); 200 } 201 }; 202 203 template <typename T> 204 struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {}; 205 #ifdef TENSORFLOW_USE_SYCL 206 template <typename T> 207 struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {}; 208 #endif // TENSORFLOW_USE_SYCL 209 210 template <typename Device, typename T> 211 struct SelectScalarFunctorBase { 212 void operator()(const Device& d, typename TTypes<T>::Flat out, 213 TTypes<bool>::ConstScalar cond, 214 typename TTypes<T>::ConstFlat then_flat, 215 typename TTypes<T>::ConstFlat else_flat) { 216 out.device(d) = cond() ? then_flat : else_flat; 217 } 218 }; 219 220 // CPU Specializations of Select functors with scalar 221 template <typename T> 222 struct SelectScalarFunctor<CPUDevice, T> 223 : SelectScalarFunctorBase<CPUDevice, T> {}; 224 #ifdef TENSORFLOW_USE_SYCL 225 template <typename T> 226 struct SelectScalarFunctor<SYCLDevice, T> 227 : SelectScalarFunctorBase<SYCLDevice, T> {}; 228 #endif // TENSORFLOW_USE_SYCL 229 230 template <typename Device, typename T> 231 struct BatchSelectFunctorBase { 232 void operator()(const Device& d, 233 typename TTypes<T>::Matrix output_flat_outer_dims, 234 TTypes<bool>::ConstVec cond_vec, 235 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 236 typename TTypes<T>::ConstMatrix else_flat_outer_dims) { 237 const Eigen::DenseIndex batch = cond_vec.size(); 238 const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1); 239 240 #if !defined(EIGEN_HAS_INDEX_LIST) 241 Eigen::array<Eigen::DenseIndex, 2> broadcast_dims{{1, all_but_batch}}; 242 Eigen::Tensor<Eigen::DenseIndex, 2>::Dimensions reshape_dims{{batch, 1}}; 243 #else 244 Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims; 245 broadcast_dims.set(1, all_but_batch); 246 Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims; 247 reshape_dims.set(0, batch); 248 #endif 249 250 Assign(d, output_flat_outer_dims, 251 cond_vec.reshape(reshape_dims) 252 .broadcast(broadcast_dims) 253 .select(then_flat_outer_dims, else_flat_outer_dims)); 254 } 255 }; 256 257 template <typename T> 258 struct BatchSelectFunctor<CPUDevice, T> : BatchSelectFunctorBase<CPUDevice, T> { 259 }; 260 #ifdef TENSORFLOW_USE_SYCL 261 template <typename T> 262 struct BatchSelectFunctor<SYCLDevice, T> 263 : BatchSelectFunctorBase<SYCLDevice, T> {}; 264 #endif // TENSORFLOW_USE_SYCL 265 266 } // namespace functor 267 268 } // namespace tensorflow 269