Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     21 
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/kernels/bounds_check.h"
     24 #include "tensorflow/core/kernels/cwise_ops_common.h"
     25 
     26 namespace tensorflow {
     27 
     28 typedef Eigen::ThreadPoolDevice CPUDevice;
     29 typedef Eigen::GpuDevice GPUDevice;
     30 
     31 #ifdef TENSORFLOW_USE_SYCL
     32 typedef Eigen::SyclDevice SYCLDevice;
     33 #endif  // TENSORFLOW_USE_SYCL
     34 
     35 template <typename Device, typename T>
     36 class SelectOp : public OpKernel {
     37  public:
     38   explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
     39 
     40   void Compute(OpKernelContext* ctx) override {
     41     const Tensor* cond;
     42     const Tensor* then;
     43     const Tensor* else_;
     44     OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
     45     OP_REQUIRES_OK(ctx, ctx->input("t", &then));
     46     OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
     47 
     48     if (TensorShapeUtils::IsScalar(cond->shape())) {
     49       ComputeScalar(ctx, cond, then, else_);
     50       return;
     51     }
     52 
     53     bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
     54                          !TensorShapeUtils::IsVector(then->shape()));
     55 
     56     if (broadcasting) {
     57       ComputeBroadcasting(ctx, cond, then, else_);
     58     } else {
     59       ComputeElementwise(ctx, cond, then, else_);
     60     }
     61   }
     62 
     63  protected:
     64   void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond,
     65                            const Tensor* then, const Tensor* else_) {
     66     // Preliminary validation of sizes.
     67     OP_REQUIRES(
     68         ctx, TensorShapeUtils::IsVector(cond->shape()),
     69         errors::InvalidArgument("'cond' must be a vector, but saw shape: ",
     70                                 cond->shape().DebugString()));
     71     OP_REQUIRES(
     72         ctx,
     73         FastBoundsCheck(cond->NumElements(),
     74                         std::numeric_limits<Eigen::DenseIndex>::max()),
     75         errors::InvalidArgument("cond vector larger than ",
     76                                 std::numeric_limits<Eigen::DenseIndex>::max()));
     77     OP_REQUIRES(
     78         ctx,
     79         FastBoundsCheck(then->flat_outer_dims<T>().dimension(1),
     80                         std::numeric_limits<Eigen::DenseIndex>::max()),
     81         errors::InvalidArgument("flat outer dims dim 1 size >= ",
     82                                 std::numeric_limits<Eigen::DenseIndex>::max()));
     83 
     84     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()),
     85                 errors::InvalidArgument(
     86                     "'then' must be at least a vector, but saw shape: ",
     87                     then->shape().DebugString()));
     88     OP_REQUIRES(
     89         ctx, then->shape().dim_size(0) == cond->NumElements(),
     90         errors::InvalidArgument(
     91             "Number of batches of 'then' must match size of 'cond', but saw: ",
     92             then->shape().dim_size(0), " vs. ", cond->NumElements()));
     93     OP_REQUIRES(
     94         ctx, then->shape().IsSameSize(else_->shape()),
     95         errors::InvalidArgument(
     96             "'then' and 'else' must have the same size.  but received: ",
     97             then->shape().DebugString(), " vs. ",
     98             else_->shape().DebugString()));
     99 
    100     Tensor* output = nullptr;
    101     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    102                             {"t", "e"}, "output", then->shape(), &output));
    103     if (output->NumElements() > 0) {
    104       functor::BatchSelectFunctor<Device, T> func;
    105       func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
    106            cond->vec<bool>(), then->flat_outer_dims<T>(),
    107            else_->flat_outer_dims<T>());
    108     }
    109   }
    110 
    111   void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
    112                           const Tensor* then, const Tensor* else_) {
    113     if (!ctx->ValidateInputsAreSameShape(this)) return;
    114     Tensor* output = nullptr;
    115     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    116                             {"t", "e"}, "output", then->shape(), &output));
    117     if (output->NumElements() > 0) {
    118       functor::SelectFunctor<Device, T> func;
    119       func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
    120            then->flat<T>(), else_->flat<T>());
    121     }
    122   }
    123 
    124   void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
    125                      const Tensor* then, const Tensor* else_) {
    126     OP_REQUIRES(
    127         ctx, then->shape().IsSameSize(else_->shape()),
    128         errors::InvalidArgument(
    129             "'then' and 'else' must have the same size.  but received: ",
    130             then->shape().DebugString(), " vs. ",
    131             else_->shape().DebugString()));
    132 
    133     Tensor* output = nullptr;
    134     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    135                             {"t", "e"}, "output", then->shape(), &output));
    136 
    137     if (output->NumElements() > 0) {
    138       functor::SelectScalarFunctor<Device, T> func;
    139       TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
    140       func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
    141            then->flat<T>(), else_->flat<T>());
    142     }
    143   }
    144 
    145  private:
    146   TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
    147 };
    148 
    149 #define REGISTER_SELECT(type)                                      \
    150   REGISTER_KERNEL_BUILDER(                                         \
    151       Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    152       SelectOp<CPUDevice, type>);
    153 
    154 TF_CALL_ALL_TYPES(REGISTER_SELECT);
    155 
    156 #if GOOGLE_CUDA
    157 
    158 // Registration of the GPU implementations.
    159 #define REGISTER_SELECT_GPU(type)                                  \
    160   REGISTER_KERNEL_BUILDER(                                         \
    161       Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    162       SelectOp<GPUDevice, type>);
    163 
    164 REGISTER_SELECT_GPU(Eigen::half);
    165 REGISTER_SELECT_GPU(float);
    166 REGISTER_SELECT_GPU(double);
    167 REGISTER_SELECT_GPU(int32);
    168 REGISTER_SELECT_GPU(int64);
    169 REGISTER_SELECT_GPU(complex64);
    170 REGISTER_SELECT_GPU(complex128);
    171 
    172 #undef REGISTER_SELECT_GPU
    173 
    174 #endif  // GOOGLE_CUDA
    175 
    176 #ifdef TENSORFLOW_USE_SYCL
    177 // Registration of the SYCL implementations.
    178 #define REGISTER_SELECT_SYCL(type)                                  \
    179   REGISTER_KERNEL_BUILDER(                                          \
    180       Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    181       SelectOp<SYCLDevice, type>);
    182 
    183 REGISTER_SELECT_SYCL(float);
    184 REGISTER_SELECT_SYCL(double);
    185 REGISTER_SELECT_SYCL(int32);
    186 REGISTER_SELECT_SYCL(int64);
    187 #undef REGISTER_SELECT_SYCL
    188 #endif  // TENSORFLOW_USE_SYCL
    189 
    190 namespace functor {
    191 
    192 // CPU Specializations of Select functors.
    193 template <typename Device, typename T>
    194 struct SelectFunctorBase {
    195   void operator()(const Device& d, typename TTypes<T>::Flat out,
    196                   typename TTypes<bool>::ConstFlat cond_flat,
    197                   typename TTypes<T>::ConstFlat then_flat,
    198                   typename TTypes<T>::ConstFlat else_flat) {
    199     Assign(d, out, cond_flat.select(then_flat, else_flat));
    200   }
    201 };
    202 
    203 template <typename T>
    204 struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {};
    205 #ifdef TENSORFLOW_USE_SYCL
    206 template <typename T>
    207 struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
    208 #endif  // TENSORFLOW_USE_SYCL
    209 
    210 template <typename Device, typename T>
    211 struct SelectScalarFunctorBase {
    212   void operator()(const Device& d, typename TTypes<T>::Flat out,
    213                   TTypes<bool>::ConstScalar cond,
    214                   typename TTypes<T>::ConstFlat then_flat,
    215                   typename TTypes<T>::ConstFlat else_flat) {
    216     out.device(d) = cond() ? then_flat : else_flat;
    217   }
    218 };
    219 
    220 // CPU Specializations of Select functors with scalar
    221 template <typename T>
    222 struct SelectScalarFunctor<CPUDevice, T>
    223     : SelectScalarFunctorBase<CPUDevice, T> {};
    224 #ifdef TENSORFLOW_USE_SYCL
    225 template <typename T>
    226 struct SelectScalarFunctor<SYCLDevice, T>
    227     : SelectScalarFunctorBase<SYCLDevice, T> {};
    228 #endif  // TENSORFLOW_USE_SYCL
    229 
    230 template <typename Device, typename T>
    231 struct BatchSelectFunctorBase {
    232   void operator()(const Device& d,
    233                   typename TTypes<T>::Matrix output_flat_outer_dims,
    234                   TTypes<bool>::ConstVec cond_vec,
    235                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
    236                   typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
    237     const Eigen::DenseIndex batch = cond_vec.size();
    238     const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1);
    239 
    240 #if !defined(EIGEN_HAS_INDEX_LIST)
    241     Eigen::array<Eigen::DenseIndex, 2> broadcast_dims{{1, all_but_batch}};
    242     Eigen::Tensor<Eigen::DenseIndex, 2>::Dimensions reshape_dims{{batch, 1}};
    243 #else
    244     Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims;
    245     broadcast_dims.set(1, all_but_batch);
    246     Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims;
    247     reshape_dims.set(0, batch);
    248 #endif
    249 
    250     Assign(d, output_flat_outer_dims,
    251            cond_vec.reshape(reshape_dims)
    252                .broadcast(broadcast_dims)
    253                .select(then_flat_outer_dims, else_flat_outer_dims));
    254   }
    255 };
    256 
    257 template <typename T>
    258 struct BatchSelectFunctor<CPUDevice, T> : BatchSelectFunctorBase<CPUDevice, T> {
    259 };
    260 #ifdef TENSORFLOW_USE_SYCL
    261 template <typename T>
    262 struct BatchSelectFunctor<SYCLDevice, T>
    263     : BatchSelectFunctorBase<SYCLDevice, T> {};
    264 #endif  // TENSORFLOW_USE_SYCL
    265 
    266 }  // namespace functor
    267 
    268 }  // namespace tensorflow
    269