Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
     12 implied.
     13 See the License for the specific language governing permissions and
     14 limitations under the License.
     15 ==============================================================================*/
     16 
     17 // See docs in ../ops/math_ops.cc
     18 
     19 #define EIGEN_USE_THREADS
     20 
     21 #include <bitset>
     22 
     23 #include "tensorflow/core/kernels/population_count_op.h"
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 
     36 typedef Eigen::ThreadPoolDevice CPUDevice;
     37 typedef Eigen::GpuDevice GPUDevice;
     38 
     39 template <typename Device, typename T>
     40 class PopulationCountOp : public OpKernel {
     41  public:
     42   explicit PopulationCountOp(OpKernelConstruction* context)
     43       : OpKernel(context) {}
     44 
     45   void Compute(OpKernelContext* c) override {
     46     const Tensor& input_t = c->input(0);
     47     Tensor* output_t;
     48     OP_REQUIRES_OK(c, c->allocate_output(0, input_t.shape(), &output_t));
     49 
     50     auto input = input_t.flat<T>();
     51     auto output = output_t->flat<uint8>();
     52 
     53     functor::PopulationCount<Device, T> popcnt;
     54     popcnt(c, input, output);
     55   }
     56 };
     57 
     58 #define REGISTER_POPULATION_COUNT(type)                                     \
     59   REGISTER_KERNEL_BUILDER(                                                  \
     60       Name("PopulationCount").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
     61       PopulationCountOp<CPUDevice, type>);
     62 
     63 TF_CALL_uint8(REGISTER_POPULATION_COUNT);
     64 TF_CALL_int8(REGISTER_POPULATION_COUNT);
     65 TF_CALL_uint16(REGISTER_POPULATION_COUNT);
     66 TF_CALL_int16(REGISTER_POPULATION_COUNT);
     67 TF_CALL_int32(REGISTER_POPULATION_COUNT);
     68 TF_CALL_int64(REGISTER_POPULATION_COUNT);
     69 
     70 #undef REGISTER_POPULATION_COUNT
     71 
     72 namespace functor {
     73 
     74 namespace {
     75 
     76 template <typename T>
     77 inline uint8 PopCnt(const T v);
     78 
     79 #define POPCNT(T, N)                  \
     80   template <>                         \
     81   uint8 PopCnt<T>(const T v) {        \
     82     return std::bitset<N>(v).count(); \
     83   }
     84 
     85 POPCNT(int8, 8);
     86 POPCNT(uint8, 8);
     87 POPCNT(int16, 16);
     88 POPCNT(uint16, 16);
     89 POPCNT(int32, 32);
     90 POPCNT(int64, 64);
     91 
     92 #undef POPCNT
     93 
     94 }  // namespace
     95 
     96 template <typename T>
     97 struct PopulationCount<CPUDevice, T> {
     98   void operator()(OpKernelContext* c, typename TTypes<T>::ConstFlat input,
     99                   TTypes<uint8>::Flat output) {
    100     const T* input_ptr = input.data();
    101     uint8* output_ptr = output.data();
    102     auto shard = [input_ptr, output_ptr](int64 start, int64 limit) {
    103       for (int64 i = start; i < limit; ++i) {
    104         output_ptr[i] = PopCnt<T>(input_ptr[i]);
    105       }
    106     };
    107     int64 total_shards = input.size();
    108     // Approximating cost of popcnt: convert T to int64
    109     // (std::bitset constructor) and convert int64 to uint8
    110     // (bitset.count() -> output).  The .count() itself is relatively cheap.
    111     const double total_cost = (Eigen::TensorOpCost::CastCost<T, uint8>() +
    112                                Eigen::TensorOpCost::CastCost<int64, uint8>());
    113     const int64 shard_cost = (total_cost >= static_cast<double>(kint64max))
    114                                  ? kint64max
    115                                  : static_cast<int64>(total_cost);
    116 
    117     auto worker_threads = *(c->device()->tensorflow_cpu_worker_threads());
    118     Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
    119           shard_cost, shard);
    120   }
    121 };
    122 
    123 }  // namespace functor
    124 
    125 #if GOOGLE_CUDA
    126 
    127 #define REGISTER_POPULATION_COUNT(type)                                     \
    128   REGISTER_KERNEL_BUILDER(                                                  \
    129       Name("PopulationCount").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    130       PopulationCountOp<GPUDevice, type>)
    131 
    132 TF_CALL_uint8(REGISTER_POPULATION_COUNT);
    133 TF_CALL_int8(REGISTER_POPULATION_COUNT);
    134 TF_CALL_uint16(REGISTER_POPULATION_COUNT);
    135 TF_CALL_int16(REGISTER_POPULATION_COUNT);
    136 TF_CALL_int32(REGISTER_POPULATION_COUNT);
    137 TF_CALL_int64(REGISTER_POPULATION_COUNT);
    138 
    139 #undef REGISTER_POPULATION_COUNT
    140 
    141 namespace functor {
    142 
    143 #define DECLARE_GPU_SPEC(T)                                    \
    144   template <>                                                  \
    145   void PopulationCount<GPUDevice, T>::operator()(              \
    146       OpKernelContext* c, typename TTypes<T>::ConstFlat input, \
    147       TTypes<uint8>::Flat output);                             \
    148   extern template struct PopulationCount<GPUDevice, T>
    149 
    150 TF_CALL_uint8(DECLARE_GPU_SPEC);
    151 TF_CALL_int8(DECLARE_GPU_SPEC);
    152 TF_CALL_uint16(DECLARE_GPU_SPEC);
    153 TF_CALL_int16(DECLARE_GPU_SPEC);
    154 TF_CALL_int32(DECLARE_GPU_SPEC);
    155 TF_CALL_int64(DECLARE_GPU_SPEC);
    156 
    157 #undef DECLARE_GPU_SPEC
    158 
    159 }  // namespace functor
    160 
    161 #endif  // GOOGLE_CUDA
    162 
    163 }  // namespace tensorflow
    164