1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <array> 20 #include <limits> 21 #include <type_traits> 22 #include <vector> 23 24 #include "tensorflow/contrib/coder/kernels/range_coder.h" 25 #include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/tensor_types.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/gtl/array_slice.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace tensorflow { 38 namespace { 39 // A helper class to iterate over data and cdf simultaneously, while cdf is 40 // broadcasted to data. 41 // NOTE: Moving this class out of anonymous namespace impacts compiler 42 // optimization and affects performance. When moving this code around (e.g., 43 // into a library header), be sure to check the benchmark tests. 44 template <typename T, typename U, int N> 45 class BroadcastRange { 46 public: 47 BroadcastRange(T* data_pointer, gtl::ArraySlice<int64> data_shape, 48 const U* cdf_pointer, gtl::ArraySlice<int64> cdf_shape) 49 : data_pointer_(data_pointer), cdf_pointer_(cdf_pointer) { 50 CHECK(!data_shape.empty()); 51 CHECK_EQ(data_shape.size(), N); 52 CHECK_EQ(cdf_shape.size(), N + 1); 53 54 std::copy(data_shape.begin(), data_shape.end(), &data_shape_[0]); 55 data_index_.fill(0); 56 57 const int64 innermost_stride = cdf_shape[N]; 58 cdf_displace_.fill(innermost_stride); 59 60 // Pre-compute the pointer displacement for cdf. 61 int64 stride = innermost_stride; 62 for (int i = N - 1; i >= 0; --i) { 63 const bool broadcasting = (cdf_shape[i] <= 1); 64 65 // When the data linear index advances by one, the cdf linear index 66 // advances by `innermost_stride`. 67 // 68 // Suppose that the i-th axis coordinate of data increased by one, and 69 // that i-th axis is broadcasting. The cdf linear index should be wound 70 // back by i-th axis stride, so that i-th axis coordinate of cdf is 71 // effectively kept at 0. 72 if (broadcasting) { 73 cdf_displace_[i] -= stride; 74 } 75 stride *= cdf_shape[i]; 76 } 77 } 78 79 // Returns the pointers to the current iterating locations to data and cdf 80 // tensors. 81 // 82 // Note that this function does not track whether data pointer is running past 83 // the end of data buffer. The caller has to make sure Next() is called no 84 // more than that. 85 std::pair<T*, const U*> Next() { 86 std::pair<T*, const U*> return_value = {data_pointer_, cdf_pointer_}; 87 88 int i = N - 1; 89 for (; i > 0; --i) { 90 ++data_index_[i]; 91 if (data_index_[i] < data_shape_[i]) { 92 break; 93 } 94 data_index_[i] = 0; 95 } 96 97 // Advance data pointer by one. 98 data_pointer_ += 1; 99 100 // For cdf pointer, it's more complicated because of broadcasting. When i-th 101 // coordinate increase by one, and if i-th axis is broadcasting, then we 102 // need to rewind back the pointer so that the effective i-th axis 103 // coordinate for cdf is always 0. This value is precomputed as 104 // cdf_displace_. 105 cdf_pointer_ += cdf_displace_[i]; 106 return return_value; 107 } 108 109 private: 110 std::array<int64, N> data_shape_; 111 std::array<int64, N> cdf_displace_; 112 std::array<int64, N> data_index_; 113 114 T* data_pointer_; 115 const U* cdf_pointer_; 116 }; 117 118 Status CheckCdfShape(const TensorShape& data_shape, 119 const TensorShape& cdf_shape) { 120 if (TF_PREDICT_FALSE(cdf_shape.dims() != data_shape.dims() + 1)) { 121 return errors::InvalidArgument( 122 "`cdf` should have one more axis than `data`: data shape=", 123 data_shape.DebugString(), ", cdf shape=", cdf_shape.DebugString()); 124 } 125 126 if (TF_PREDICT_FALSE(cdf_shape.dim_size(cdf_shape.dims() - 1) <= 1)) { 127 return errors::InvalidArgument( 128 "The last dimension of `cdf` should be > 1: ", cdf_shape.DebugString()); 129 } 130 131 return Status::OK(); 132 } 133 134 // Non-incremental encoder op ------------------------------------------------- 135 class RangeEncodeOp : public OpKernel { 136 public: 137 explicit RangeEncodeOp(OpKernelConstruction* context) : OpKernel(context) { 138 OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); 139 OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, 140 errors::InvalidArgument("`precision` must be in [1, 16]: ", 141 precision_)); 142 } 143 144 void Compute(OpKernelContext* context) override { 145 const Tensor& data = context->input(0); 146 const Tensor& cdf = context->input(1); 147 148 OP_REQUIRES_OK(context, CheckCdfShape(data.shape(), cdf.shape())); 149 150 std::vector<int64> data_shape, cdf_shape; 151 OP_REQUIRES_OK( 152 context, MergeAxes(data.shape(), cdf.shape(), &data_shape, &cdf_shape)); 153 154 Tensor* output_tensor; 155 OP_REQUIRES_OK(context, 156 context->allocate_output(0, TensorShape{}, &output_tensor)); 157 string* output = &output_tensor->scalar<string>()(); 158 159 switch (data_shape.size()) { 160 #define RANGE_ENCODE_CASE(dims) \ 161 case dims: { \ 162 RangeEncodeImpl<dims>(data.flat<int16>(), data_shape, \ 163 cdf.flat_inner_dims<int32, 2>(), cdf_shape, output); \ 164 } break 165 RANGE_ENCODE_CASE(1); 166 RANGE_ENCODE_CASE(2); 167 RANGE_ENCODE_CASE(3); 168 RANGE_ENCODE_CASE(4); 169 RANGE_ENCODE_CASE(5); 170 RANGE_ENCODE_CASE(6); 171 #undef RANGE_ENCODE_CASE 172 default: 173 context->CtxFailure(errors::InvalidArgument( 174 "Irregular broadcast pattern: ", data.shape().DebugString(), ", ", 175 cdf.shape().DebugString())); 176 return; 177 } 178 } 179 180 private: 181 template <int N> 182 void RangeEncodeImpl(TTypes<int16>::ConstFlat data, 183 gtl::ArraySlice<int64> data_shape, 184 TTypes<int32>::ConstMatrix cdf, 185 gtl::ArraySlice<int64> cdf_shape, string* output) const { 186 const int64 data_size = data.size(); 187 const int64 cdf_size = cdf.size(); 188 const int64 chip_size = cdf.dimension(1); 189 190 BroadcastRange<const int16, int32, N> view{data.data(), data_shape, 191 cdf.data(), cdf_shape}; 192 RangeEncoder encoder{precision_}; 193 for (int64 linear = 0; linear < data_size; ++linear) { 194 const auto pair = view.Next(); 195 196 const int64 index = *pair.first; 197 DCHECK_GE(index, 0); 198 DCHECK_LT(index + 1, chip_size); 199 200 const int32* cdf_slice = pair.second; 201 DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); 202 203 const int32 lower = cdf_slice[index]; 204 const int32 upper = cdf_slice[index + 1]; 205 encoder.Encode(lower, upper, output); 206 } 207 208 encoder.Finalize(output); 209 } 210 211 int precision_; 212 }; 213 214 REGISTER_KERNEL_BUILDER(Name("RangeEncode").Device(DEVICE_CPU), RangeEncodeOp); 215 216 // Non-incremental decoder op ------------------------------------------------- 217 class RangeDecodeOp : public OpKernel { 218 public: 219 explicit RangeDecodeOp(OpKernelConstruction* context) : OpKernel(context) { 220 OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); 221 OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, 222 errors::InvalidArgument("`precision` must be in [1, 16]: ", 223 precision_)); 224 } 225 226 void Compute(OpKernelContext* context) override { 227 const Tensor& encoded_tensor = context->input(0); 228 const Tensor& shape = context->input(1); 229 const Tensor& cdf = context->input(2); 230 231 OP_REQUIRES(context, TensorShapeUtils::IsScalar(encoded_tensor.shape()), 232 errors::InvalidArgument("Invalid `encoded` shape: ", 233 encoded_tensor.shape().DebugString())); 234 OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()), 235 errors::InvalidArgument("Invalid `shape` shape: ", 236 shape.shape().DebugString())); 237 TensorShape output_shape; 238 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(shape.vec<int32>(), 239 &output_shape)); 240 OP_REQUIRES_OK(context, CheckCdfShape(output_shape, cdf.shape())); 241 242 std::vector<int64> data_shape, cdf_shape; 243 OP_REQUIRES_OK( 244 context, MergeAxes(output_shape, cdf.shape(), &data_shape, &cdf_shape)); 245 246 const string& encoded = encoded_tensor.scalar<string>()(); 247 248 Tensor* output; 249 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 250 251 switch (data_shape.size()) { 252 #define RANGE_DECODE_CASE(dim) \ 253 case dim: { \ 254 RangeDecodeImpl<dim>(output->flat<int16>(), data_shape, \ 255 cdf.flat_inner_dims<int32>(), cdf_shape, encoded); \ 256 } break 257 RANGE_DECODE_CASE(1); 258 RANGE_DECODE_CASE(2); 259 RANGE_DECODE_CASE(3); 260 RANGE_DECODE_CASE(4); 261 RANGE_DECODE_CASE(5); 262 RANGE_DECODE_CASE(6); 263 #undef RANGE_DECODE_CASE 264 default: 265 context->CtxFailure(errors::InvalidArgument( 266 "Irregular broadcast pattern: ", output_shape.DebugString(), ", ", 267 cdf.shape().DebugString())); 268 return; 269 } 270 } 271 272 private: 273 template <int N> 274 void RangeDecodeImpl(TTypes<int16>::Flat output, 275 gtl::ArraySlice<int64> output_shape, 276 TTypes<int32>::ConstMatrix cdf, 277 gtl::ArraySlice<int64> cdf_shape, 278 const string& encoded) const { 279 BroadcastRange<int16, int32, N> view{output.data(), output_shape, 280 cdf.data(), cdf_shape}; 281 282 RangeDecoder decoder{encoded, precision_}; 283 284 const int64 output_size = output.size(); 285 const int64 cdf_size = cdf.size(); 286 const auto chip_size = 287 static_cast<gtl::ArraySlice<int32>::size_type>(cdf.dimension(1)); 288 289 for (int64 i = 0; i < output_size; ++i) { 290 const auto pair = view.Next(); 291 292 int16* data = pair.first; 293 DCHECK_LT(data, output.data() + output_size); 294 295 const int32* cdf_slice = pair.second; 296 DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); 297 298 *data = decoder.Decode(gtl::ArraySlice<int32>{cdf_slice, chip_size}); 299 } 300 } 301 302 int precision_; 303 }; 304 305 REGISTER_KERNEL_BUILDER(Name("RangeDecode").Device(DEVICE_CPU), RangeDecodeOp); 306 } // namespace 307 } // namespace tensorflow 308