1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements convolution operations with other kernels baked into the 17 // processing, to optimize latency and memory usage. 18 19 #define EIGEN_USE_THREADS 20 21 #include <string.h> 22 #include <map> 23 #include <vector> 24 #include "tensorflow/core/framework/common_shape_fns.h" 25 #include "tensorflow/core/framework/numeric_op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/resource_mgr.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/tensor_slice.h" 32 #include "tensorflow/core/kernels/bounds_check.h" 33 #include "tensorflow/core/kernels/conv_ops.h" 34 #include "tensorflow/core/kernels/gemm_functors.h" 35 #include "tensorflow/core/kernels/image_resizer_state.h" 36 #include "tensorflow/core/lib/core/threadpool.h" 37 #include "tensorflow/core/util/mirror_pad_mode.h" 38 #include "tensorflow/core/util/padding.h" 39 #include "tensorflow/core/util/tensor_format.h" 40 41 namespace tensorflow { 42 43 namespace { 44 45 // We don't want to allocate a buffer to hold all the patches if the size is 46 // going to be extremely large, so break it into chunks if it's bigger than 47 // a limit. Each chunk will be processed serially, so we can refill the 48 // buffer for the next chunk and reuse it, keeping maximum memory size down. 49 // In this case, we've picked 16 megabytes as a reasonable limit for Android and 50 // other platforms using Eigen, and 1MB for iOS devices, from experimentation. 51 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM) 52 const size_t kMaxChunkSize = (1 * 1024 * 1024); 53 #else 54 const size_t kMaxChunkSize = (16 * 1024 * 1024); 55 #endif 56 const size_t kResizeCacheSize = (8 * 1024 * 1024); 57 58 // Lookup method used when resizing. 59 enum SamplingMode { 60 BILINEAR = 0, 61 NEAREST = 1, 62 }; 63 64 // Simple utility function used by FusedConv to multithread basic workloads. To 65 // use it, pass begin and end values for the full workload and a std::function 66 // that receives a subset of that through the begin and end values for each 67 // worker's task. The division of the full workload into worker tasks is handled 68 // by the multithreading logic. Here's an example of how to use it: 69 // std::vector<float> my_vector(100); 70 // ... 71 // FusedConvParallelFor(context, 0, 100, 72 // [&my_vector](int64 task_begin, int64 task_end) { 73 // for (int64 current = task_begin; current != task_end; ++current) { 74 // my_vector[current] *= 10.0f; 75 // } 76 // }); 77 void FusedConvParallelFor( 78 OpKernelContext* context, int64 begin, int64 end, 79 const std::function<void(int64, int64)>& task_function) { 80 // On iOS, the thread management imposes a very big performance penalty, so 81 // just call the function directly with no multithreading. 82 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM) 83 task_function(begin, end); 84 #else 85 auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 86 thread::ThreadPool* thread_pool = worker_threads.workers; 87 const int64 total_elements = end - begin; 88 // This is a bit of an arbitrary number, but was found to work well for 89 // typical models we've been profiling on various devices. 90 const int64 element_cost = 10000000; 91 thread_pool->ParallelFor( 92 total_elements, element_cost, 93 [begin, task_function](int64 begin_offset, int64 end_offset) { 94 const int64 task_begin = begin + begin_offset; 95 const int64 task_end = begin + end_offset; 96 task_function(task_begin, task_end); 97 }); 98 #endif 99 } 100 101 // Holds the state needed for the resizing subtasks. 102 template <class T1> 103 struct ResizeTaskParameters { 104 ResizeTaskParameters() : st(false) {} 105 106 int cache_height; 107 T1* resize_cache; 108 int cache_line_width; 109 int input_width; 110 int input_depth; 111 int top_padding; 112 int pad_offset; 113 int64 resized_height; 114 ImageResizerState st; 115 const T1* input_batch_start; 116 int64 cache_start_x; 117 int64 cache_end_x; 118 int left_padding; 119 int64 resized_width; 120 int64 padded_width; 121 int64 padded_height; 122 }; 123 124 template <class T1> 125 struct PerCacheLineParameters { 126 PerCacheLineParameters() {} 127 PerCacheLineParameters(const PerCacheLineParameters<T1>& other) 128 : cache_line_start(other.cache_line_start), 129 input_top_row_start(other.input_top_row_start), 130 input_bottom_row_start(other.input_bottom_row_start), 131 y_lerp(other.y_lerp) {} 132 133 T1* cache_line_start; 134 const T1* input_top_row_start; 135 const T1* input_bottom_row_start; 136 T1 y_lerp; 137 }; 138 139 // Helper class to simplify bilinear filtering 140 template <class T1> 141 struct SampleRect { 142 EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right, 143 const T1* in_bottom_left, 144 const T1* in_bottom_right) 145 : top_left(in_top_left), 146 top_right(in_top_right), 147 bottom_left(in_bottom_left), 148 bottom_right(in_bottom_right) {} 149 150 EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp, 151 T1 y_lerp) const { 152 const T1 top = 153 top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp; 154 const T1 bottom = bottom_left[channel] + 155 (bottom_right[channel] - bottom_left[channel]) * x_lerp; 156 return top + (bottom - top) * y_lerp; 157 } 158 159 const T1* top_left; 160 const T1* top_right; 161 const T1* bottom_left; 162 const T1* bottom_right; 163 }; 164 165 // Calculates parameters which remain constant through a resize cache row. 166 template <class T1> 167 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters( 168 int64 cache_height, int64 cache_y, T1* resize_cache, int64 cache_line_width, 169 int64 input_width, int64 input_depth, int64 top_padding, int64 pad_offset, 170 int64 resized_height, const ImageResizerState& st, 171 const T1* input_batch_start) { 172 PerCacheLineParameters<T1> result; 173 // The cache is organized so that the real y values of the resized image map 174 // onto the actual cache values through a modulo scheme. This means that as we 175 // progress downwards through the image, we keep reusing a small cache and so 176 // keep memory usage down. 177 int64 cache_index_y; 178 if (cache_y < 0) { 179 cache_index_y = cache_height + (cache_y % cache_height); 180 } else { 181 cache_index_y = cache_y % cache_height; 182 } 183 result.cache_line_start = 184 resize_cache + (cache_index_y * cache_line_width * input_depth); 185 // This part is implementing the mirror padding that happens before resizing. 186 float in_y = (cache_y - top_padding); 187 if (in_y < 0) { 188 in_y = -(in_y + 1.0f - pad_offset); 189 } else if (in_y >= resized_height) { 190 in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset); 191 } 192 // Here's where do do the actual resize. 193 in_y *= st.height_scale; 194 const int64 top_y_index = static_cast<int64>(std::floor(in_y)); 195 const int64 bottom_y_index = 196 std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1)); 197 // Lerp is used for bilinear filtering when that's needed. 198 result.y_lerp = in_y - top_y_index; 199 // Which rows of the original input image to pull the values from. 200 result.input_top_row_start = 201 input_batch_start + (top_y_index * input_width * input_depth); 202 result.input_bottom_row_start = 203 input_batch_start + (bottom_y_index * input_width * input_depth); 204 return result; 205 } 206 207 template <class T1> 208 struct PerCachePixelParameters { 209 PerCachePixelParameters() {} 210 PerCachePixelParameters(const PerCachePixelParameters<T1>& other) 211 : cache_line_pixel(other.cache_line_pixel), 212 left_x_index(other.left_x_index), 213 right_x_index(other.right_x_index), 214 x_lerp(other.x_lerp) {} 215 216 T1* cache_line_pixel; 217 int64 left_x_index; 218 int64 right_x_index; 219 T1 x_lerp; 220 }; 221 222 // Pulls out common parameters used for every resized pixel. 223 template <class T1> 224 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1> 225 CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x, 226 T1* cache_line_start, int64 input_depth, 227 int64 left_padding, int64 pad_offset, 228 int64 resized_width, 229 const ImageResizerState& st) { 230 PerCachePixelParameters<T1> result; 231 // Figure out where we're going to store the results of our transform. 232 const int cache_index_x = cache_x - cache_start_x; 233 result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth); 234 // Implement mirror padding by flipping in_x if it's off the edge. 235 float in_x = (cache_x - left_padding); 236 if (in_x < 0) { 237 in_x = -(in_x + 1.0f - pad_offset); 238 } else if (in_x >= resized_width) { 239 in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset); 240 } 241 // Resize the x parameters. 242 in_x *= st.width_scale; 243 // Get the x coordinates for the left and right pixels to pull from. 244 result.left_x_index = static_cast<int64>(std::floor(in_x)); 245 result.right_x_index = 246 std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1)); 247 // This x_lerp is used to blend pixels in bilinear filtering. 248 result.x_lerp = in_x - result.left_x_index; 249 return result; 250 } 251 252 // Combines bilinear resizing and mirror padding into the im2col transformation 253 // stage of convolution. 254 template <class T1, class T2, class T3, class TGemmFunctor, 255 SamplingMode SampleMode> 256 class FusedResizeAndPadConvFunctor { 257 public: 258 void operator()(OpKernelContext* context, const Tensor& input, 259 int input_batches, int resized_height, int resized_width, 260 int padded_height, int padded_width, int input_depth, 261 const T2* filter_data, int filter_height, int filter_width, 262 int filter_count, int stride_rows, int stride_cols, 263 Padding padding, T3* output_data, int output_height, 264 int output_width, const ImageResizerState& st, 265 int top_padding, int bottom_padding, int left_padding, 266 int right_padding, int pad_offset) { 267 if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) || 268 (input_depth <= 0)) { 269 LOG(WARNING) << "Conv2D was called with bad input dimensions: " 270 << input_batches << ", " << padded_height << ", " 271 << padded_width << ", " << input_depth; 272 return; 273 } 274 if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) { 275 LOG(WARNING) << "Conv2D was called with bad filter dimensions: " 276 << filter_width << ", " << filter_height << ", " 277 << filter_count; 278 return; 279 } 280 if ((output_width <= 0) || (output_height <= 0)) { 281 LOG(WARNING) << "Conv2D was called with bad output width or height: " 282 << output_width << ", " << output_height; 283 return; 284 } 285 OP_REQUIRES( 286 context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)), 287 errors::InvalidArgument("Bad sample mode passed in", SampleMode)); 288 289 // These calculations define how the patches will be positioned within the 290 // input image. The actual definitions are quite complex, and rely on the 291 // previously-calculated output size. 292 int filter_left_offset; 293 int filter_top_offset; 294 if (padding == VALID) { 295 filter_left_offset = 296 ((output_width - 1) * stride_cols + filter_width - padded_width + 1) / 297 2; 298 filter_top_offset = ((output_height - 1) * stride_rows + filter_height - 299 padded_height + 1) / 300 2; 301 } else { 302 filter_left_offset = 303 ((output_width - 1) * stride_cols + filter_width - padded_width) / 2; 304 filter_top_offset = 305 ((output_height - 1) * stride_rows + filter_height - padded_height) / 306 2; 307 } 308 309 ResizeTaskParameters<T1> task_params; 310 task_params.input_depth = input_depth; 311 task_params.top_padding = top_padding; 312 task_params.pad_offset = pad_offset; 313 task_params.resized_height = resized_height; 314 task_params.st = st; 315 task_params.left_padding = left_padding; 316 task_params.resized_width = resized_width; 317 task_params.padded_width = padded_width; 318 task_params.padded_height = padded_height; 319 320 // The im2col buffer has # of patches rows, and # of filters cols. 321 // It's laid out like this, in row major order in memory: 322 // < filter value count > 323 // ^ +---------------------+ 324 // patch | | 325 // count | | 326 // v +---------------------+ 327 // Each patch row contains a filter_width x filter_height patch of the 328 // input, with the depth channel as the most contiguous in memory, followed 329 // by the width, then the height. This is the standard memory order in the 330 // image world if it helps to visualize it. 331 const int filter_value_count = filter_width * filter_height * input_depth; 332 333 OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize, 334 errors::InvalidArgument("Im2Col patch too large for buffer")); 335 const size_t patches_per_chunk = 336 kMaxChunkSize / (filter_value_count * sizeof(T1)); 337 // Because memory allocation is very expensive on mobile platforms, try to 338 // allocate a persistent buffer that will be kept around between calls. We 339 // use TensorFlow's resource management to ensure that the memory will be 340 // released when the session is over. 341 Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource; 342 std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator = 343 [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) { 344 *resource = new Im2ColBufferResource<T1, kMaxChunkSize>(); 345 return Status::OK(); 346 }; 347 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( 348 "Conv2d", "im2col_buffer", 349 &im2col_buffer_resource, creator)); 350 351 // Create a resize cache memory buffer that will hold the rows of 352 // transformed and mirror padded input pixels, ready to be copied 353 // into filter patches by im2col. 354 // It's laid out like this, in row major order in memory: 355 // < cache line width > 356 // ^ +--------------------+ 357 // cache | | 358 // height | | 359 // v +--------------------+ 360 // Each cache row contains a cache_line_width number of resized pixels, 361 // each with input_depth channels. The cache height is typically less than 362 // the full height the resized image would be, so it's filled up 363 // incrementally as we progress downwards through the input creating im2col 364 // patches. 365 task_params.cache_start_x = -filter_left_offset; 366 task_params.cache_end_x = 367 (((output_width - 1) * stride_cols) - filter_left_offset) + 368 filter_width; 369 task_params.cache_line_width = 370 task_params.cache_end_x - task_params.cache_start_x; 371 task_params.cache_height = 372 kResizeCacheSize / (task_params.cache_line_width * input_depth); 373 const int needed_resize_cache_count = 374 filter_height * task_params.cache_line_width * input_depth; 375 OP_REQUIRES(context, 376 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize, 377 errors::InvalidArgument("Input too large for resize cache")); 378 Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource; 379 std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)> 380 resize_creator = 381 [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) { 382 *resource = new Im2ColBufferResource<T1, kResizeCacheSize>(); 383 return Status::OK(); 384 }; 385 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( 386 "Conv2d", "resize_cache", 387 &resize_cache_resource, resize_creator)); 388 389 // This means that multiple ops can't be run simultaneously on different 390 // threads, because we have a single shared resource. The platforms this is 391 // aimed at have intra-op parallelism as their focus though, so it shouldn't 392 // be an issue. 393 mutex_lock lock_buffer(im2col_buffer_resource->mu); 394 core::ScopedUnref unref_buffer(im2col_buffer_resource); 395 T1* im2col_buffer = im2col_buffer_resource->data; 396 397 // This buffer is used as a fairly heavy-weight cache for the resized and 398 // mirrored inputs to the im2col operation. The problem is that we want to 399 // keep the memory usage down by not rendering the fully resized and padded 400 // input tensor to the convolution into an entire buffer. The first approach 401 // to avoid this was to fold the bilinear filtering and padding spatial 402 // transformations into the im2col lookup itself. This successfully reduced 403 // memory usage, but because im2col can access an individual pixel for many 404 // different patches, the extra overhead of doing the same bilinear lookups 405 // repeatedly became too expensive. 406 // The resize cache is designed to avoid this problem by keeping a 407 // horizontal slice of the resized and padded input to the im2col 408 // precalculated, so that repeated accesses to the same pixel from different 409 // filter patches can just be copied from this cache. It's organized as a 410 // horizontal slice stretching across the whole virtual image, and as high 411 // as the filter window, so that as the patch processing moves across all 412 // the pixels are present, and before a new row of patches is started any 413 // previously calculated rows that are needed are maintained, with new rows 414 // calculated as required. 415 mutex_lock resize_lock_buffer(resize_cache_resource->mu); 416 core::ScopedUnref unref_resized_cache(resize_cache_resource); 417 task_params.resize_cache = resize_cache_resource->data; 418 419 const T1* input_data = input.flat<T1>().data(); 420 const int64 input_height = input.shape().dim_sizes()[1]; 421 task_params.input_width = input.shape().dim_sizes()[2]; 422 423 int end_cached_lines = std::numeric_limits<int>::min(); 424 425 for (int batch = 0; batch < input_batches; ++batch) { 426 task_params.input_batch_start = 427 input_data + 428 (batch * input_height * task_params.input_width * input_depth); 429 const int in_y_end = 430 ((output_height * stride_rows) - filter_top_offset) + filter_height; 431 for (int out_y = 0; out_y < output_height; ++out_y) { 432 const int in_y_origin = (out_y * stride_rows) - filter_top_offset; 433 const int cache_start_y = std::max(in_y_origin, end_cached_lines); 434 const int cache_end_y = std::min( 435 in_y_end, std::max((in_y_origin + task_params.cache_height), 436 end_cached_lines)); 437 if (end_cached_lines < (in_y_origin + filter_height)) { 438 // This call breaks up the work required for calculating the mirror 439 // padding and resizing across multiple threads. 440 FusedConvParallelFor( 441 context, cache_start_y, cache_end_y, 442 [task_params](int64 task_cache_start_y, int64 task_cache_end_y) { 443 // This is a long and confusing function, but it's been laid out 444 // this way to help with performance on some intensive models. 445 // What it's doing is populating a cache of the original input 446 // image, after it's been bilinear resized and had its edges 447 // mirrored. This allows the following im2col code to access the 448 // transformed pixels from this cache, without having to 449 // repeatedly apply the expensive bilinear calculations as the 450 // same pixels are accessed by different patches. 451 // This is most effective when the stride is small and the 452 // filter size is large, since that's when pixels are reused 453 // most frequently as patches overlap. 454 for (int cache_y = task_cache_start_y; 455 cache_y < task_cache_end_y; ++cache_y) { 456 // We organize the cache as a series of rows, each containing 457 // all the transformed pixels for a given line in the image. 458 // This cache is big enough to hold at least a filter's height 459 // worth of rows, but typically more, limited by the size of 460 // the cache buffer. 461 // We don't allocate an entire image's worth of rows though, 462 // because we're trying to keep memory usage down, so as we 463 // progress downwards through the im2col we periodically 464 // refresh the cache so that the next lines that are needed 465 // for that operation are always present. 466 // Work out the parameters that remain constant across the 467 // row we're calculating. 468 PerCacheLineParameters<float> line_params( 469 CalculatePerCacheLineParameters<float>( 470 task_params.cache_height, cache_y, 471 task_params.resize_cache, 472 task_params.cache_line_width, task_params.input_width, 473 task_params.input_depth, task_params.top_padding, 474 task_params.pad_offset, task_params.resized_height, 475 task_params.st, task_params.input_batch_start)); 476 // Iterate through the resize cache row we're filling in. 477 for (int cache_x = task_params.cache_start_x; 478 cache_x < task_params.cache_end_x; ++cache_x) { 479 // Figure out what we need for the cache pixel we're 480 // populating. 481 PerCachePixelParameters<T1> pixel_params( 482 CalculatePerCachePixelParameters<T1>( 483 cache_x, task_params.cache_start_x, 484 line_params.cache_line_start, 485 task_params.input_depth, task_params.left_padding, 486 task_params.pad_offset, task_params.resized_width, 487 task_params.st)); 488 // If the access is off the left, right, top, or bottom of 489 // the resized image, the conv padding means we should set 490 // it to zero. 491 if ((cache_x < 0) || 492 (cache_x >= task_params.padded_width) || 493 (cache_y < 0) || 494 (cache_y >= task_params.padded_height)) { 495 std::fill_n(pixel_params.cache_line_pixel, 496 task_params.input_depth, T1(0)); 497 } else { 498 // There are two different sampling strategies for 499 // resizing. When using nearest, we can just do a 500 // straight copy of the pixel closest to our sample point, 501 // but bilinear requires a more complex calculation. 502 if (SampleMode == NEAREST) { 503 const T1* input_top_left_pixel = 504 line_params.input_top_row_start + 505 (pixel_params.left_x_index * 506 task_params.input_depth); 507 508 std::copy_n(input_top_left_pixel, 509 task_params.input_depth, 510 pixel_params.cache_line_pixel); 511 } else { 512 const SampleRect<T1> rect( 513 line_params.input_top_row_start + 514 (pixel_params.left_x_index * 515 task_params.input_depth), 516 line_params.input_top_row_start + 517 (pixel_params.right_x_index * 518 task_params.input_depth), 519 line_params.input_bottom_row_start + 520 (pixel_params.left_x_index * 521 task_params.input_depth), 522 line_params.input_bottom_row_start + 523 (pixel_params.right_x_index * 524 task_params.input_depth)); 525 for (int in_channel = 0; 526 in_channel < task_params.input_depth; 527 ++in_channel) { 528 pixel_params.cache_line_pixel[in_channel] = 529 rect.BilinearSample(in_channel, 530 pixel_params.x_lerp, 531 line_params.y_lerp); 532 } 533 } 534 } 535 } 536 } 537 }); 538 end_cached_lines = cache_end_y; 539 } 540 for (int out_x = 0; out_x < output_width; ++out_x) { 541 const int in_x_origin = (out_x * stride_cols) - filter_left_offset; 542 const int patch_index = (batch * output_width * output_height) + 543 (out_y * output_width) + out_x; 544 const int patch_index_within_chunk = patch_index % patches_per_chunk; 545 T1* im2col_patch_start = 546 im2col_buffer + (patch_index_within_chunk * filter_value_count); 547 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 548 T1* im2col_row_start = 549 im2col_patch_start + 550 (filter_y * filter_width * task_params.input_depth); 551 const int conv_in_y = in_y_origin + filter_y; 552 int cache_index_y; 553 if (conv_in_y < 0) { 554 cache_index_y = task_params.cache_height + 555 (conv_in_y % task_params.cache_height); 556 } else { 557 cache_index_y = conv_in_y % task_params.cache_height; 558 } 559 T1* cache_line_start = 560 task_params.resize_cache + 561 (cache_index_y * task_params.cache_line_width * 562 task_params.input_depth); 563 T1* cache_filter_row_start = 564 cache_line_start + ((in_x_origin - task_params.cache_start_x) * 565 task_params.input_depth); 566 std::copy_n(cache_filter_row_start, 567 (filter_width * task_params.input_depth), 568 im2col_row_start); 569 } 570 const bool is_last_in_chunk = 571 (patch_index_within_chunk == (patches_per_chunk - 1)); 572 const bool is_last_overall = 573 ((batch == (input_batches - 1)) && 574 (out_y == (output_height - 1)) && (out_x == (output_width - 1))); 575 if (is_last_in_chunk || is_last_overall) { 576 // Now we've assembled a set of image patches into a matrix, apply 577 // a GEMM matrix multiply of the patches as rows, times the filter 578 // weights in columns, to get partial results in the output 579 // matrix. 580 const int how_many_patches = patch_index_within_chunk + 1; 581 const int m = how_many_patches; 582 const int n = filter_count; 583 const int k = filter_value_count; 584 const int lda = filter_value_count; 585 const int ldb = filter_count; 586 const int ldc = filter_count; 587 const size_t start_patch_index = 588 patch_index - (how_many_patches - 1); 589 T3* chunk_output_data = 590 output_data + (start_patch_index * filter_count); 591 TGemmFunctor gemm_functor; 592 gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb, 593 chunk_output_data, ldc); 594 } 595 } 596 } 597 } 598 } 599 }; 600 601 } // namespace 602 603 // Implements a version of convolution with bilinear resizing and mirror padding 604 // included. 605 template <class T, class TConvFunctor, bool DoResize> 606 class FusedResizeConv2DUsingGemmOp : public OpKernel { 607 public: 608 explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context) 609 : OpKernel(context) { 610 if (DoResize) { 611 OP_REQUIRES_OK(context, 612 context->GetAttr("resize_align_corners", &align_corners_)); 613 } 614 MirrorPadMode mode; 615 OP_REQUIRES_OK(context, context->GetAttr("mode", &mode)); 616 617 switch (mode) { 618 case MirrorPadMode::SYMMETRIC: { 619 offset_ = 0; 620 break; 621 } 622 case MirrorPadMode::REFLECT: { 623 offset_ = 1; 624 break; 625 } 626 default: 627 OP_REQUIRES(context, false, 628 errors::InvalidArgument( 629 "mode must be either REFLECT or SYMMETRIC.")); 630 } 631 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 632 OP_REQUIRES(context, strides_.size() == 4, 633 errors::InvalidArgument("Sliding window strides field must " 634 "specify 4 dimensions")); 635 const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N'); 636 const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C'); 637 OP_REQUIRES( 638 context, stride_n == 1 && stride_c == 1, 639 errors::InvalidArgument("Current implementation does not yet support " 640 "strides in the batch and depth dimensions.")); 641 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 642 } 643 644 void Compute(OpKernelContext* context) override { 645 // Input tensor is of the following dimensions: 646 // [ batch, in_rows, in_cols, in_depth ] 647 const Tensor& input = context->input(0); 648 OP_REQUIRES(context, (input.shape().num_elements() > 0), 649 errors::InvalidArgument("Input tensor can't be empty")); 650 651 ImageResizerState st(false); 652 if (DoResize) { 653 st = ImageResizerState(align_corners_); 654 st.ValidateAndCalculateOutputSize(context, input); 655 if (!context->status().ok()) return; 656 } else { 657 // Set up the resize parameters to do no scaling at all. 658 st.batch_size = input.dim_size(0); 659 st.out_height = input.dim_size(1); 660 st.out_width = input.dim_size(2); 661 st.in_height = input.dim_size(1); 662 st.in_width = input.dim_size(2); 663 st.channels = input.dim_size(3); 664 st.height_scale = 1.0f; 665 st.width_scale = 1.0f; 666 } 667 TensorShape resized_shape( 668 {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)}); 669 int paddings_index; 670 int filter_index; 671 if (DoResize) { 672 paddings_index = 2; 673 filter_index = 3; 674 } else { 675 paddings_index = 1; 676 filter_index = 2; 677 } 678 const Tensor& paddings = context->input(paddings_index); 679 680 const int dims = resized_shape.dims(); 681 OP_REQUIRES( 682 context, 683 TensorShapeUtils::IsMatrix(paddings.shape()) && 684 paddings.dim_size(1) == 2, 685 errors::InvalidArgument("paddings must be a matrix with 2 columns: ", 686 paddings.shape().DebugString())); 687 const int fixed_dims = 688 (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1) 689 ? 1 690 : dims; 691 OP_REQUIRES( 692 context, fixed_dims == paddings.dim_size(0), 693 errors::InvalidArgument( 694 "The first dimension of paddings must be the rank of inputs: ", 695 fixed_dims, " ", paddings.shape().DebugString(), " ", 696 resized_shape.DebugString())); 697 OP_REQUIRES( 698 context, dims == paddings.dim_size(0), 699 errors::InvalidArgument( 700 "The first dimension of paddings must be the rank of inputs: ", 701 dims, " ", paddings.shape().DebugString(), " ", 702 resized_shape.DebugString())); 703 704 OP_REQUIRES( 705 context, dims == 4, 706 errors::InvalidArgument( 707 "Fused mirror padding only supports four-dimensional inputs, but ", 708 dims, " requested")); 709 710 // Compute the shape of the output tensor, and allocate it. 711 TensorShape padded_shape; 712 TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>(); 713 for (int d = 0; d < dims; ++d) { 714 const int32 before = 715 paddings_matrix(d, 0); // Pad before existing elements. 716 const int32 after = 717 paddings_matrix(d, 1); // Pad after existing elements. 718 OP_REQUIRES(context, before >= 0 && after >= 0, 719 errors::InvalidArgument( 720 "paddings must be non-negative: ", before, " ", after)); 721 if (offset_ == 0) { // SYMMETRIC mode. 722 OP_REQUIRES( 723 context, 724 before <= resized_shape.dim_size(d) && 725 after <= resized_shape.dim_size(d), 726 errors::InvalidArgument("paddings must be no greater " 727 "than the dimension size: ", 728 before, ", ", after, " greater than ", 729 resized_shape.dim_size(d))); 730 } else if (offset_ == 1) { // REFLECT mode. 731 OP_REQUIRES( 732 context, 733 before < resized_shape.dim_size(d) && 734 after < resized_shape.dim_size(d), 735 errors::InvalidArgument("paddings must be less than" 736 " the dimension size: ", 737 before, ", ", after, " not less than ", 738 resized_shape.dim_size(d))); 739 } 740 padded_shape.AddDim(before + resized_shape.dim_size(d) + after); 741 } 742 743 OP_REQUIRES( 744 context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)), 745 errors::InvalidArgument( 746 "Fused mirror padding only support spatial padding, not batches: ", 747 paddings.DebugString())); 748 OP_REQUIRES( 749 context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)), 750 errors::InvalidArgument( 751 "Fused mirror padding only support spatial padding, not channels: ", 752 paddings.DebugString())); 753 const int32 top_padding = paddings_matrix(1, 0); 754 const int32 bottom_padding = paddings_matrix(1, 1); 755 const int32 left_padding = paddings_matrix(2, 0); 756 const int32 right_padding = paddings_matrix(2, 1); 757 758 // Input filter is of the following dimensions: 759 // [ filter_rows, filter_cols, in_depth, out_depth] 760 const Tensor& filter = context->input(filter_index); 761 762 // For 2D convolution, there should be 4 dimensions. 763 OP_REQUIRES(context, padded_shape.dims() == 4, 764 errors::InvalidArgument("input must be 4-dimensional", 765 padded_shape.DebugString())); 766 OP_REQUIRES(context, filter.dims() == 4, 767 errors::InvalidArgument("filter must be 4-dimensional: ", 768 filter.shape().DebugString())); 769 770 // We only check the first three dims, since the depth is accessed as an 771 // int64 below. 772 for (int i = 0; i < 3; i++) { 773 OP_REQUIRES( 774 context, 775 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), 776 errors::InvalidArgument("filter too large")); 777 } 778 779 // The last dimension for input is in_depth. It must be the same as the 780 // filter's in_depth. 781 const int64 in_depth = padded_shape.dim_size(3); 782 OP_REQUIRES(context, in_depth == filter.dim_size(2), 783 errors::InvalidArgument( 784 "input and filter must have the same depth: ", in_depth, 785 " vs ", filter.dim_size(2))); 786 787 // The last dimension for filter is out_depth. 788 const int out_depth = static_cast<int>(filter.dim_size(3)); 789 790 // The second dimension for input is rows/height. 791 // The first dimension for filter is rows/height. 792 const int64 padded_rows_raw = padded_shape.dim_size(1); 793 OP_REQUIRES( 794 context, 795 FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()), 796 errors::InvalidArgument("Input rows too large")); 797 const int padded_rows = static_cast<int>(padded_rows_raw); 798 const int filter_rows = static_cast<int>(filter.dim_size(0)); 799 const int resized_rows = static_cast<int>(resized_shape.dim_size(1)); 800 801 // The third dimension for input is columns/width. 802 // The second dimension for filter is columns/width. 803 const int64 padded_cols_raw = padded_shape.dim_size(2); 804 OP_REQUIRES( 805 context, 806 FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()), 807 errors::InvalidArgument("Input cols too large")); 808 const int padded_cols = static_cast<int>(padded_cols_raw); 809 const int filter_cols = static_cast<int>(filter.dim_size(1)); 810 const int resized_cols = static_cast<int>(resized_shape.dim_size(2)); 811 812 // The first dimension for input is batch. 813 const int64 batch_raw = padded_shape.dim_size(0); 814 OP_REQUIRES(context, 815 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), 816 errors::InvalidArgument("batch is too large")); 817 const int batch = static_cast<int>(batch_raw); 818 819 // For now we take the stride from the second and third dimensions only (we 820 // do not support striding on the batch or depth dimension). 821 const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H'); 822 const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W'); 823 824 int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; 825 OP_REQUIRES_OK(context, 826 GetWindowedOutputSize(padded_rows, filter_rows, stride_rows, 827 padding_, &out_rows, &pad_rows)); 828 OP_REQUIRES_OK(context, 829 GetWindowedOutputSize(padded_cols, filter_cols, stride_cols, 830 padding_, &out_cols, &pad_cols)); 831 TensorShape out_shape = 832 ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth); 833 OP_REQUIRES(context, (out_shape.num_elements() > 0), 834 errors::InvalidArgument("Output tensor can't be empty")); 835 836 // Output tensor is of the following dimensions: 837 // [ in_batch, out_rows, out_cols, out_depth ] 838 Tensor* output = nullptr; 839 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 840 841 VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth 842 << ", padded_cols = " << padded_cols 843 << ", resized_cols = " << resized_cols 844 << ", filter_cols = " << filter_cols 845 << ", padded_rows = " << padded_rows 846 << ", resized_rows = " << resized_rows 847 << ", filter_rows = " << filter_rows 848 << ", stride_rows = " << stride_rows 849 << ", stride_cols = " << stride_cols 850 << ", out_depth = " << out_depth << ", DoResize=" << DoResize; 851 852 // If there is nothing to compute, return. 853 if (out_shape.num_elements() == 0) { 854 return; 855 } 856 TConvFunctor conv_functor; 857 conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows, 858 padded_cols, in_depth, filter.flat<T>().data(), filter_rows, 859 filter_cols, out_depth, stride_rows, stride_cols, padding_, 860 output->flat<T>().data(), out_rows, out_cols, st, top_padding, 861 bottom_padding, left_padding, right_padding, offset_); 862 } 863 864 private: 865 std::vector<int32> strides_; 866 Padding padding_; 867 bool align_corners_; 868 int offset_; 869 870 TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp); 871 }; 872 873 #define REGISTER_FUSED(T) \ 874 REGISTER_KERNEL_BUILDER( \ 875 Name("FusedResizeAndPadConv2D") \ 876 .Device(DEVICE_CPU) \ 877 .TypeConstraint<T>("T"), \ 878 FusedResizeConv2DUsingGemmOp< \ 879 T, \ 880 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \ 881 BILINEAR>, \ 882 true>); 883 884 TF_CALL_float(REGISTER_FUSED); 885 886 #define REGISTER_PAD_ONLY_FUSED(T) \ 887 REGISTER_KERNEL_BUILDER( \ 888 Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 889 FusedResizeConv2DUsingGemmOp< \ 890 T, \ 891 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \ 892 NEAREST>, \ 893 false>); 894 895 TF_CALL_float(REGISTER_PAD_ONLY_FUSED); 896 897 } // namespace tensorflow 898