Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/compare_and_bitpack_op.h"
     21 
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/lib/core/status.h"
     29 #include "tensorflow/core/util/work_sharder.h"
     30 
     31 namespace tensorflow {
     32 
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 typedef Eigen::GpuDevice GPUDevice;
     35 
     36 template <typename Device, typename T>
     37 class CompareAndBitpackOp : public OpKernel {
     38  public:
     39   explicit CompareAndBitpackOp(OpKernelConstruction* context)
     40       : OpKernel(context) {}
     41 
     42   void Compute(OpKernelContext* c) override {
     43     const Tensor& input_t = c->input(0);
     44     const Tensor& threshold_t = c->input(1);
     45     OP_REQUIRES(
     46         c, TensorShapeUtils::IsScalar(threshold_t.shape()),
     47         errors::InvalidArgument("Compare must be a scalar, but saw shape: ",
     48                                 threshold_t.shape().DebugString()));
     49     const TensorShape& input_shape = input_t.shape();
     50     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_shape),
     51                 errors::InvalidArgument(
     52                     "Input should be at least a vector, but saw a scalar."));
     53     OP_REQUIRES(c, input_shape.dim_size(input_shape.dims() - 1) % 8 == 0,
     54                 errors::InvalidArgument(
     55                     "Inner dimension of input should be "
     56                     "divisible by ",
     57                     8, ", but saw shape: ", input_shape.DebugString()));
     58 
     59     TensorShape output_shape = input_shape;
     60     int rank = input_shape.dims();
     61     output_shape.set_dim(rank - 1, input_shape.dim_size(rank - 1) / 8);
     62 
     63     Tensor* output_t;
     64     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output_t));
     65 
     66     auto input = input_t.flat_inner_dims<T>();
     67     auto threshold = threshold_t.scalar<T>();
     68     auto output = output_t->flat_inner_dims<uint8>();
     69 
     70     functor::CompareAndBitpack<Device, T> func;
     71     func(c, input, threshold, output);
     72   }
     73 };
     74 
     75 #define REGISTER_COMPARE_AND_BITPACK(type)                                    \
     76   REGISTER_KERNEL_BUILDER(                                                    \
     77       Name("CompareAndBitpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
     78       CompareAndBitpackOp<CPUDevice, type>);
     79 
     80 TF_CALL_REAL_NUMBER_TYPES(REGISTER_COMPARE_AND_BITPACK);
     81 TF_CALL_bool(REGISTER_COMPARE_AND_BITPACK);
     82 
     83 #undef REGISTER_COMPARE_AND_BITPACK
     84 
     85 namespace functor {
     86 
     87 template <typename T, class = void, class = void>
     88 struct ComputeShard {
     89   static EIGEN_STRONG_INLINE void Compute(typename TTypes<T>::ConstMatrix input,
     90                                           typename TTypes<uint8>::Matrix output,
     91                                           const T& thresh, int64 start,
     92                                           int64 limit) {
     93     for (int64 i = start; i < limit; ++i) {
     94       uint8* out = output.data() + i;
     95       const T* block = input.data() + 8 * i;
     96       *out = ((((block[0] > thresh) << 7)) | (((block[1] > thresh) << 6)) |
     97               (((block[2] > thresh) << 5)) | (((block[3] > thresh) << 4)) |
     98               (((block[4] > thresh) << 3)) | (((block[5] > thresh) << 2)) |
     99               (((block[6] > thresh) << 1)) | (((block[7] > thresh))));
    100     }
    101   }
    102 };
    103 
    104 // Specialization for bool on systems where sizeof(bool) == 1.
    105 template <typename T>
    106 struct ComputeShard<T,
    107                     typename std::enable_if<std::is_same<T, bool>::value>::type,
    108                     typename std::enable_if<sizeof(T) == 1>::type> {
    109   static EIGEN_STRONG_INLINE void Compute(
    110       typename TTypes<bool>::ConstMatrix input,
    111       typename TTypes<uint8>::Matrix output, bool /*thresh*/, int64 start,
    112       int64 limit) {
    113 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
    114     for (int64 i = start; i < limit; ++i) {
    115       uint8* out = output.data() + i;
    116       const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
    117       *out = ((((block & (1LL << (7 * 8))) >> (7 * 8 - 7))) |
    118               (((block & (1LL << (6 * 8))) >> (6 * 8 - 6))) |
    119               (((block & (1LL << (5 * 8))) >> (5 * 8 - 5))) |
    120               (((block & (1LL << (4 * 8))) >> (4 * 8 - 4))) |
    121               (((block & (1LL << (3 * 8))) >> (3 * 8 - 3))) |
    122               (((block & (1LL << (2 * 8))) >> (2 * 8 - 2))) |
    123               (((block & (1LL << 8)) >> (1 * 8 - 1))) | (((block & (1LL)))));
    124     }
    125 #else
    126     for (int64 i = start; i < limit; ++i) {
    127       uint8* out = output.data() + i;
    128       const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
    129       *out =
    130           ((((block & (1LL << (7 * 8))) >> (7 * 8 - 0))) |
    131            (((block & (1LL << (6 * 8))) >> (6 * 8 - 1))) |
    132            (((block & (1LL << (5 * 8))) >> (5 * 8 - 2))) |
    133            (((block & (1LL << (4 * 8))) >> (4 * 8 - 3))) |
    134            (((block & (1LL << (3 * 8))) >> (3 * 8 - 4))) |
    135            (((block & (1LL << (2 * 8))) >> (2 * 8 - 5))) |
    136            (((block & (1LL << 8)) >> (1 * 8 - 6))) | (((block & (1LL)) << 7)));
    137     }
    138 #endif
    139   }
    140 };
    141 
    142 template <typename T>
    143 struct CompareAndBitpack<CPUDevice, T> {
    144   void operator()(OpKernelContext* c, typename TTypes<T>::ConstMatrix input,
    145                   typename TTypes<T>::ConstScalar threshold,
    146                   TTypes<uint8>::Matrix output) {
    147     const T thresh = threshold();
    148     auto shard = [&, thresh](int64 start, int64 limit) {
    149       ComputeShard<T>::Compute(input, output, thresh, start, limit);
    150     };
    151     int64 total_shards = output.size();  // Approximate cmp as an add and
    152                                          // bitwise-or + shift as an add.
    153     const double total_cost = 8 * (Eigen::TensorOpCost::AddCost<T>() +
    154                                    Eigen::TensorOpCost::AddCost<uint8>());
    155     const int64 shard_cost = (total_cost >= static_cast<double>(kint64max))
    156                                  ? kint64max
    157                                  : static_cast<int64>(total_cost);
    158 
    159     auto worker_threads = *(c->device()->tensorflow_cpu_worker_threads());
    160     Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
    161           shard_cost, shard);
    162   }
    163 };
    164 
    165 }  // namespace functor
    166 
    167 #if GOOGLE_CUDA
    168 
    169 #define REGISTER_COMPARE_AND_BITPACK(type)                                    \
    170   REGISTER_KERNEL_BUILDER(                                                    \
    171       Name("CompareAndBitpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    172       CompareAndBitpackOp<GPUDevice, type>);
    173 
    174 TF_CALL_GPU_NUMBER_TYPES(REGISTER_COMPARE_AND_BITPACK);
    175 TF_CALL_bool(REGISTER_COMPARE_AND_BITPACK);
    176 
    177 #undef REGISTER_COMPARE_AND_BITPACK
    178 
    179 namespace functor {
    180 
    181 #define DECLARE_GPU_SPEC(T)                                      \
    182   template <>                                                    \
    183   void CompareAndBitpack<GPUDevice, T>::operator()(              \
    184       OpKernelContext* c, typename TTypes<T>::ConstMatrix input, \
    185       typename TTypes<T>::ConstScalar threshold,                 \
    186       TTypes<uint8>::Matrix output);                             \
    187   extern template struct CompareAndBitpack<GPUDevice, T>;
    188 
    189 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC)
    190 TF_CALL_bool(DECLARE_GPU_SPEC)
    191 
    192 #undef DECLARE_GPU_SPEC
    193 
    194 }  // namespace functor
    195 
    196 #endif  // GOOGLE_CUDA
    197 
    198 }  // namespace tensorflow
    199