Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 #include <functional>
     20 #include <iterator>
     21 #include <numeric>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/tensor_types.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/core/threadpool.h"
     30 #include "tensorflow/core/lib/gtl/array_slice.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 namespace tensorflow {
     36 namespace {
     37 using errors::InvalidArgument;
     38 
     39 class PmfToCdfOp : public OpKernel {
     40  public:
     41   explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) {
     42     OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
     43     OP_REQUIRES(
     44         context, 0 < precision_ && precision_ <= 16,
     45         InvalidArgument("`precision` must be in [1, 16]: ", precision_));
     46   }
     47 
     48   void Compute(OpKernelContext* context) override {
     49     const Tensor& pmf_tensor = context->input(0);
     50 
     51     TensorShape shape = pmf_tensor.shape();
     52     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape),
     53                 InvalidArgument("`pmf` should be at least 1-D."));
     54     OP_REQUIRES(
     55         context, shape.dim_size(shape.dims() - 1) > 1,
     56         InvalidArgument("`pmf` size should be at least 2 in the last axis."));
     57     shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1);
     58 
     59     Tensor* cdf_tensor;
     60     OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor));
     61 
     62     auto pmf = pmf_tensor.flat_inner_dims<float, 2>();
     63     auto cdf = cdf_tensor->flat_inner_dims<int32, 2>();
     64     CHECK_EQ(pmf.dimension(0), cdf.dimension(0));
     65     CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1));
     66 
     67     const double n = pmf.dimension(1);
     68     const int64 cost_per_unit = static_cast<int64>(50.0 * n * std::log2(n));
     69     thread::ThreadPool* thread_pool =
     70         context->device()->tensorflow_cpu_worker_threads()->workers;
     71     thread_pool->ParallelFor(
     72         pmf.dimension(0), cost_per_unit,
     73         [this, pmf, &cdf](int64 start, int64 limit) {
     74           const gtl::ArraySlice<float>::size_type pmf_size = pmf.dimension(1);
     75           for (int64 i = start; i < limit; ++i) {
     76             cdf(i, 0) = 0;
     77             PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size});
     78           }
     79         });
     80   }
     81 
     82  private:
     83   struct PenaltyItem {
     84     PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) {
     85       penalty = ComputeNextPenalty();
     86     }
     87 
     88     void Decrease() {
     89       CHECK_GT(*pointer, 1);
     90       --*pointer;
     91       penalty = ComputeNextPenalty();
     92     }
     93 
     94     friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) {
     95       return lhs.penalty < rhs.penalty;
     96     }
     97 
     98     double ComputeNextPenalty() {
     99       if (*pointer <= 1) {
    100         return std::numeric_limits<double>::infinity();
    101       }
    102       return mass * (std::log2(*pointer) - std::log2(*pointer - 1));
    103     }
    104 
    105     int32* pointer;
    106     double mass;
    107     double penalty;
    108   };
    109 
    110   struct GainItem {
    111     GainItem(int32* p, double mass) : pointer(p), mass(mass) {
    112       gain = ComputeNextGain();
    113     }
    114 
    115     void Increase() {
    116       CHECK_GT(*pointer, 0);
    117       ++*pointer;
    118       gain = ComputeNextGain();
    119     }
    120 
    121     friend bool operator>(const GainItem& lhs, const GainItem& rhs) {
    122       return lhs.gain > rhs.gain;
    123     }
    124 
    125     double ComputeNextGain() {
    126       // Never increment zero value to non-zero value.
    127       if (*pointer < 1) {
    128         return -std::numeric_limits<double>::infinity();
    129       }
    130       return mass * (std::log2(*pointer + 1) - std::log2(*pointer));
    131     }
    132 
    133     int32* pointer;
    134     double mass;
    135     double gain;
    136   };
    137 
    138   void PerShard(gtl::ArraySlice<float> pmf,
    139                 gtl::MutableArraySlice<int32> cdf) const {
    140     CHECK_EQ(pmf.size(), cdf.size());
    141 
    142     const int32 normalizer = 1 << precision_;
    143     std::transform(pmf.begin(), pmf.end(), cdf.begin(),
    144                    [normalizer](float mass) {
    145                      int32 value = std::rint(mass * normalizer);
    146                      // NOTE: Consider checking if mass > 0.
    147                      value = std::max(value, 1);
    148                      return value;
    149                    });
    150 
    151     int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0);
    152     if (sum > normalizer) {
    153       std::vector<PenaltyItem> queue;
    154       queue.reserve(cdf.size());
    155       for (int i = 0; i < cdf.size(); ++i) {
    156         queue.emplace_back(&cdf[i], pmf[i]);
    157       }
    158 
    159       std::sort(queue.begin(), queue.end());
    160       while (sum-- > normalizer) {
    161         queue[0].Decrease();
    162         // Performs a linear search because this find_if is likely to return
    163         // iterator very close to the begin.
    164         auto iter = std::find_if(
    165             std::next(queue.begin()), queue.end(),
    166             [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; });
    167         std::rotate(queue.begin(), std::next(queue.begin()), iter);
    168       }
    169     } else if (sum < normalizer) {
    170       std::vector<GainItem> queue;
    171       queue.reserve(cdf.size());
    172       for (int i = 0; i < cdf.size(); ++i) {
    173         queue.emplace_back(&cdf[i], pmf[i]);
    174       }
    175 
    176       std::sort(queue.begin(), queue.end(), std::greater<GainItem>());
    177       while (sum++ < normalizer) {
    178         queue[0].Increase();
    179         // Performs a linear search because this find_if is likely to return
    180         // iterator very close to the begin.
    181         auto iter = std::find_if(
    182             std::next(queue.begin()), queue.end(),
    183             [&queue](const GainItem& rhs) { return queue[0] > rhs; });
    184         std::rotate(queue.begin(), std::next(queue.begin()), iter);
    185       }
    186     }
    187     std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
    188   }
    189 
    190   int precision_;
    191 };
    192 
    193 REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU),
    194                         PmfToCdfOp);
    195 }  // namespace
    196 }  // namespace tensorflow
    197