1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define USE_EIGEN_TENSOR 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/kernels/deep_conv2d.h" 20 21 #include <stdlib.h> 22 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/kernels/winograd_transform.h" 25 #include "tensorflow/core/util/work_sharder.h" 26 27 namespace tensorflow { 28 29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e 30 // large 'in_depth' and 'out_depth' product. See cost models below for details). 31 // 32 // DeepConv2D is implemented by computing the following equation: 33 // 34 // y = C[Ad * Bg] 35 // 36 // C: output transform matrix 37 // A: input data transform matrix 38 // B: filter transform matrix 39 // d: vectorized data tile 40 // g: vectorized filter tile 41 // y: vectorized output tile 42 // 43 // The transform matrices and input, filter and output tile sizes are all 44 // specified by the DeepConv2DTransform implementation selected at the 45 // start of the DeepConv2D call, based on convolution parameters. 46 47 // Approximate cost models for direct and deep convolutions. 48 static int64 GetDeepConvCost(int input_tile_rows, int input_tile_cols, 49 int out_tile_rows, int out_tile_cols, int in_depth, 50 int out_depth, int out_rows, int out_cols) { 51 // Input transform cost. 52 const int64 input_tile_spatial_size = input_tile_rows * input_tile_cols; 53 const int64 input_transform_cost = 54 input_tile_spatial_size * input_tile_spatial_size * in_depth; 55 56 // Element-wise products (each product is a MatMul across depth). 57 const int64 product_cost = input_tile_spatial_size * in_depth * out_depth; 58 59 // Output transform cost. 60 const int64 output_tile_spatial_size = out_tile_rows * out_tile_cols; 61 const int64 output_transform_cost = 62 output_tile_spatial_size * input_tile_spatial_size * out_depth; 63 64 // Calculate number of input tiles to process. 65 const int64 row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows; 66 const int64 col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols; 67 const int64 num_tiles = row_tiles * col_tiles; 68 69 // Return total cost. 70 return num_tiles * 71 (input_transform_cost + product_cost + output_transform_cost); 72 } 73 74 static int64 GetDirectConvCost(int filter_rows, int filter_cols, int in_depth, 75 int out_depth, int out_rows, int out_cols) { 76 return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols; 77 } 78 79 // Reads environment variable 'env_var_name'. 80 // Returns 'true' if environment variable is enabled, false otherwise. 81 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) { 82 const char* tf_env_var_val = getenv(env_var_name); 83 if (tf_env_var_val != nullptr) { 84 StringPiece tf_env_var_val_str(tf_env_var_val); 85 if (tf_env_var_val_str == "0") { 86 return false; 87 } 88 return true; 89 } 90 return default_val; 91 } 92 93 // Returns true if convolution can be computed efficiently by DeepConv2D, 94 // returns false otherwise. 95 // TODO(andydavis) Add support for other filter sizes and strides. 96 // TODO(andydavis) Add support for autotuning. 97 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows, 98 int filter_cols, int in_depth, int out_depth, 99 int out_rows, int out_cols) { 100 // Check if convolution parameters are supported. 101 // TODO(andydavis) Add support for multiple filter sizes and strides. 102 if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 || 103 filter_cols != 3) { 104 return false; 105 } 106 107 // Check if deep convolution is enabled by environment variable. 108 // NOTE: IF this environment variable name changes, update conv_ops_test.py. 109 if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) { 110 return false; 111 } 112 113 // Check if flop cost of deep convolution is less than direct convolution. 114 WinogradTransform<float> t; 115 const int64 deep_conv_cost = GetDeepConvCost( 116 t.input_shape().rows, t.input_shape().cols, t.output_shape().rows, 117 t.output_shape().cols, in_depth, out_depth, out_rows, out_cols); 118 const int64 direct_conv_cost = GetDirectConvCost( 119 filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols); 120 121 VLOG(2) << "CanUseDeepConv2D" 122 << " deep_conv_cost: " << deep_conv_cost 123 << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: " 124 << (static_cast<float>(deep_conv_cost) / 125 static_cast<float>(direct_conv_cost)) 126 << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost); 127 return deep_conv_cost < direct_conv_cost; 128 } 129 130 typedef Eigen::ThreadPoolDevice CPUDevice; 131 132 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension. 133 // 134 // filter_in: 135 // [filter_rows, filter_cols, in_depth, out_depth] 136 // 137 // filter_buf: 138 // [base_filter_rows, base_filter_cols, in_depth] 139 // 140 template <typename T> 141 struct CopyFilterDepth { 142 void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) { 143 typedef typename Eigen::internal::packet_traits<T>::type Packet; 144 static constexpr int64 kPacketSize = (sizeof(Packet) / sizeof(T)); 145 146 const int64 vectorized_size = args.in_depth / kPacketSize; 147 const int64 scalar_size = args.in_depth % kPacketSize; 148 const int64 input_stride = args.out_depth * kPacketSize; 149 150 // Copy vectorized portion of depth dimension. 151 for (int64 d = 0; d < vectorized_size; ++d) { 152 auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride, 153 args.out_depth); 154 Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v); 155 } 156 // Copy scalar portion of inner dimension. 157 const int64 in_scalar_base = vectorized_size * input_stride; 158 const int64 buf_scalar_base = vectorized_size * kPacketSize; 159 for (int64 d = 0; d < scalar_size; ++d) { 160 filter_buf[buf_scalar_base + d] = 161 filter_in[in_scalar_base + d * args.out_depth]; 162 } 163 } 164 }; 165 166 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'. 167 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in')) 168 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to 169 // 'filter_out' at the coordinate stride required by the transformed filter 170 // data layout. 171 // 172 // filter_in: 173 // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols, 174 // in_depth] 175 // 176 // filter_out: 177 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] 178 // 179 // transform_matrix: 180 // [tile_spatial_size, base_filter_spatial_size] 181 // 182 // out_buffer: 183 // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth] 184 185 template <typename T> 186 struct ComputeFilterRangeTransform { 187 typedef typename Eigen::internal::packet_traits<T>::type Packet; 188 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); 189 190 typedef Eigen::Map< 191 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 192 MatrixMap; 193 typedef Eigen::Map< 194 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 195 ConstMatrixMap; 196 197 void operator()(const Conv2DArgs& args, 198 const DeepConv2DTransform<T>* transform, const int64 od_start, 199 const int64 num_filters, const int64 shard_rows, 200 const int64 shard_cols, const T* filter_in, 201 const int64 in_stride, const int64 out_stride, 202 const T* transform_matrix, T* out_buffer, T* filter_out) { 203 namespace ei = Eigen::internal; 204 205 const int64 in_depth = args.in_depth; 206 const int64 base_filter_rows = transform->filter_shape().rows; 207 const int64 base_filter_cols = transform->filter_shape().cols; 208 const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols; 209 const int64 tile_rows = transform->input_shape().rows; 210 const int64 tile_cols = transform->input_shape().cols; 211 const int64 tile_spatial_size = tile_rows * tile_cols; 212 213 // Compute transform of 'num_filters' by 'transform_matrix'. 214 ConstMatrixMap A(transform_matrix, tile_spatial_size, 215 base_filter_spatial_size); 216 ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride); 217 MatrixMap C(out_buffer, tile_spatial_size, in_stride); 218 219 C.noalias() = A * B; 220 221 // Copy 'out_buffer' to 'filter_out' at required filter output stride. 222 const int64 scalar_size = in_depth % kPacketSize; 223 const int64 vectorized_size = in_depth / kPacketSize; 224 225 const int64 shard_stride = args.in_depth; 226 const int64 out_depth_stride = shard_rows * shard_cols * shard_stride; 227 228 for (int64 od = 0; od < num_filters; ++od) { 229 const int64 out_depth_buf_base = od * out_depth_stride; 230 const int64 out_depth_base = (od_start + od) * out_depth_stride; 231 232 // TODO(andydavis) Shard filters that are multiples of base filter sizes. 233 for (int64 s_r = 0; s_r < shard_rows; ++s_r) { 234 for (int64 s_c = 0; s_c < shard_cols; ++s_c) { 235 const int64 shard_base = shard_stride * (s_r * shard_cols + s_c); 236 237 for (int64 i = 0; i < tile_spatial_size; ++i) { 238 const int64 in_base = 239 i * in_stride + out_depth_buf_base + shard_base; 240 const int64 out_base = i * out_stride + out_depth_base + shard_base; 241 // Copy vectorized portion of 'in_depth'. 242 for (int64 d = 0; d < vectorized_size; ++d) { 243 auto v = 244 ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize); 245 ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v); 246 } 247 // Transform scalar portion of 'in_depth'. 248 const int64 scalar_base = vectorized_size * kPacketSize; 249 for (int64 d = 0; d < scalar_size; ++d) { 250 filter_out[out_base + scalar_base + d] = 251 out_buffer[in_base + scalar_base + d]; 252 } 253 } 254 } 255 } 256 } 257 } 258 }; 259 260 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'. 261 // For each filter in 'num_filters', copies data for all filter shards from 262 // 'filter_in' into 'filter_buf', adding zero-padding as needed. 263 // Calls ComputeFilterRangeTransform to compute filter transform of data 264 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'. 265 // 266 // filter_in: 267 // [filter_rows, filter_cols, in_depth, out_depth] 268 // 269 // filter_out: 270 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] 271 // 272 // filter_buffer: 273 // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols, 274 // in_depth] 275 // 276 // transform_matrix: 277 // [tile_spatial_size, base_filter_spatial_size] 278 // 279 // out_buffer: 280 // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth] 281 // 282 283 template <typename T> 284 struct TransformFilterRange { 285 void operator()(const Conv2DArgs& args, 286 const DeepConv2DTransform<T>* transform, const int64 od_start, 287 const int64 od_limit, const T* filter_in, 288 const T* transform_matrix, T* out_buffer, T* filter_buf, 289 T* filter_out) { 290 const int64 num_filters = od_limit - od_start; 291 const int64 base_filter_rows = transform->filter_shape().rows; 292 const int64 base_filter_cols = transform->filter_shape().cols; 293 const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols; 294 295 // Compute number of filter shards. 296 const int64 residual_row = 297 std::max(0LL, args.filter_rows - base_filter_rows); 298 const int64 shard_rows = 1 + (residual_row + 2 - 1) / 2; 299 300 const int64 residual_col = 301 std::max(0LL, args.filter_cols - base_filter_cols); 302 const int64 shard_cols = 1 + (residual_col + 2 - 1) / 2; 303 304 // Compute strides to be used for input and output IO. 305 const int64 shard_stride = args.in_depth; 306 const int64 out_depth_stride = shard_rows * shard_cols * shard_stride; 307 const int64 coord_stride = out_depth_stride * args.out_depth; 308 const int64 filter_buf_stride = 309 num_filters * shard_rows * shard_cols * args.in_depth; 310 const int64 tile_stride_rows = transform->output_shape().rows; 311 const int64 tile_stride_cols = transform->output_shape().cols; 312 313 const int64 filter_buf_size = base_filter_spatial_size * num_filters * 314 shard_rows * shard_cols * args.in_depth; 315 memset(filter_buf, 0, sizeof(T) * filter_buf_size); 316 317 // Copy filter range into 'filter_buf'. 318 for (int64 od = 0; od < num_filters; ++od) { 319 const int64 out_depth_base = od * out_depth_stride; 320 321 // TODO(andydavis) Shard filters that are multiples of base filter sizes. 322 for (int64 s_r = 0; s_r < shard_rows; ++s_r) { 323 const int64 row_offset = s_r == 0 ? 0 : 1; 324 325 for (int64 s_c = 0; s_c < shard_cols; ++s_c) { 326 const int64 col_offset = s_c == 0 ? 0 : 1; 327 const int64 f_r_start = s_r * tile_stride_rows; 328 const int64 f_c_start = s_c * tile_stride_cols; 329 330 const int64 shard_base = shard_stride * (s_r * shard_cols + s_c); 331 332 for (int64 b_r = row_offset; b_r < base_filter_rows; ++b_r) { 333 const int64 f_r = f_r_start + b_r; 334 if (f_r >= args.filter_rows) continue; 335 336 for (int64 b_c = col_offset; b_c < base_filter_cols; ++b_c) { 337 const int64 f_c = f_c_start + b_c; 338 if (f_c >= args.filter_cols) continue; 339 340 const int64 in_index = 341 args.out_depth * 342 (args.in_depth * (f_r * args.filter_cols + f_c)) + 343 (od_start + od); 344 345 const int64 buf_index = 346 filter_buf_stride * (b_r * base_filter_cols + b_c) + 347 out_depth_base + shard_base; 348 349 CopyFilterDepth<T>()(args, filter_in + in_index, 350 filter_buf + buf_index); 351 } 352 } 353 } 354 } 355 } 356 357 // Compute filter transform of data in 'filter_buf' by 'transform_matrix'. 358 // Intermediate results are stored in 'out_buffer'. 359 // Final results are stored in 'filter_out'. 360 ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters, 361 shard_rows, shard_cols, filter_buf, 362 filter_buf_stride, coord_stride, 363 transform_matrix, out_buffer, filter_out); 364 } 365 }; 366 367 // Transforms all filters from 'filter_in', storing result in 'filter_out'. 368 // 369 // filter_in: 370 // [filter_rows, filter_cols, in_depth, out_depth] 371 // 372 // filter_out: 373 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] 374 // 375 template <typename T> 376 struct TransformFilters { 377 void operator()(OpKernelContext* ctx, const Conv2DArgs& args, 378 const DeepConv2DTransform<T>* transform, 379 const int64 filter_shards_row, const int64 filter_shards_col, 380 const T* filter_in, T* filter_out) { 381 const int64 in_depth = args.in_depth; 382 const int64 out_depth = args.out_depth; 383 384 const int64 tile_rows = transform->input_shape().rows; 385 const int64 tile_cols = transform->input_shape().cols; 386 const int64 tile_spatial_size = tile_rows * tile_cols; 387 388 const int64 base_filter_rows = transform->filter_shape().rows; 389 const int64 base_filter_cols = transform->filter_shape().cols; 390 const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols; 391 392 const int64 filter_shards_total = filter_shards_row * filter_shards_col; 393 394 // Calculate filter transform batch based on cache/filter sizes. 395 396 // Cache budget (based on L2 cache size = 256KB). 397 // TODO(andydavis) Read cache size from system. 398 const int64 cache_size = (256LL << 10) / sizeof(T); 399 400 // Fixed cost. 401 const int64 filter_transform_matrix_size = 402 tile_spatial_size * base_filter_spatial_size; 403 404 // Per-filter costs. 405 const int64 filter_total_size = 406 base_filter_spatial_size * in_depth * filter_shards_total; 407 408 const int64 filter_transform_buffer_size = 409 base_filter_spatial_size * filter_shards_total * in_depth; 410 411 const int64 filter_out_buf_size = 412 tile_spatial_size * filter_shards_total * in_depth; 413 414 // Total per-filter costs. 415 const int64 per_filter_cost = 416 filter_total_size + filter_transform_buffer_size + filter_out_buf_size; 417 418 // Remove fixed cost and divide by per-filter cost. 419 const int64 num_filters_cache = std::max( 420 1LL, (cache_size - filter_transform_matrix_size) / per_filter_cost); 421 const int64 num_filters_transform = std::min(out_depth, num_filters_cache); 422 423 // Allocate buffer for filter transform matrix: 424 // [tile_spatial_size, base_filter_spatial_size] 425 Tensor filter_transform_matrix; 426 OP_REQUIRES_OK( 427 ctx, ctx->allocate_temp( 428 DataTypeToEnum<T>::value, 429 TensorShape({tile_spatial_size, base_filter_spatial_size}), 430 &filter_transform_matrix)); 431 T* transform_matrix = filter_transform_matrix.template flat<T>().data(); 432 transform->GetFilterTransformMatrix( 433 tile_spatial_size, base_filter_spatial_size, transform_matrix); 434 435 auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols, 436 &num_filters_transform, &in_depth, &out_depth, 437 &filter_shards_row, &filter_shards_col, &tile_spatial_size, 438 &filter_in, &transform_matrix, 439 &filter_out](int64 start, int64 limit) { 440 // Allocate buffer for pre-processed filter: 441 // [base_filter_rows, base_filter_cols, num_filters_transform, in_depth] 442 // 443 Tensor filter_transform_buffer; 444 OP_REQUIRES_OK(ctx, 445 ctx->allocate_temp( 446 DataTypeToEnum<T>::value, 447 TensorShape({base_filter_rows, base_filter_cols, 448 num_filters_transform, filter_shards_row, 449 filter_shards_col, in_depth}), 450 &filter_transform_buffer)); 451 T* filter_buf = filter_transform_buffer.template flat<T>().data(); 452 453 // Allocate buffer for output filter transform matrix: 454 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] 455 Tensor filter_output_buffer; 456 OP_REQUIRES_OK( 457 ctx, 458 ctx->allocate_temp( 459 DataTypeToEnum<T>::value, 460 TensorShape({tile_spatial_size, num_filters_transform, 461 filter_shards_row, filter_shards_col, in_depth}), 462 &filter_output_buffer)); 463 T* out_buffer = filter_output_buffer.template flat<T>().data(); 464 465 const int64 num_filters = limit - start; 466 const int64 od_unroll = num_filters_transform; 467 const int64 od_unroll_limit = (num_filters / od_unroll) * od_unroll; 468 469 for (int64 od = start; od < od_unroll_limit; od += od_unroll) { 470 TransformFilterRange<T>()(args, transform, od, od + od_unroll, 471 filter_in, transform_matrix, out_buffer, 472 filter_buf, filter_out); 473 } 474 475 if (od_unroll_limit < limit) { 476 TransformFilterRange<T>()(args, transform, od_unroll_limit, limit, 477 filter_in, transform_matrix, out_buffer, 478 filter_buf, filter_out); 479 } 480 }; 481 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 482 483 const int64 shard_cost = args.filter_rows * args.filter_cols * in_depth * 484 filter_shards_total * tile_spatial_size; 485 // TODO(andydavis) Resolve performance of multi-threaded filter transforms. 486 Shard(1, worker_threads.workers, out_depth, shard_cost, shard); 487 } 488 }; 489 490 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a 491 // gemm-kernel friendly data layout. 492 // 493 // Data layout for 'lhs_block': 494 // [out_depth, shard_rows, shard_cols, in_depth]. 495 496 template <typename T> 497 class GemmFilterPacker { 498 public: 499 typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor> 500 LhsMapper; 501 typedef Eigen::internal::gebp_traits<T, T> Traits; 502 Eigen::internal::gemm_pack_lhs<T, int64, LhsMapper, Traits::mr, 503 Traits::LhsProgress, Eigen::RowMajor> 504 pack_lhs; 505 506 GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input, 507 T* lhs_block) 508 : rows_(rows), 509 depth_(depth), 510 lhs_block_(lhs_block), 511 lhs_mapper_(lhs_input, depth_) {} 512 513 void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); } 514 515 private: 516 const int64 rows_; 517 const int64 depth_; 518 T* lhs_block_; 519 LhsMapper lhs_mapper_; 520 }; 521 522 // Packs transformed filter stored in 'filter_transform_data' into 523 // 'packed_filters' to be used by GemmState. 524 template <typename T> 525 struct PackFilters { 526 void operator()(OpKernelContext* ctx, const Conv2DArgs& args, 527 const int64 tile_spatial_size, const int64 filter_shards_row, 528 const int64 filter_shards_col, const T* filter_transform_data, 529 std::vector<Tensor>* packed_filters) { 530 const int64 in_depth = args.in_depth; 531 const int64 out_depth = args.out_depth; 532 const int64 num_filters = filter_shards_row * filter_shards_col * out_depth; 533 534 auto shard = [&ctx, &packed_filters, &filter_transform_data, 535 &tile_spatial_size, &in_depth, &out_depth, &filter_shards_row, 536 &filter_shards_col, &num_filters](int64 start, int64 limit) { 537 const int64 filter_coord_stride = num_filters * in_depth; 538 for (int64 i = start; i < limit; ++i) { 539 // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth]. 540 OP_REQUIRES_OK( 541 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 542 TensorShape({out_depth, filter_shards_row, 543 filter_shards_col, in_depth}), 544 &(*packed_filters)[i])); 545 T* packed_filter = (*packed_filters)[i].template flat<T>().data(); 546 // Pack filters. 547 GemmFilterPacker<T> packer( 548 num_filters, in_depth, 549 filter_transform_data + i * filter_coord_stride, packed_filter); 550 packer.Run(); 551 } 552 }; 553 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 554 Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size, 555 num_filters * in_depth, shard); 556 } 557 }; 558 559 // Computes the product of filters stored in 'lhs_block' and input tiles 560 // stored in 'rhs_block', storing output in 'out_buffer'. 561 // 562 // Data layout for 'lhs_block': 563 // [out_depth, shard_rows, shard_cols, in_depth]. 564 // 565 // Data layout for 'rhs_block': 566 // [num_tiles, in_depth] 567 // 568 // Data layout for 'out_buffer': 569 // [num_tiles, out_depth, shard_rows, shard_cols] 570 571 template <typename T> 572 class GemmState { 573 public: 574 typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::ColMajor> 575 RhsMapper; 576 typedef Eigen::internal::blas_data_mapper<T, int64, Eigen::ColMajor> 577 OutputMapper; 578 typedef Eigen::internal::gebp_traits<T, T> Traits; 579 580 Eigen::internal::gemm_pack_rhs<T, int64, RhsMapper, Traits::nr, 581 Eigen::ColMajor> 582 pack_rhs; 583 Eigen::internal::gebp_kernel<T, T, int64, OutputMapper, Traits::mr, 584 Traits::nr, false, false> 585 gebp; 586 587 GemmState(const int64 rows, const int64 cols, const int64 depth, 588 const int64 out_buffer_size, const T* lhs_block, const T* rhs_input, 589 T* rhs_block, T* out_buffer) 590 : rows_(rows), 591 cols_(cols), 592 depth_(depth), 593 out_buffer_size_(out_buffer_size), 594 lhs_block_(lhs_block), 595 rhs_block_(rhs_block), 596 out_buffer_(out_buffer), 597 rhs_mapper_(rhs_input, depth_), 598 out_mapper_(out_buffer, rows_) {} 599 600 void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); } 601 602 void Compute() { 603 memset(out_buffer_, 0, sizeof(T) * out_buffer_size_); 604 gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0); 605 } 606 607 private: 608 const int64 rows_; 609 const int64 cols_; 610 const int64 depth_; 611 const int64 out_buffer_size_; 612 const T* lhs_block_; 613 T* rhs_block_; 614 T* out_buffer_; 615 RhsMapper rhs_mapper_; 616 OutputMapper out_mapper_; 617 }; 618 619 // Copies an input tile from 'input' into 'tile_buffer'. 620 // 621 // input: 622 // [in_rows, in_cols, in_depth] 623 // 624 // tile_buffer: 625 // [tile_rows, tile_cols, num_tiles, in_depth] 626 627 template <typename T> 628 struct CopyInputTile { 629 void operator()(const Conv2DArgs& args, 630 const DeepConv2DTransform<T>* transform, 631 const int64 num_tiles, const int64 in_r_start, 632 const int64 in_c_start, const T* input, T* tile_buffer) { 633 typedef typename Eigen::internal::packet_traits<T>::type Packet; 634 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); 635 636 const int64 tile_rows = transform->input_shape().rows; 637 const int64 tile_cols = transform->input_shape().cols; 638 const int64 coord_stride = num_tiles * args.in_depth; 639 640 // Calculate vectorized and scalar (residual) lengths for 'in_depth'. 641 const int64 input_vectorized_size = 642 (args.in_depth / kPacketSize) * kPacketSize; 643 const int64 input_scalar_size = args.in_depth % kPacketSize; 644 645 for (int64 r = 0; r < tile_rows; ++r) { 646 const int64 in_r = in_r_start + r; 647 if (in_r < 0 || in_r >= args.in_rows) continue; 648 649 for (int64 c = 0; c < tile_cols; ++c) { 650 const int64 in_c = in_c_start + c; 651 if (in_c < 0 || in_c >= args.in_cols) continue; 652 653 auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth; 654 auto* tile = tile_buffer + coord_stride * (r * tile_rows + c); 655 // Copy vectorized portion of depth dimension. 656 for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) { 657 auto v = Eigen::internal::ploadu<Packet>(in + d); 658 Eigen::internal::pstoreu<T>(tile, v); 659 tile += kPacketSize; 660 } 661 // Copy scalar portion of inner dimension. 662 for (int64 d = 0; d < input_scalar_size; ++d) { 663 tile[d] = in[input_vectorized_size + d]; 664 } 665 } 666 } 667 } 668 }; 669 670 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the 671 // final result in 'tile_transform'. 672 // Intermediate results are stored in 'tile_buffer'. 673 // 674 // input: 675 // [in_rows, in_cols, in_depth] 676 // tile_buffer: 677 // [tile_rows, tile_cols, num_tiles, in_depth] 678 // tile_transform_matrix: 679 // [tile_spatial_size, tile_spatial_size] 680 // tile_transform: 681 // [tile_rows, tile_cols, num_tiles, in_depth] 682 683 template <typename T> 684 struct TransformInputTiles { 685 typedef Eigen::Map< 686 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 687 MatrixMap; 688 typedef Eigen::Map< 689 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 690 ConstMatrixMap; 691 692 void operator()(const Conv2DArgs& args, 693 const DeepConv2DTransform<T>* transform, 694 const int64 num_tiles, const int64 in_r_start, 695 const int64 in_c_start, const T* input, 696 const T* transform_matrix, T* tile_buffer, 697 T* tile_transform) { 698 const int64 tile_rows = transform->input_shape().rows; 699 const int64 tile_cols = transform->input_shape().cols; 700 const int64 tile_spatial_size = tile_rows * tile_cols; 701 const int64 tile_stride_cols = transform->output_shape().cols; 702 const int64 coord_stride = num_tiles * args.in_depth; 703 const int64 num_tiles_stride = args.in_depth; 704 705 memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride); 706 const int64 in_r = in_r_start; 707 for (int64 t = 0; t < num_tiles; ++t) { 708 const int64 num_tiles_base = t * num_tiles_stride; 709 const int64 in_c = in_c_start + t * tile_stride_cols; 710 CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input, 711 tile_buffer + num_tiles_base); 712 } 713 714 ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size); 715 ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride); 716 MatrixMap C(tile_transform, tile_spatial_size, coord_stride); 717 718 C.noalias() = A * B; 719 } 720 }; 721 722 // Transforms output tiles from buffer by 'out_transform_matrix', storing 723 // final result in 'output' (intermediate results stored in 'out_buffer'). 724 // 725 // out_buffer: 726 // [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols] 727 // 728 // output transform buffer: 729 // [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols] 730 // 731 // output: 732 // [out_rows, out_cols, out_depth] 733 // 734 735 template <typename T> 736 struct TransformOutputTile { 737 typedef Eigen::Map< 738 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 739 MatrixMap; 740 typedef Eigen::Map< 741 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 742 ConstMatrixMap; 743 744 void operator()(const Conv2DArgs& args, 745 const DeepConv2DTransform<T>* transform, 746 const int64 num_tiles, const int64 in_r, const int64 in_c, 747 const int64 filter_shards_row, const int64 filter_shards_col, 748 const T* out_transform_matrix, const T* out_buffer, 749 T* out_transform_buffer, T* output) { 750 const int64 tile_rows = transform->input_shape().rows; 751 const int64 tile_cols = transform->input_shape().cols; 752 const int64 tile_spatial_size = tile_rows * tile_cols; 753 754 const int64 out_buf_stride = 755 num_tiles * args.out_depth * filter_shards_row * filter_shards_col; 756 757 const int64 out_tile_rows = transform->output_shape().rows; 758 const int64 out_tile_cols = transform->output_shape().cols; 759 const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols; 760 761 // Compute output transform. 762 ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size, 763 tile_spatial_size); 764 ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride); 765 MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride); 766 767 C.noalias() = A * B; 768 769 const int64 tile_stride_rows = transform->output_shape().rows; 770 const int64 tile_stride_cols = transform->output_shape().cols; 771 772 const int64 out_depth_stride = filter_shards_row * filter_shards_col; 773 const int64 num_tiles_stride = args.out_depth * out_depth_stride; 774 775 // Copy transformed output from 'out_transform_buffer' to proper index 776 // in 'output'. Note that some outputs at boundaries can be discarded. 777 for (int64 t = 0; t < num_tiles; ++t) { 778 const int64 tile_base = t * num_tiles_stride; 779 780 for (int64 od = 0; od < args.out_depth; ++od) { 781 const int64 out_depth_base = od * out_depth_stride; 782 783 // TODO(andydavis) Update filter sharding scheme in the next CL. 784 for (int64 sr = 0; sr < filter_shards_row; ++sr) { 785 for (int64 sc = 0; sc < filter_shards_col; ++sc) { 786 const int64 shard_base = sr * filter_shards_col + sc; 787 const int64 out_buf_base = tile_base + out_depth_base + shard_base; 788 789 // Calcuate output indices and outputs to drop (if needed). 790 const int64 out_r_start = 791 in_r + args.pad_rows - sr * tile_stride_rows; 792 // NOTE: The index 't' for 'num_tiles is used in index calculation 793 // for 'out_c_start' because we 'num_tiles' progresses along the 794 // column dimension. 795 const int64 out_c_start = (in_c + t * tile_stride_cols) + 796 args.pad_cols - sc * tile_stride_cols; 797 798 if (out_r_start < 0 || out_r_start >= args.out_rows || 799 out_c_start < 0 || out_c_start >= args.out_cols) { 800 continue; // Skip un-needed outputs. 801 } 802 803 // Increment output if not first filter shard. 804 const bool inc_output = (sr == 0 && sc == 0) ? false : true; 805 806 for (int64 ot_row = 0; ot_row < out_tile_rows; ++ot_row) { 807 const int64 out_r = out_r_start + ot_row; 808 if (out_r >= args.out_rows) continue; 809 810 for (int64 ot_col = 0; ot_col < out_tile_cols; ++ot_col) { 811 const int64 out_c = out_c_start + ot_col; 812 if (out_c >= args.out_cols) continue; 813 814 // Calculate out tile indexl 815 const int64 out_buf_index = ot_row * out_tile_cols + ot_col; 816 // Read output value from buffer. 817 const T out_val = 818 out_transform_buffer[out_buf_base + 819 out_buf_index * out_buf_stride]; 820 // Calculate output index. 821 const int64 output_index = 822 args.out_depth * (out_r * args.out_cols + out_c) + od; 823 // Update output. 824 if (inc_output) { 825 output[output_index] += out_val; 826 } else { 827 output[output_index] = out_val; 828 } 829 } 830 } 831 } 832 } 833 } 834 } 835 } 836 }; 837 838 template <typename T> 839 struct Conv2DState { 840 Conv2DState(const int64 tile_spatial_size, const int64 filter_shards_row, 841 const int64 filter_shards_col, const T* input, 842 const T* tile_transform_matrix, const T* output_transform_matrix, 843 T* buffer1, T* buffer2, T* packed_tile_buffer, 844 T* gemm_output_buffer) 845 : tile_spatial_size(tile_spatial_size), 846 filter_shards_row(filter_shards_row), 847 filter_shards_col(filter_shards_col), 848 input(input), 849 tile_transform_matrix(tile_transform_matrix), 850 output_transform_matrix(output_transform_matrix), 851 buffer1(buffer1), 852 buffer2(buffer2), 853 packed_tile_buffer(packed_tile_buffer), 854 gemm_output_buffer(gemm_output_buffer) {} 855 856 const int64 tile_spatial_size; 857 const int64 filter_shards_row; 858 const int64 filter_shards_col; 859 const T* input; 860 const T* tile_transform_matrix; 861 const T* output_transform_matrix; 862 T* buffer1; 863 T* buffer2; 864 T* packed_tile_buffer; 865 T* gemm_output_buffer; 866 }; 867 868 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at 869 // (in_r, in_c), storing the results of the computation in 'output'. 870 // Details: 871 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'. 872 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters. 873 // *) Transforms output tiles, and stores result to 'output'. 874 875 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions. 876 template <typename T> 877 struct ComputeConv2D { 878 void operator()(const Conv2DArgs& args, 879 const DeepConv2DTransform<T>* transform, 880 const Conv2DState<T>& cs, const int64 in_r, const int64 in_c, 881 const int64 num_tiles, 882 const std::vector<Tensor>& packed_filters, const T* input, 883 T* output) { 884 // Transform input tiles. 885 TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input, 886 cs.tile_transform_matrix, cs.buffer1, cs.buffer2); 887 888 // Compute element-wise product (each a MatMul): input tiles X filters. 889 const int64 in_depth = args.in_depth; 890 const int64 out_depth = args.out_depth; 891 const int64 num_filters = 892 cs.filter_shards_row * cs.filter_shards_col * out_depth; 893 const int64 tile_coord_stride = num_tiles * in_depth; 894 const int64 gemm_out_buf_size = num_tiles * num_filters; 895 const int64 gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T); 896 897 for (int64 i = 0; i < cs.tile_spatial_size; ++i) { 898 GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size, 899 packed_filters[i].template flat<T>().data(), 900 cs.buffer2 + i * tile_coord_stride, 901 cs.packed_tile_buffer, cs.gemm_output_buffer); 902 // Pack tile buffer. 903 gemm.PackRhs(); 904 // Compute product. 905 gemm.Compute(); 906 // Copy to larger output buffer without alignment requirements. 907 memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer, 908 gemm_out_buf_bytes); 909 } 910 911 // Transform output. 912 TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c, 913 cs.filter_shards_row, cs.filter_shards_col, 914 cs.output_transform_matrix, cs.buffer1, cs.buffer2, 915 output); 916 } 917 }; 918 919 namespace functor { 920 921 // Conv2D operation specialized for deep convolutions (i.e. large 922 // in_depth * out_depth). 923 // Details: 924 // *) Transforms and packs filters from 'filter' in parallel. 925 // *) Computes Conv2D parallelized across 'batch' dimension. 926 // *) Each thread loops over images in its batch shard, copying 'num_tiles' 927 // input tiles into a local buffer, and computing the Conv2D output of 928 // these tiles by all filters. 929 930 // TODO(andydavis) Improve the performance of boundary cases where the input 931 // tile extends past the limit, and wasted outputs are computed. This overhead 932 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse 933 // for smaller spatial sizes. 934 // TODO(andydavis) Improve the performance of sharded filters. 935 template <typename T> 936 struct DeepConv2D<CPUDevice, T> { 937 void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input, 938 const T* filter, T* output) { 939 // TODO(andydavis) Add function to select transform based on conv params. 940 std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>); 941 942 const int64 in_depth = args.in_depth; 943 const int64 out_depth = args.out_depth; 944 945 const int64 tile_rows = transform->input_shape().rows; 946 const int64 tile_cols = transform->input_shape().cols; 947 const int64 tile_spatial_size = tile_rows * tile_cols; 948 949 const int64 out_tile_rows = transform->output_shape().rows; 950 const int64 out_tile_cols = transform->output_shape().cols; 951 const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols; 952 953 const int64 base_filter_rows = transform->filter_shape().rows; 954 955 const int64 filter_residual_row = 956 std::max(0LL, args.filter_rows - base_filter_rows); 957 const int64 filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2; 958 959 const int64 filter_residual_col = 960 std::max(0LL, args.filter_cols - base_filter_rows); 961 const int64 filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2; 962 963 // Allocate buffer for transformed filters. 964 Tensor filter_transform; 965 OP_REQUIRES_OK( 966 ctx, ctx->allocate_temp( 967 DataTypeToEnum<T>::value, 968 TensorShape({tile_rows, tile_cols, out_depth, 969 filter_shards_row, filter_shards_col, in_depth}), 970 &filter_transform)); 971 T* filter_transform_data = filter_transform.template flat<T>().data(); 972 973 // Transform filters. 974 TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row, 975 filter_shards_col, filter, filter_transform_data); 976 977 // Pack filters. 978 std::vector<Tensor> packed_filters(tile_spatial_size); 979 PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row, 980 filter_shards_col, filter_transform_data, &packed_filters); 981 982 // Allocate buffer for tile transform matrix. 983 Tensor tile_transform_matrix_tensor; 984 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 985 DataTypeToEnum<T>::value, 986 TensorShape({tile_spatial_size, tile_spatial_size}), 987 &tile_transform_matrix_tensor)); 988 T* tile_transform_matrix = 989 tile_transform_matrix_tensor.template flat<T>().data(); 990 transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size, 991 tile_transform_matrix); 992 993 // Allocate buffer for output transform matrix. 994 Tensor output_transform_matrix_tensor; 995 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 996 TensorShape({out_tile_spatial_size, 997 tile_spatial_size}), 998 &output_transform_matrix_tensor)); 999 T* output_transform_matrix = 1000 output_transform_matrix_tensor.template flat<T>().data(); 1001 transform->GetOutputTransformMatrix( 1002 out_tile_spatial_size, tile_spatial_size, output_transform_matrix); 1003 1004 auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth, 1005 out_depth, tile_rows, tile_cols, out_tile_rows, out_tile_cols, 1006 filter_shards_row, filter_shards_col, tile_spatial_size, 1007 &input, &tile_transform_matrix, &output_transform_matrix, 1008 &output](int64 batch_start, int64 batch_limit) { 1009 const int64 row_tiles = 1010 (args.out_rows + out_tile_rows - 1) / out_tile_rows + 1011 filter_shards_row - 1; 1012 const int64 col_tiles = 1013 (args.out_cols + out_tile_cols - 1) / out_tile_cols + 1014 filter_shards_col - 1; 1015 1016 // Calculate number of tiles to process together. 1017 const int64 filter_shard_size = filter_shards_row * filter_shards_col; 1018 const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols; 1019 1020 // Cache budget (based on L2 cache size = 256KB). 1021 // TODO(andydavis) Read cache size from the system. 1022 const int64 cache_size = (256LL << 10) / sizeof(T); 1023 1024 // Fixed costs. 1025 const int64 tile_transform_matrix_size = 1026 tile_spatial_size * tile_spatial_size; 1027 const int64 output_transform_matrix_size = 1028 out_tile_spatial_size * tile_spatial_size; 1029 // Calculate cache reserve size. 1030 const int64 filter_depth_size = in_depth * out_depth * filter_shard_size; 1031 const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25; 1032 const int64 cache_reserve_size = small_filter ? filter_depth_size : 1024; 1033 // Calculate total fixed cost. 1034 const int64 total_fixed_cost = tile_transform_matrix_size + 1035 output_transform_matrix_size + 1036 cache_reserve_size; 1037 1038 // Per-tile costs. 1039 const int64 buffer1_per_tile_size = 1040 tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size); 1041 const int64 buffer2_per_tile_size = 1042 std::max(tile_spatial_size * in_depth, 1043 out_tile_spatial_size * out_depth * filter_shard_size); 1044 const int64 packed_tile_per_tile_size = in_depth; 1045 const int64 gemm_out_per_tile_size = out_depth * filter_shard_size; 1046 const int64 total_per_tile_cost = 1047 buffer1_per_tile_size + buffer2_per_tile_size + 1048 packed_tile_per_tile_size + gemm_out_per_tile_size; 1049 1050 const int64 num_tiles_cache = 1051 std::max(4LL, (cache_size - total_fixed_cost) / total_per_tile_cost); 1052 const int64 num_tiles = std::min(num_tiles_cache, col_tiles); 1053 1054 // Allocate temporary buffer 'buffer1', which is first used for copying 1055 // input tiles, then re-used to buffer gemm output. Calculate the 1056 // required buffer size for 'buffer1', based on max buffer size required 1057 // between copying input tiles and buffering gemm product output. 1058 // buffer1: [max(buf1_tile_size, buf1_out_size)] 1059 const int64 buffer1_tile_size = tile_spatial_size * num_tiles * in_depth; 1060 const int64 buffer1_out_size = 1061 tile_spatial_size * num_tiles * out_depth * filter_shard_size; 1062 const int64 buffer1_size = std::max(buffer1_tile_size, buffer1_out_size); 1063 Tensor buffer1_tensor; 1064 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 1065 TensorShape({buffer1_size}), 1066 &buffer1_tensor)); 1067 T* buffer1 = buffer1_tensor.template flat<T>().data(); 1068 1069 // Allocate temporary buffer 'buffer2', which is first used for 1070 // transformed input tiles, then re-used for transformed output tiles. 1071 // Calculate required buffer size for 'buffer2' as max required buffer 1072 // between input and output transform buffer sizes. 1073 const int64 buffer2_tile_transform_size = 1074 tile_spatial_size * num_tiles * in_depth; 1075 const int64 buffer2_out_transform_size = 1076 out_tile_spatial_size * num_tiles * out_depth * filter_shard_size; 1077 const int64 buffer2_size = 1078 std::max(buffer2_tile_transform_size, buffer2_out_transform_size); 1079 Tensor buffer2_tensor; 1080 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 1081 TensorShape({buffer2_size}), 1082 &buffer2_tensor)); 1083 T* buffer2 = buffer2_tensor.template flat<T>().data(); 1084 1085 // Allocate temporary buffer to store packed tiles for one coordinate. 1086 // packed tile buffer: [num_tiles, in_depth]. 1087 Tensor packed_tile_tensor; 1088 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 1089 TensorShape({num_tiles, in_depth}), 1090 &packed_tile_tensor)); 1091 T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data(); 1092 1093 // Allocate temporary buffer for gemm output. 1094 // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols]. 1095 Tensor gemm_output_tensor; 1096 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 1097 TensorShape({num_tiles, out_depth, 1098 filter_shards_row, 1099 filter_shards_col}), 1100 &gemm_output_tensor)); 1101 T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data(); 1102 1103 // Capture state needed for ComputeConv2D inner loop. 1104 Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row, 1105 filter_shards_col, input, tile_transform_matrix, 1106 output_transform_matrix, buffer1, buffer2, 1107 packed_tile_buffer, gemm_output_buffer); 1108 1109 const int64 row_pad = args.pad_rows; 1110 const int64 col_pad = args.pad_cols; 1111 const int64 unroll_col_limit = (col_tiles / num_tiles) * num_tiles; 1112 1113 const int64 input_image_size = args.in_rows * args.in_cols * in_depth; 1114 const int64 output_image_size = args.out_rows * args.out_cols * out_depth; 1115 1116 const int64 tile_stride_rows = transform->output_shape().rows; 1117 const int64 tile_stride_cols = transform->output_shape().cols; 1118 1119 for (int64 b = batch_start; b < batch_limit; ++b) { 1120 const int64 in_base = b * input_image_size; 1121 const int64 out_base = b * output_image_size; 1122 1123 for (int64 tile_r = 0; tile_r < row_tiles; ++tile_r) { 1124 const int64 in_r = tile_r * tile_stride_rows - row_pad; 1125 1126 // Process unrolled tiles. 1127 for (int64 tile_c = 0; tile_c < unroll_col_limit; 1128 tile_c += num_tiles) { 1129 const int64 in_c = tile_c * tile_stride_cols - col_pad; 1130 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c, 1131 num_tiles, packed_filters, input + in_base, 1132 output + out_base); 1133 } 1134 // Process remaining tiles. 1135 if (unroll_col_limit < col_tiles) { 1136 const int64 rem_tiles = col_tiles - unroll_col_limit; 1137 const int64 in_c = unroll_col_limit * tile_stride_cols - col_pad; 1138 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c, 1139 rem_tiles, packed_filters, input + in_base, 1140 output + out_base); 1141 } 1142 } 1143 } 1144 }; 1145 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 1146 const int64 shard_cost = args.out_rows * args.out_cols * args.out_depth * 1147 tile_spatial_size * args.in_depth; 1148 Shard(worker_threads.num_threads, worker_threads.workers, args.batch, 1149 shard_cost, shard); 1150 } 1151 }; 1152 1153 } // namespace functor 1154 1155 template struct functor::DeepConv2D<CPUDevice, float>; 1156 1157 } // namespace tensorflow 1158