Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/where_op.h"
     25 
     26 #include <memory>
     27 #include <numeric>
     28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/register_types.h"
     31 #include "tensorflow/core/framework/tensor.h"
     32 #include "tensorflow/core/framework/tensor_shape.h"
     33 #include "tensorflow/core/framework/tensor_types.h"
     34 #include "tensorflow/core/framework/types.h"
     35 #include "tensorflow/core/kernels/bounds_check.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/macros.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 #if GOOGLE_CUDA
     41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     42 #include "tensorflow/core/kernels/cuda_solvers.h"
     43 #include "tensorflow/core/platform/cuda.h"
     44 
     45 using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
     46 #endif  // GOOGLE_CUDA
     47 
     48 namespace tensorflow {
     49 
     50 typedef Eigen::ThreadPoolDevice CPUDevice;
     51 typedef Eigen::GpuDevice GPUDevice;
     52 
     53 namespace functor {
     54 
     55 namespace {
     56 template <typename T>
     57 int64 CountAccumulator(const T* begin, const T* end) {
     58   return std::accumulate(begin, end, 0LL, [](int64 accum, const T& val) {
     59     return accum + (val != T(0));
     60   });
     61 }
     62 
     63 template <>
     64 int64 CountAccumulator<bool>(const bool* begin, const bool* end) {
     65   return std::accumulate(begin, end, 0LL);
     66 }
     67 
     68 }  // namespace
     69 
     70 template <typename T>
     71 struct NumTrue<CPUDevice, T, int64> {
     72   static Status Compute(OpKernelContext* ctx, const CPUDevice& d,
     73                         typename TTypes<T>::ConstFlat input,
     74                         TTypes<int64>::Scalar num_true) {
     75     num_true() = CountAccumulator<T>(input.data(), input.data() + input.size());
     76     return Status::OK();
     77   }
     78 };
     79 
     80 template <int DIMS, typename T, typename TIndex>
     81 struct Where<CPUDevice, DIMS, T, TIndex> {
     82   EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
     83       typename TTypes<int64>::Matrix output,
     84       const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n,
     85       TIndex index) {
     86     for (int i = 0; i < DIMS; ++i) {
     87       output(true_n, i) = index / strides[i];
     88       index -= output(true_n, i) * strides[i];
     89     }
     90   }
     91 
     92   EIGEN_ALWAYS_INLINE static Status Compute(
     93       OpKernelContext* ctx, const CPUDevice& d,
     94       typename TTypes<T, DIMS>::ConstTensor input,
     95       typename TTypes<int64>::Matrix output, TIndex* found_true) {
     96     Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions();
     97     Eigen::DSizes<TIndex, DIMS> strides;
     98 
     99     EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
    100                          static_cast<int>(Eigen::RowMajor)),
    101                         INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
    102 
    103     strides[DIMS - 1] = 1;
    104     for (int i = DIMS - 2; i >= 0; --i) {
    105       strides[i] = strides[i + 1] * dims[i + 1];
    106     }
    107 
    108     Eigen::DenseIndex output_size = output.dimension(0);
    109     for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
    110       if (input.data()[n] != T(0)) {
    111         if (FastBoundsCheck(*found_true, output_size)) {
    112           WriteIndexRowMajor(output, strides, *found_true, n);
    113         }
    114         ++*found_true;
    115       }
    116     }
    117     return Status::OK();
    118   }
    119 };
    120 
    121 }  // namespace functor
    122 
    123 template <typename T>
    124 class WhereCPUOp : public OpKernel {
    125  public:
    126   explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {}
    127 
    128   void Compute(OpKernelContext* context) override {
    129     const Tensor& input = context->input(0);
    130 
    131     OP_REQUIRES(
    132         context, input.dtype() != DT_HALF,
    133         errors::Unimplemented("No WhereOp available for float16/half type on "
    134                               "CPU; dying in CPU WhereOp to avoid silently "
    135                               "creating costly copies from device."));
    136 
    137     const int input_dims = input.dims();
    138 
    139     Tensor num_true;
    140     OP_REQUIRES_OK(
    141         context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true));
    142     auto num_true_t = num_true.scalar<int64>();
    143 
    144     Status s = functor::NumTrue<CPUDevice, T, int64>::Compute(
    145         context, context->eigen_device<CPUDevice>(), input.flat<T>(),
    146         num_true_t);
    147     OP_REQUIRES_OK(context, s);
    148     TensorShape output_shape({num_true_t(), input_dims});
    149     Tensor* output = nullptr;
    150     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    151 
    152     // TODO(ebrevdo): Replace single-threaded copy with a
    153     // multithreaded block copy by getting block counts above instead
    154     // of a global NumTrue, then having each block filled in in
    155     // separate threads below.
    156     int64 found_true = 0;
    157 
    158 #define HANDLE_DIM(NDIM)                                                      \
    159   case NDIM: {                                                                \
    160     Status s = functor::Where<CPUDevice, NDIM, T, int64>::Compute(            \
    161         context, context->eigen_device<CPUDevice>(), input.tensor<T, NDIM>(), \
    162         output->matrix<int64>(), &found_true);                                \
    163     OP_REQUIRES_OK(context, s);                                               \
    164   } break;
    165 
    166     switch (input_dims) {
    167       HANDLE_DIM(1);
    168       HANDLE_DIM(2);
    169       HANDLE_DIM(3);
    170       HANDLE_DIM(4);
    171       HANDLE_DIM(5);
    172 
    173       default:
    174         OP_REQUIRES(context, false,
    175                     errors::InvalidArgument(
    176                         "WhereOp : Unhandled input dimensions: ", input_dims));
    177     }
    178 #undef HANDLE_DIM
    179 
    180     OP_REQUIRES(
    181         context, found_true == num_true_t(),
    182         errors::InvalidArgument(
    183             "WhereOp: Race condition between counting the number of true "
    184             "elements and writing them.  When counting, saw ",
    185             num_true_t(), " elements; but when writing their indices, saw ",
    186             found_true, " elements."));
    187   }
    188 
    189  private:
    190   TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp);
    191 };
    192 
    193 #define REGISTER_WHERE_OP(T) \
    194   REGISTER_KERNEL_BUILDER(   \
    195       Name("Where").Device(DEVICE_CPU).TypeConstraint<T>("T"), WhereCPUOp<T>);
    196 
    197 TF_CALL_NUMBER_TYPES(REGISTER_WHERE_OP);
    198 TF_CALL_bool(REGISTER_WHERE_OP);
    199 
    200 #undef REGISTER_WHERE_OP
    201 
    202 #if GOOGLE_CUDA
    203 
    204 namespace functor {
    205 
    206 #define DECLARE_GPU_NUMTRUE(T, Tindex)                                      \
    207   template <>                                                               \
    208   Status NumTrue<GPUDevice, T, Tindex>::Compute(                            \
    209       OpKernelContext* ctx, const GPUDevice& d, TTypes<T>::ConstFlat input, \
    210       TTypes<Tindex>::Scalar num_true);                                     \
    211   extern template struct NumTrue<GPUDevice, T, Tindex>
    212 
    213 #define DECLARE_GPU_NUMTRUE_TYPE(T) \
    214   DECLARE_GPU_NUMTRUE(T, int32);    \
    215   DECLARE_GPU_NUMTRUE(T, int64);
    216 
    217 TF_CALL_NUMBER_TYPES(DECLARE_GPU_NUMTRUE_TYPE);
    218 TF_CALL_bool(DECLARE_GPU_NUMTRUE_TYPE);
    219 
    220 #undef DECLARE_GPU_NUMTRUE_TYPE
    221 #undef DECLARE_GPU_NUMTRUE
    222 
    223 #define DECLARE_GPU_WHERE_INDEX(Dims, T, Tindex)                  \
    224   template <>                                                     \
    225   Status Where<GPUDevice, Dims, T, Tindex>::Compute(              \
    226       OpKernelContext* ctx, const GPUDevice& d,                   \
    227       typename TTypes<T, Dims>::ConstTensor input,                \
    228       typename TTypes<int64>::Matrix output, Tindex* found_true); \
    229   extern template struct Where<GPUDevice, Dims, T, Tindex>;
    230 #define DECLARE_GPU_WHERE(Dims, T)         \
    231   DECLARE_GPU_WHERE_INDEX(Dims, T, int32); \
    232   DECLARE_GPU_WHERE_INDEX(Dims, T, int64);
    233 
    234 #define DECLARE_GPU_WHERE_TYPES(T) \
    235   DECLARE_GPU_WHERE(1, T);         \
    236   DECLARE_GPU_WHERE(2, T);         \
    237   DECLARE_GPU_WHERE(3, T);         \
    238   DECLARE_GPU_WHERE(4, T);         \
    239   DECLARE_GPU_WHERE(5, T);
    240 
    241 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_WHERE_TYPES);
    242 
    243 #undef DECLARE_GPU_WHERE_TYPES
    244 #undef DECLARE_GPU_WHERE
    245 #undef DECLARE_GPU_WHERE_INDEX
    246 
    247 }  // namespace functor
    248 
    249 template <typename T>
    250 class WhereGPUOp : public AsyncOpKernel {
    251  public:
    252   explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
    253 
    254   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    255     const Tensor& input = context->input(0);
    256     const int input_dims = input.dims();
    257 
    258     if (input.NumElements() < std::numeric_limits<int32>::max()) {
    259       ComputeAsyncType<int32>(input, input_dims, context, done);
    260     } else {
    261       ComputeAsyncType<int64>(input, input_dims, context, done);
    262     }
    263   }
    264 
    265   template <typename Tindex>
    266   void ComputeAsyncType(const Tensor& input, const int input_dims,
    267                         OpKernelContext* context, DoneCallback done) {
    268     // Step 0: alloc nnz
    269     // Step 1: call nnz kernel
    270     // Step 2: copy nnz to host
    271     // Step 3: call create_output
    272     // Step 4: call where kernel
    273     Tensor num_true;
    274     OP_REQUIRES_OK_ASYNC(context,
    275                          context->allocate_temp(DataTypeToEnum<Tindex>::v(),
    276                                                 TensorShape({}), &num_true),
    277                          done);
    278 
    279     auto num_true_t = num_true.scalar<Tindex>();
    280 
    281     perftools::gputools::DeviceMemoryBase num_true_ptr(
    282         static_cast<void*>(num_true_t.data()));
    283     // Push kernel to stream to get number of true elements.
    284     const GPUDevice& d = context->eigen_device<GPUDevice>();
    285     Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute(
    286         context, d, input.flat<T>(), num_true_t);
    287     OP_REQUIRES_OK_ASYNC(context, s, done);
    288 
    289     // Copy num_true to host;
    290     ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true);
    291 
    292     auto stream = context->op_device_context()->stream();
    293     OP_REQUIRES_ASYNC(
    294         context,
    295         stream
    296             ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr,
    297                          sizeof(Tindex))
    298             .ok(),
    299         errors::Internal("WhereOp: failed to copy num_true from device"), done);
    300 
    301     auto create_and_check_output = [context, &d, &input, input_dims,
    302                                     num_true_host, done]() {
    303       // Ensure that within the callback, the proper GPU settings are
    304       // configured.
    305       auto stream = context->op_device_context()->stream();
    306       ScopedActivateExecutorContext scoped_activation{stream->parent()};
    307 
    308       Tindex num_true = *num_true_host.data();
    309 
    310       // TODO(ebrevdo): Properly copy back found_true value to CPU for
    311       // validation checking.  Currently Where<GPUDevice>::Compute()
    312       // does not perform this copy back to CPU.
    313       Tindex found_true = -1;
    314 
    315       // Step 1: Allocate the output and perform the selection/copy.
    316       Tensor* output;
    317       OP_REQUIRES_OK_ASYNC(context,
    318                            context->allocate_output(
    319                                0, TensorShape({num_true, input_dims}), &output),
    320                            done);
    321 
    322 #define HANDLE_DIM(NDIM)                                              \
    323   case NDIM: {                                                        \
    324     Status s = functor::Where<GPUDevice, NDIM, T, Tindex>::Compute(   \
    325         context, d, input.tensor<T, NDIM>(), output->matrix<int64>(), \
    326         &found_true);                                                 \
    327     OP_REQUIRES_OK_ASYNC(context, s, done);                           \
    328   } break;
    329 
    330       switch (input_dims) {
    331         HANDLE_DIM(1);
    332         HANDLE_DIM(2);
    333         HANDLE_DIM(3);
    334         HANDLE_DIM(4);
    335         HANDLE_DIM(5);
    336 
    337         default:
    338           OP_REQUIRES_ASYNC(
    339               context, false,
    340               errors::InvalidArgument("WhereOp: Unhandled input dimensions: ",
    341                                       input_dims),
    342               done);
    343       }
    344 #undef HANDLE_DIM
    345 
    346       // TODO(ebrevdo): Fix the copy back to host.
    347 
    348       // OP_REQUIRES_ASYNC(
    349       //     context, found_true == num_true,
    350       //     errors::InvalidArgument(
    351       //         "WhereOp: Race condition between counting the number of true "
    352       //         "elements and writing them.  When counting, saw ",
    353       //         num_true, " elements; but when writing their indices, saw ",
    354       //         found_true, " elements."),
    355       //     done);
    356 
    357       done();
    358     };
    359     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
    360         stream, create_and_check_output);
    361   }
    362 
    363  private:
    364   TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp);
    365 };
    366 
    367 #define REGISTER_GPU_WHERE_OP(T) \
    368   REGISTER_KERNEL_BUILDER(       \
    369       Name("Where").Device(DEVICE_GPU).TypeConstraint<T>("T"), WhereGPUOp<T>);
    370 
    371 TF_CALL_WHERE_GPU_TYPES(REGISTER_GPU_WHERE_OP);
    372 
    373 #undef REGISTER_GPU_WHERE_OP
    374 
    375 #endif  // GOOGLE_CUDA
    376 
    377 }  // namespace tensorflow
    378