1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements quantized eight-bit versions of the convolution operations. 17 18 #include <algorithm> 19 #include <vector> 20 21 #define EIGEN_USE_THREADS 22 23 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK 24 #include "public/gemmlowp.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/kernels/conv_ops.h" 28 #include "tensorflow/core/kernels/meta_support.h" 29 #include "tensorflow/core/kernels/ops_util.h" 30 #include "tensorflow/core/kernels/quantization_utils.h" 31 #include "tensorflow/core/kernels/reference_gemm.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/util/padding.h" 34 35 namespace tensorflow { 36 37 // This functor implements the convolution operation in as simple a form as 38 // possible. It won't give great performance, but it is very useful for 39 // stepping through and instrumenting for debugging, creating minimal benchmarks 40 // to prototype with, and sharing with teams that want to run this outside of 41 // our environment. 42 // With that in mind, I've avoided using anything except pretty standard C++ 43 // types. This is especially noticeable in the data access through raw array 44 // indexing. It's deliberate in this case though, since it makes the underlying 45 // memory order very explicit, which is important for both inspecting memory 46 // contents during debugging and for specifying what we expect to others. 47 // The memory layout of the data is, from biggest stride to smallest: 48 // input_data = [input_batches, input_height, input_width, input_depth] 49 // filter_data = [filter_height, filter_width, input_depth, filter_count] 50 // output_data = [input_batches, output_height, output_width, filter_count] 51 template <class T1, class T2, class T3> 52 class ReferenceConvFunctor { 53 public: 54 void operator()(OpKernelContext* context, const T1* input_data, 55 int input_batches, int input_height, int input_width, 56 int input_depth, int input_offset, const T2* filter_data, 57 int filter_height, int filter_width, int filter_count, 58 int filter_offset, int stride, Padding padding, 59 T3* output_data, int output_height, int output_width, 60 int output_shift, int output_offset, int output_mult) { 61 // Set up some constants we need for the output down-shifting and 62 // saturation. 63 const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest()); 64 const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest()); 65 66 // When we're converting the 32 bit accumulator to a lower bit depth, we 67 // need to add on 0.5 in fixed-point terms to make the operation round half 68 // up towards positive infinity, rather than a floor. 69 // We also need to watch out for the case when there's no down shift, 70 // because a left shift by a negative number gives undefined results. 71 const int32 rounding = (output_shift < 1) ? 0 : (1 << (output_shift - 1)); 72 73 // The two different padding modes we support can be a bit confusing. SAME 74 // means we're trying to produce an output image that's the same size as the 75 // input. It's complicated by stride, which shrinks the output image by a 76 // a factor, but it means we end up sampling from outside the borders of the 77 // input. These out-of-bounds values are read as zeroes. VALID means only 78 // produce output values where the filters can read all their values from 79 // within the input image. It effectively removes the margins of the output 80 // image compared to the one produced by SAME. Stride complicates this 81 // definition though, because it can result in the right and bottom filter 82 // patches sampling from outside the borders if it's greater than 1. 83 // Most of the logic for sorting this all out is done before this function, 84 // when we calculate the output size, but the positioning of the origin of 85 // the filters is different between the two modes, since SAME positions the 86 // first filter off the edge of the input. 87 int filter_left_offset; 88 int filter_top_offset; 89 if (padding == VALID) { 90 filter_left_offset = 91 ((output_width - 1) * stride + filter_width - input_width + 1) / 2; 92 filter_top_offset = 93 ((output_height - 1) * stride + filter_height - input_height + 1) / 2; 94 } else { 95 filter_left_offset = 96 ((output_width - 1) * stride + filter_width - input_width) / 2; 97 filter_top_offset = 98 ((output_height - 1) * stride + filter_height - input_height) / 2; 99 } 100 101 // If we've got multiple images in our input, work through each of them. 102 for (int batch = 0; batch < input_batches; ++batch) { 103 // Walk through all the output image values, sliding the filter to 104 // different 105 // positions in the input. 106 for (int out_y = 0; out_y < output_height; ++out_y) { 107 for (int out_x = 0; out_x < output_width; ++out_x) { 108 // Each filter kernel produces one output channel. 109 for (int out_channel = 0; out_channel < filter_count; ++out_channel) { 110 // We're going to calculate a single output value, which means we 111 // need to multiply a three dimensional kernel of weights against 112 // the current location within the input image. 113 /* 114 *-------------------------------... 115 |\ ^ 116 | \in_depth 117 | \ v 118 | *-------------------------------... 119 | | ^ 120 | | in_y_origin 121 | | v \ 122 | |<in_x_origin>*---*^ 123 | | \| |filter_height 124 . | *---*v 125 . | <---> 126 . filter_width 127 . 128 */ 129 const int in_x_origin = (out_x * stride) - filter_left_offset; 130 const int in_y_origin = (out_y * stride) - filter_top_offset; 131 int32 total = 0; 132 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 133 for (int filter_x = 0; filter_x < filter_width; ++filter_x) { 134 for (int in_channel = 0; in_channel < input_depth; 135 ++in_channel) { 136 const int in_x = in_x_origin + filter_x; 137 const int in_y = in_y_origin + filter_y; 138 int32 input_value; 139 // If the location is outside the bounds of the input image, 140 // use zero as a default value. 141 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && 142 (in_y < input_height)) { 143 const T1 input_source_value = 144 input_data[(batch * input_height * input_width * 145 input_depth) + 146 (in_y * input_width * input_depth) + 147 (in_x * input_depth) + in_channel]; 148 // We're promoting the T1 type to a higher bit depth here as 149 // we do the subtraction. 150 input_value = 151 static_cast<int32>(input_source_value) - input_offset; 152 } else { 153 input_value = 0; 154 } 155 const T2 filter_source_value = 156 filter_data[(filter_y * filter_width * input_depth * 157 filter_count) + 158 (filter_x * input_depth * filter_count) + 159 (in_channel * filter_count) + out_channel]; 160 // Another promotion to 32 bit, as above. 161 const int32 filter_value = 162 static_cast<int32>(filter_source_value) - filter_offset; 163 total += (input_value * filter_value); 164 } 165 } 166 } 167 // Here we're applying scale factors to compress the 32 bit 168 // accumulated total to a potentially lower bit depth. 169 const int32_t output = 170 ((((total + output_offset) * output_mult) + rounding) >> 171 output_shift); 172 // We need to saturate the results against the largest and smallest 173 // values that can be represented in this type. 174 const int32 top_clamped_output = std::min(output, highest); 175 const int32 clamped_output = std::max(top_clamped_output, lowest); 176 output_data[(batch * output_height * output_width * filter_count) + 177 (out_y * output_width * filter_count) + 178 (out_x * filter_count) + out_channel] = clamped_output; 179 } 180 } 181 } 182 } 183 } 184 }; 185 186 // We don't want to allocate a buffer to hold all the patches if the size is 187 // going to be extremely large, so break it into chunks if it's bigger than 188 // a limit. Each chunk will be processed serially, so we can refill the 189 // buffer for the next chunk and reuse it, keeping maximum memory size down. 190 // In this case, we've picked 1 megabyte as a reasonable limit, from 191 // experimentation. 192 const size_t kMaxChunkSize = (1 * 1024 * 1024); 193 194 // Implements convolution as a two stage process, first packing the patches of 195 // the input image into columns (im2col) and then running GEMM to produce the 196 // final result. 197 template <class T1, class T2, class T3> 198 class Im2ColConvFunctor { 199 public: 200 void operator()(OpKernelContext* context, const T1* input_data, 201 int input_batches, int input_height, int input_width, 202 int input_depth, int input_offset, const T2* filter_data, 203 int filter_height, int filter_width, int filter_count, 204 int filter_offset, int stride, Padding padding, 205 T3* output_data, int output_height, int output_width, 206 int output_shift, int output_offset, int output_mult) { 207 if (input_offset < 0) { 208 // Only log the first few occurrences of this warning. 209 static int warning_count = 0; 210 if (warning_count < 10) { 211 ++warning_count; 212 LOG(WARNING) 213 << "For kernel '" << context->op_kernel().name() << "' from input '" 214 << context->op_kernel().requested_input(0) 215 << "': Zero is not representable in the quantized range used by the" 216 << " input. This means QuantizedConv2d has to fall back to a slow" 217 << " implementation, since the border of zero values can't be" 218 << " represented easily. You should try to construct graphs that" 219 << " avoid this situation."; 220 } 221 ReferenceConvFunctor<T1, T2, T3> conv_functor; 222 conv_functor(context, input_data, input_batches, input_height, 223 input_width, input_depth, input_offset, filter_data, 224 filter_height, filter_width, filter_count, filter_offset, 225 stride, padding, output_data, output_height, output_width, 226 output_shift, output_offset, output_mult); 227 return; 228 } 229 230 CHECK_GT(output_width, 0); 231 CHECK_GT(output_height, 0); 232 int filter_left_offset; 233 int filter_top_offset; 234 if (padding == VALID) { 235 filter_left_offset = 236 ((output_width - 1) * stride + filter_width - input_width + 1) / 2; 237 filter_top_offset = 238 ((output_height - 1) * stride + filter_height - input_height + 1) / 2; 239 } else { 240 filter_left_offset = 241 ((output_width - 1) * stride + filter_width - input_width) / 2; 242 filter_top_offset = 243 ((output_height - 1) * stride + filter_height - input_height) / 2; 244 } 245 246 // The im2col buffer has # of patches rows, and # of filters cols. 247 // It's laid out like this, in row major order in memory: 248 // < filter value count > 249 // ^ +---------------------+ 250 // patch | | 251 // count | | 252 // v +---------------------+ 253 // Each patch row contains a filter_width x filter_height patch of the 254 // input, with the depth channel as the most contiguous in memory, followed 255 // by the width, then the height. This is the standard memory order in the 256 // image world if it helps to visualize it. 257 const int filter_value_count = filter_width * filter_height * input_depth; 258 const int64 patches_per_chunk = 259 kMaxChunkSize / (filter_value_count * sizeof(T1)); 260 const int64 chunk_value_count = 261 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); 262 // TODO(petewarden) - Memory allocation can be very slow on Android. Can we 263 // optimize this by keeping the scratch buffer around? 264 // Because memory allocation is very expensive on mobile platforms, try to 265 // allocate a persistent buffer that will be kept around between calls. We 266 // use TensorFlow's resource management to ensure that the memory will be 267 // released when the session is over. 268 Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource; 269 std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)> 270 creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) { 271 #ifdef _MSC_VER 272 // MSVC complains about the capture of chunk_value_count which oddly 273 // works fine in conv_ops_using_gemm.cc for example. 274 // Define chunk_value_count inside the lambda for now. 275 const int64 chunk_value_count = 276 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); 277 #endif 278 *resource = new Im2ColBufferResource<T1, chunk_value_count>(); 279 return Status::OK(); 280 }; 281 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( 282 "Conv2d", "im2col_buffer", 283 &im2col_buffer_resource, creator)); 284 // This means that multiple ops can't be run simultaneously on different 285 // threads, because we have a single shared resource. The platforms this is 286 // aimed at have intra-op parallelism as their focus though, so it shouldn't 287 // be an issue. 288 mutex_lock lock_buffer(im2col_buffer_resource->mu); 289 core::ScopedUnref unref_buffer(im2col_buffer_resource); 290 T1* im2col_buffer = im2col_buffer_resource->data; 291 292 const int64 patch_count = (input_batches * output_height * output_width); 293 const int64 chunk_count = 294 (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; 295 296 for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { 297 const int64 patch_index_start = chunk_index * patches_per_chunk; 298 const int64 patch_index_end = 299 std::min(patch_index_start + patches_per_chunk, patch_count); 300 for (int64 patch_index = patch_index_start; patch_index < patch_index_end; 301 ++patch_index) { 302 const int64 batch = patch_index / (output_height * output_width); 303 const int64 out_y = (patch_index / output_width) % output_height; 304 const int64 out_x = patch_index % output_width; 305 const T1* input_batch_start = 306 input_data + (batch * input_height * input_width * input_depth); 307 const int in_y_origin = (out_y * stride) - filter_top_offset; 308 const int in_x_origin = (out_x * stride) - filter_left_offset; 309 const int patch_index_within_chunk = patch_index % patches_per_chunk; 310 T1* im2col_patch_start = 311 im2col_buffer + (patch_index_within_chunk * filter_value_count); 312 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 313 const int in_y = in_y_origin + filter_y; 314 T1* im2col_row_start = 315 im2col_patch_start + (filter_y * filter_width * input_depth); 316 // If we're off the top or the bottom of the input, fill the 317 // whole row with zeroes. 318 if ((in_y < 0) || (in_y >= input_height)) { 319 // On Android, memset and memcpy are significantly faster than the 320 // more modern std::set and std::copy equivalents. 321 memset(im2col_row_start, input_offset, 322 (filter_width * input_depth)); 323 } else { 324 // What we're doing here is trying to copy and fill the im2col 325 // buffer as efficiently as possible, using functions to set or 326 // duplicate values en masse. We know we don't have to worry about 327 // vertical edges because we dealt with that case above, so we 328 // just need to handle filters that overlap the left or right 329 // edges. Here's what that looks like: 330 // 331 // < left_zero_count > < center_copy_count > < right_zero_count > 332 // +------------------+---------------------+--------------------+ 333 // | (filter) | (image) | (filter) | 334 // +------------------+---------------------+--------------------+ 335 // in_x_origin 0 input_width in_x_end 336 // 337 // In reality it's unlikely that a filter patch will be wider 338 // than an input, but this shows all the edge cases. 339 // We use memset() to set the left and right sections to zeroes 340 // and memcpy() to copy over the input data for the center. These 341 // are preferred to std::fill and std::copy because they're much 342 // faster on Android. 343 const int in_x_end = in_x_origin + filter_width; 344 const int left_zero_count = std::max(0, 0 - in_x_origin); 345 const int right_zero_count = std::max(0, in_x_end - input_width); 346 const int center_copy_count = 347 filter_width - (left_zero_count + right_zero_count); 348 if (left_zero_count > 0) { 349 T1* im2col_left_start = im2col_row_start; 350 memset(im2col_left_start, input_offset, 351 (left_zero_count * input_depth)); 352 } 353 if (center_copy_count > 0) { 354 const T1* input_row_start = 355 input_batch_start + (in_y * input_width * input_depth) + 356 (std::max(0, in_x_origin) * input_depth); 357 T1* im2col_center_start = 358 im2col_row_start + (left_zero_count * input_depth); 359 memcpy(im2col_center_start, input_row_start, 360 (center_copy_count * input_depth)); 361 } 362 if (right_zero_count > 0) { 363 T1* im2col_right_start = 364 im2col_row_start + 365 ((left_zero_count + center_copy_count) * input_depth); 366 memset(im2col_right_start, input_offset, 367 (right_zero_count * input_depth)); 368 } 369 } 370 } 371 } 372 // Now we've assembled a set of image patches into a matrix, apply a 373 // GEMM matrix multiply of the patches as rows, times the filter 374 // weights in columns, to get partial results in the output matrix. 375 const int how_many_patches = patch_index_end - patch_index_start; 376 const bool transpose_a = false; 377 const bool transpose_b = false; 378 const bool transpose_c = false; 379 const int m = how_many_patches; 380 const int n = filter_count; 381 const int k = filter_value_count; 382 const int lda = filter_value_count; 383 const int ldb = filter_count; 384 const int ldc = filter_count; 385 T3* chunk_output_data = output_data + (patch_index_start * filter_count); 386 387 if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && 388 std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && 389 (output_offset == 0) && (output_mult == 1) && (output_shift == 0) && 390 (transpose_c == false) && (k <= 2048)) { 391 meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer, 392 filter_data, chunk_output_data, m, n, k, 393 -input_offset, -filter_offset, lda, ldb, ldc); 394 } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && 395 std::is_same<T3, qint32>() && (output_offset == 0) && 396 (output_mult == 1) && (output_shift == 0)) { 397 // The gemmlowp optimized library only works for a particular set of 398 // data types, so check if we meet those requirements and fall back to a 399 // slower reference implementation if not. 400 const uint8* im2col_data_as_uint8 = &(im2col_buffer->value); 401 const uint8* filter_data_as_uint8 = &(filter_data->value); 402 int32* output_data_as_int32 = &(chunk_output_data->value); 403 // All of the transpose_* variables are currently compile-time consts, 404 // so we could just hard-code these values too, but that would break if 405 // anybody changed those values in the future (e.g. to match the ability 406 // of MatMul to specify them as attributes). We're using a verbose 407 // approach of deriving the order values from the transpose variables to 408 // be able to catch any changes like that. 409 static const gemmlowp::MapOrder ResultOrder = 410 !transpose_c ? gemmlowp::MapOrder::RowMajor 411 : gemmlowp::MapOrder::ColMajor; 412 static const gemmlowp::MapOrder LhsOrder = 413 !transpose_a ? gemmlowp::MapOrder::RowMajor 414 : gemmlowp::MapOrder::ColMajor; 415 static const gemmlowp::MapOrder RhsOrder = 416 !transpose_b ? gemmlowp::MapOrder::RowMajor 417 : gemmlowp::MapOrder::ColMajor; 418 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs( 419 im2col_data_as_uint8, m, k, lda); 420 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs( 421 filter_data_as_uint8, k, n, ldb); 422 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result( 423 output_data_as_int32, m, n, ldc); 424 const std::tuple<> empty_pipeline = {}; 425 426 auto& worker_threads = 427 *(context->device()->tensorflow_cpu_worker_threads()); 428 TensorflowGemmContext context(worker_threads.num_threads, 429 worker_threads.workers); 430 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, 431 gemmlowp::DefaultL8R8BitDepthParams>( 432 &context, lhs, rhs, &result, -input_offset, -filter_offset, 433 empty_pipeline); 434 // Since gemmlowp uses assembly to write to the output, msan won't 435 // detect the output buffer as written to, so we mark it manually. 436 TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_data_as_int32, 437 m * n * sizeof(int32)); 438 } else { 439 ReferenceGemm<T1, T2, T3>( 440 transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer, 441 input_offset, lda, filter_data, filter_offset, ldb, 442 chunk_output_data, output_shift, output_offset, output_mult, ldc); 443 } 444 } 445 } 446 }; 447 448 template <class T1, class T2, class T3, 449 template <class TF1, class TF2, class TF3> class ConvFunctor> 450 class QuantizedConv2DOp : public OpKernel { 451 public: 452 explicit QuantizedConv2DOp(OpKernelConstruction* context) 453 : OpKernel(context) { 454 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 455 OP_REQUIRES(context, strides_.size() == 4, 456 errors::InvalidArgument("Sliding window strides field must " 457 "specify 4 dimensions")); 458 OP_REQUIRES(context, strides_[1] == strides_[2], 459 errors::InvalidArgument( 460 "Current implementation only supports equal length " 461 "strides in the row and column dimensions.")); 462 OP_REQUIRES( 463 context, (strides_[0] == 1 && strides_[3] == 1), 464 errors::InvalidArgument("Current implementation does not yet support " 465 "strides in the batch and depth dimensions.")); 466 std::vector<int32> dilations; 467 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); 468 OP_REQUIRES(context, dilations.size() == 4, 469 errors::InvalidArgument("Dilations field must " 470 "specify 4 dimensions")); 471 OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1, 472 errors::InvalidArgument( 473 "Current implementation only supports dilated rate as 1 " 474 "in the row and column dimensions.")); 475 OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1), 476 errors::InvalidArgument( 477 "Current implementation does not yet support " 478 "dilations in the batch and depth dimensions.")); 479 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 480 } 481 482 void Compute(OpKernelContext* context) override { 483 // Input tensor is of the following dimensions: 484 // [ batch, in_rows, in_cols, in_depth ] 485 const Tensor& input = context->input(0); 486 487 // Input filter is of the following dimensions: 488 // [ filter_rows, filter_cols, in_depth, out_depth] 489 const Tensor& filter = context->input(1); 490 491 // For 2D convolution, there should be 4 dimensions. 492 OP_REQUIRES(context, input.dims() == 4, 493 errors::InvalidArgument("input must be 4-dimensional", 494 input.shape().DebugString())); 495 OP_REQUIRES(context, filter.dims() == 4, 496 errors::InvalidArgument("filter must be 4-dimensional: ", 497 filter.shape().DebugString())); 498 499 const float min_input = context->input(2).flat<float>()(0); 500 const float max_input = context->input(3).flat<float>()(0); 501 const float min_filter = context->input(4).flat<float>()(0); 502 const float max_filter = context->input(5).flat<float>()(0); 503 const int32 offset_input = 504 FloatToQuantizedUnclamped<T1>(0.0f, min_input, max_input); 505 const int32 offset_filter = 506 FloatToQuantizedUnclamped<T2>(0.0f, min_filter, max_filter); 507 const int32 offset_output = 0; 508 const int32 mult_output = 1; 509 const int32 shift_output = 0; 510 511 // The last dimension for input is in_depth. It must be the same as the 512 // filter's in_depth. 513 const int64 in_depth = input.dim_size(3); 514 OP_REQUIRES(context, in_depth == filter.dim_size(2), 515 errors::InvalidArgument( 516 "input and filter must have the same depth: ", in_depth, 517 " vs ", filter.dim_size(2))); 518 519 // The last dimension for filter is out_depth. 520 const int64 out_depth = filter.dim_size(3); 521 522 // The second dimension for input is rows/height. 523 // The first dimension for filter is rows/height. 524 const int64 input_rows = input.dim_size(1); 525 const int64 filter_rows = filter.dim_size(0); 526 527 // The third dimension for input is columns/width. 528 // The second dimension for filter is columns/width. 529 const int64 input_cols = input.dim_size(2); 530 const int64 filter_cols = filter.dim_size(1); 531 532 // The first dimension for input is batch. 533 const int64 batch = input.dim_size(0); 534 535 // For now we take the stride from the second dimension only (we 536 // assume row = col stride, and do not support striding on the 537 // batch or depth dimension). 538 const int stride = strides_[1]; 539 540 int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; 541 OP_REQUIRES_OK(context, 542 GetWindowedOutputSize(input_rows, filter_rows, stride, 543 padding_, &out_rows, &pad_rows)); 544 OP_REQUIRES_OK(context, 545 GetWindowedOutputSize(input_cols, filter_cols, stride, 546 padding_, &out_cols, &pad_cols)); 547 CHECK_GT(batch, 0); 548 CHECK_GT(out_rows, 0); 549 CHECK_GT(out_cols, 0); 550 CHECK_GT(out_depth, 0); 551 TensorShape out_shape({batch, out_rows, out_cols, out_depth}); 552 553 // Output tensor is of the following dimensions: 554 // [ in_batch, out_rows, out_cols, out_depth ] 555 Tensor* output = nullptr; 556 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 557 558 // This will call different implementations (e.g. reference or optimized) 559 // depending on the template parameter. 560 ConvFunctor<T1, T2, T3> conv_functor; 561 conv_functor(context, input.flat<T1>().data(), batch, input_rows, 562 input_cols, in_depth, offset_input, filter.flat<T2>().data(), 563 filter_rows, filter_cols, out_depth, offset_filter, stride, 564 padding_, output->flat<T3>().data(), out_rows, out_cols, 565 shift_output, offset_output, mult_output); 566 567 float min_output_value; 568 float max_output_value; 569 QuantizationRangeForMultiplication<T1, T2, T3>( 570 min_input, max_input, min_filter, max_filter, &min_output_value, 571 &max_output_value); 572 573 Tensor* output_min = nullptr; 574 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); 575 output_min->flat<float>()(0) = min_output_value; 576 577 Tensor* output_max = nullptr; 578 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); 579 output_max->flat<float>()(0) = max_output_value; 580 } 581 582 private: 583 std::vector<int32> strides_; 584 Padding padding_; 585 }; 586 587 // Right now we only support taking two eight bit inputs, and returning the 588 // results as signed 32-bit integers. 589 REGISTER_KERNEL_BUILDER( 590 Name("QuantizedConv2D") 591 .Device(DEVICE_CPU) 592 .TypeConstraint<quint8>("Tinput") 593 .TypeConstraint<quint8>("Tfilter") 594 .TypeConstraint<qint32>("out_type"), 595 QuantizedConv2DOp<quint8, quint8, qint32, Im2ColConvFunctor>); 596 597 } // namespace tensorflow 598