Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 #include <array>
     20 #include <limits>
     21 #include <type_traits>
     22 #include <vector>
     23 
     24 #include "tensorflow/contrib/coder/kernels/range_coder.h"
     25 #include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/lib/gtl/array_slice.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/macros.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace tensorflow {
     38 namespace {
     39 // A helper class to iterate over data and cdf simultaneously, while cdf is
     40 // broadcasted to data.
     41 // NOTE: Moving this class out of anonymous namespace impacts compiler
     42 // optimization and affects performance. When moving this code around (e.g.,
     43 // into a library header), be sure to check the benchmark tests.
     44 template <typename T, typename U, int N>
     45 class BroadcastRange {
     46  public:
     47   BroadcastRange(T* data_pointer, gtl::ArraySlice<int64> data_shape,
     48                  const U* cdf_pointer, gtl::ArraySlice<int64> cdf_shape)
     49       : data_pointer_(data_pointer), cdf_pointer_(cdf_pointer) {
     50     CHECK(!data_shape.empty());
     51     CHECK_EQ(data_shape.size(), N);
     52     CHECK_EQ(cdf_shape.size(), N + 1);
     53 
     54     std::copy(data_shape.begin(), data_shape.end(), &data_shape_[0]);
     55     data_index_.fill(0);
     56 
     57     const int64 innermost_stride = cdf_shape[N];
     58     cdf_displace_.fill(innermost_stride);
     59 
     60     // Pre-compute the pointer displacement for cdf.
     61     int64 stride = innermost_stride;
     62     for (int i = N - 1; i >= 0; --i) {
     63       const bool broadcasting = (cdf_shape[i] <= 1);
     64 
     65       // When the data linear index advances by one, the cdf linear index
     66       // advances by `innermost_stride`.
     67       //
     68       // Suppose that the i-th axis coordinate of data increased by one, and
     69       // that i-th axis is broadcasting. The cdf linear index should be wound
     70       // back by i-th axis stride, so that i-th axis coordinate of cdf is
     71       // effectively kept at 0.
     72       if (broadcasting) {
     73         cdf_displace_[i] -= stride;
     74       }
     75       stride *= cdf_shape[i];
     76     }
     77   }
     78 
     79   // Returns the pointers to the current iterating locations to data and cdf
     80   // tensors.
     81   //
     82   // Note that this function does not track whether data pointer is running past
     83   // the end of data buffer. The caller has to make sure Next() is called no
     84   // more than that.
     85   std::pair<T*, const U*> Next() {
     86     std::pair<T*, const U*> return_value = {data_pointer_, cdf_pointer_};
     87 
     88     int i = N - 1;
     89     for (; i > 0; --i) {
     90       ++data_index_[i];
     91       if (data_index_[i] < data_shape_[i]) {
     92         break;
     93       }
     94       data_index_[i] = 0;
     95     }
     96 
     97     // Advance data pointer by one.
     98     data_pointer_ += 1;
     99 
    100     // For cdf pointer, it's more complicated because of broadcasting. When i-th
    101     // coordinate increase by one, and if i-th axis is broadcasting, then we
    102     // need to rewind back the pointer so that the effective i-th axis
    103     // coordinate for cdf is always 0. This value is precomputed as
    104     // cdf_displace_.
    105     cdf_pointer_ += cdf_displace_[i];
    106     return return_value;
    107   }
    108 
    109  private:
    110   std::array<int64, N> data_shape_;
    111   std::array<int64, N> cdf_displace_;
    112   std::array<int64, N> data_index_;
    113 
    114   T* data_pointer_;
    115   const U* cdf_pointer_;
    116 };
    117 
    118 Status CheckCdfShape(const TensorShape& data_shape,
    119                      const TensorShape& cdf_shape) {
    120   if (TF_PREDICT_FALSE(cdf_shape.dims() != data_shape.dims() + 1)) {
    121     return errors::InvalidArgument(
    122         "`cdf` should have one more axis than `data`: data shape=",
    123         data_shape.DebugString(), ", cdf shape=", cdf_shape.DebugString());
    124   }
    125 
    126   if (TF_PREDICT_FALSE(cdf_shape.dim_size(cdf_shape.dims() - 1) <= 1)) {
    127     return errors::InvalidArgument(
    128         "The last dimension of `cdf` should be > 1: ", cdf_shape.DebugString());
    129   }
    130 
    131   return Status::OK();
    132 }
    133 
    134 // Non-incremental encoder op -------------------------------------------------
    135 class RangeEncodeOp : public OpKernel {
    136  public:
    137   explicit RangeEncodeOp(OpKernelConstruction* context) : OpKernel(context) {
    138     OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
    139     OP_REQUIRES(context, 0 < precision_ && precision_ <= 16,
    140                 errors::InvalidArgument("`precision` must be in [1, 16]: ",
    141                                         precision_));
    142   }
    143 
    144   void Compute(OpKernelContext* context) override {
    145     const Tensor& data = context->input(0);
    146     const Tensor& cdf = context->input(1);
    147 
    148     OP_REQUIRES_OK(context, CheckCdfShape(data.shape(), cdf.shape()));
    149 
    150     std::vector<int64> data_shape, cdf_shape;
    151     OP_REQUIRES_OK(
    152         context, MergeAxes(data.shape(), cdf.shape(), &data_shape, &cdf_shape));
    153 
    154     Tensor* output_tensor;
    155     OP_REQUIRES_OK(context,
    156                    context->allocate_output(0, TensorShape{}, &output_tensor));
    157     string* output = &output_tensor->scalar<string>()();
    158 
    159     switch (data_shape.size()) {
    160 #define RANGE_ENCODE_CASE(dims)                                                \
    161   case dims: {                                                                 \
    162     RangeEncodeImpl<dims>(data.flat<int16>(), data_shape,                      \
    163                           cdf.flat_inner_dims<int32, 2>(), cdf_shape, output); \
    164   } break
    165       RANGE_ENCODE_CASE(1);
    166       RANGE_ENCODE_CASE(2);
    167       RANGE_ENCODE_CASE(3);
    168       RANGE_ENCODE_CASE(4);
    169       RANGE_ENCODE_CASE(5);
    170       RANGE_ENCODE_CASE(6);
    171 #undef RANGE_ENCODE_CASE
    172       default:
    173         context->CtxFailure(errors::InvalidArgument(
    174             "Irregular broadcast pattern: ", data.shape().DebugString(), ", ",
    175             cdf.shape().DebugString()));
    176         return;
    177     }
    178   }
    179 
    180  private:
    181   template <int N>
    182   void RangeEncodeImpl(TTypes<int16>::ConstFlat data,
    183                        gtl::ArraySlice<int64> data_shape,
    184                        TTypes<int32>::ConstMatrix cdf,
    185                        gtl::ArraySlice<int64> cdf_shape, string* output) const {
    186     const int64 data_size = data.size();
    187     const int64 cdf_size = cdf.size();
    188     const int64 chip_size = cdf.dimension(1);
    189 
    190     BroadcastRange<const int16, int32, N> view{data.data(), data_shape,
    191                                                cdf.data(), cdf_shape};
    192     RangeEncoder encoder{precision_};
    193     for (int64 linear = 0; linear < data_size; ++linear) {
    194       const auto pair = view.Next();
    195 
    196       const int64 index = *pair.first;
    197       DCHECK_GE(index, 0);
    198       DCHECK_LT(index + 1, chip_size);
    199 
    200       const int32* cdf_slice = pair.second;
    201       DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size);
    202 
    203       const int32 lower = cdf_slice[index];
    204       const int32 upper = cdf_slice[index + 1];
    205       encoder.Encode(lower, upper, output);
    206     }
    207 
    208     encoder.Finalize(output);
    209   }
    210 
    211   int precision_;
    212 };
    213 
    214 REGISTER_KERNEL_BUILDER(Name("RangeEncode").Device(DEVICE_CPU), RangeEncodeOp);
    215 
    216 // Non-incremental decoder op -------------------------------------------------
    217 class RangeDecodeOp : public OpKernel {
    218  public:
    219   explicit RangeDecodeOp(OpKernelConstruction* context) : OpKernel(context) {
    220     OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
    221     OP_REQUIRES(context, 0 < precision_ && precision_ <= 16,
    222                 errors::InvalidArgument("`precision` must be in [1, 16]: ",
    223                                         precision_));
    224   }
    225 
    226   void Compute(OpKernelContext* context) override {
    227     const Tensor& encoded_tensor = context->input(0);
    228     const Tensor& shape = context->input(1);
    229     const Tensor& cdf = context->input(2);
    230 
    231     OP_REQUIRES(context, TensorShapeUtils::IsScalar(encoded_tensor.shape()),
    232                 errors::InvalidArgument("Invalid `encoded` shape: ",
    233                                         encoded_tensor.shape().DebugString()));
    234     OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()),
    235                 errors::InvalidArgument("Invalid `shape` shape: ",
    236                                         shape.shape().DebugString()));
    237     TensorShape output_shape;
    238     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(shape.vec<int32>(),
    239                                                         &output_shape));
    240     OP_REQUIRES_OK(context, CheckCdfShape(output_shape, cdf.shape()));
    241 
    242     std::vector<int64> data_shape, cdf_shape;
    243     OP_REQUIRES_OK(
    244         context, MergeAxes(output_shape, cdf.shape(), &data_shape, &cdf_shape));
    245 
    246     const string& encoded = encoded_tensor.scalar<string>()();
    247 
    248     Tensor* output;
    249     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    250 
    251     switch (data_shape.size()) {
    252 #define RANGE_DECODE_CASE(dim)                                              \
    253   case dim: {                                                               \
    254     RangeDecodeImpl<dim>(output->flat<int16>(), data_shape,                 \
    255                          cdf.flat_inner_dims<int32>(), cdf_shape, encoded); \
    256   } break
    257       RANGE_DECODE_CASE(1);
    258       RANGE_DECODE_CASE(2);
    259       RANGE_DECODE_CASE(3);
    260       RANGE_DECODE_CASE(4);
    261       RANGE_DECODE_CASE(5);
    262       RANGE_DECODE_CASE(6);
    263 #undef RANGE_DECODE_CASE
    264       default:
    265         context->CtxFailure(errors::InvalidArgument(
    266             "Irregular broadcast pattern: ", output_shape.DebugString(), ", ",
    267             cdf.shape().DebugString()));
    268         return;
    269     }
    270   }
    271 
    272  private:
    273   template <int N>
    274   void RangeDecodeImpl(TTypes<int16>::Flat output,
    275                        gtl::ArraySlice<int64> output_shape,
    276                        TTypes<int32>::ConstMatrix cdf,
    277                        gtl::ArraySlice<int64> cdf_shape,
    278                        const string& encoded) const {
    279     BroadcastRange<int16, int32, N> view{output.data(), output_shape,
    280                                          cdf.data(), cdf_shape};
    281 
    282     RangeDecoder decoder{encoded, precision_};
    283 
    284     const int64 output_size = output.size();
    285     const int64 cdf_size = cdf.size();
    286     const auto chip_size =
    287         static_cast<gtl::ArraySlice<int32>::size_type>(cdf.dimension(1));
    288 
    289     for (int64 i = 0; i < output_size; ++i) {
    290       const auto pair = view.Next();
    291 
    292       int16* data = pair.first;
    293       DCHECK_LT(data, output.data() + output_size);
    294 
    295       const int32* cdf_slice = pair.second;
    296       DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size);
    297 
    298       *data = decoder.Decode(gtl::ArraySlice<int32>{cdf_slice, chip_size});
    299     }
    300   }
    301 
    302   int precision_;
    303 };
    304 
    305 REGISTER_KERNEL_BUILDER(Name("RangeDecode").Device(DEVICE_CPU), RangeDecodeOp);
    306 }  // namespace
    307 }  // namespace tensorflow
    308