Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/histogram_op.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/lib/core/threadpool.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 typedef Eigen::GpuDevice GPUDevice;
     31 
     32 namespace functor {
     33 
     34 template <typename T, typename Tout>
     35 struct HistogramFixedWidthFunctor<CPUDevice, T, Tout> {
     36   static Status Compute(OpKernelContext* context,
     37                         const typename TTypes<T, 1>::ConstTensor& values,
     38                         const typename TTypes<T, 1>::ConstTensor& value_range,
     39                         int32 nbins, typename TTypes<Tout, 1>::Tensor& out) {
     40     const CPUDevice& d = context->eigen_device<CPUDevice>();
     41 
     42     Tensor index_to_bin_tensor;
     43 
     44     TF_RETURN_IF_ERROR(context->forward_input_or_allocate_temp(
     45         {0}, DataTypeToEnum<int32>::value, TensorShape({values.size()}),
     46         &index_to_bin_tensor));
     47     auto index_to_bin = index_to_bin_tensor.flat<int32>();
     48 
     49     const double step = static_cast<double>(value_range(1) - value_range(0)) /
     50                         static_cast<double>(nbins);
     51 
     52     // The calculation is done by finding the slot of each value in `values`.
     53     // With [a, b]:
     54     //   step = (b - a) / nbins
     55     //   (x - a) / step
     56     // , then the entries are mapped to output.
     57     index_to_bin.device(d) =
     58         ((values.cwiseMax(value_range(0)) - values.constant(value_range(0)))
     59              .template cast<double>() /
     60          step)
     61             .template cast<int32>()
     62             .cwiseMin(nbins - 1);
     63 
     64     out.setZero();
     65     for (int32 i = 0; i < index_to_bin.size(); i++) {
     66       out(index_to_bin(i)) += Tout(1);
     67     }
     68     return Status::OK();
     69   }
     70 };
     71 
     72 }  // namespace functor
     73 
     74 template <typename Device, typename T, typename Tout>
     75 class HistogramFixedWidthOp : public OpKernel {
     76  public:
     77   explicit HistogramFixedWidthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     78 
     79   void Compute(OpKernelContext* ctx) override {
     80     const Tensor& values_tensor = ctx->input(0);
     81     const Tensor& value_range_tensor = ctx->input(1);
     82     const Tensor& nbins_tensor = ctx->input(2);
     83 
     84     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(value_range_tensor.shape()),
     85                 errors::InvalidArgument("value_range should be a vector."));
     86     OP_REQUIRES(ctx, (value_range_tensor.shape().num_elements() == 2),
     87                 errors::InvalidArgument(
     88                     "value_range should be a vector of 2 elements."));
     89     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(nbins_tensor.shape()),
     90                 errors::InvalidArgument("nbins should be a scalar."));
     91 
     92     const auto values = values_tensor.flat<T>();
     93     const auto value_range = value_range_tensor.flat<T>();
     94     const auto nbins = nbins_tensor.scalar<int32>()();
     95 
     96     OP_REQUIRES(
     97         ctx, (value_range(0) < value_range(1)),
     98         errors::InvalidArgument("value_range should satisfy value_range[0] < "
     99                                 "value_range[1], but got '[",
    100                                 value_range(0), ", ", value_range(1), "]'"));
    101     OP_REQUIRES(
    102         ctx, (nbins > 0),
    103         errors::InvalidArgument("nbins should be a positive number, but got '",
    104                                 nbins, "'"));
    105 
    106     Tensor* out_tensor;
    107     OP_REQUIRES_OK(ctx,
    108                    ctx->allocate_output(0, TensorShape({nbins}), &out_tensor));
    109     auto out = out_tensor->flat<Tout>();
    110 
    111     OP_REQUIRES_OK(
    112         ctx, functor::HistogramFixedWidthFunctor<Device, T, Tout>::Compute(
    113                  ctx, values, value_range, nbins, out));
    114   }
    115 };
    116 
    117 #define REGISTER_KERNELS(type)                                           \
    118   REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth")                    \
    119                               .Device(DEVICE_CPU)                        \
    120                               .TypeConstraint<type>("T")                 \
    121                               .TypeConstraint<int32>("dtype"),           \
    122                           HistogramFixedWidthOp<CPUDevice, type, int32>) \
    123   REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth")                    \
    124                               .Device(DEVICE_CPU)                        \
    125                               .TypeConstraint<type>("T")                 \
    126                               .TypeConstraint<int64>("dtype"),           \
    127                           HistogramFixedWidthOp<CPUDevice, type, int64>)
    128 
    129 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
    130 #undef REGISTER_KERNELS
    131 
    132 #if GOOGLE_CUDA
    133 #define REGISTER_KERNELS(type)                                 \
    134   REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth")          \
    135                               .Device(DEVICE_GPU)              \
    136                               .HostMemory("value_range")       \
    137                               .HostMemory("nbins")             \
    138                               .TypeConstraint<type>("T")       \
    139                               .TypeConstraint<int32>("dtype"), \
    140                           HistogramFixedWidthOp<GPUDevice, type, int32>)
    141 
    142 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
    143 #undef REGISTER_KERNELS
    144 
    145 #endif  // GOOGLE_CUDA
    146 
    147 }  // end namespace tensorflow
    148