1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/convolution_group_converter.h" 17 18 #include <memory> 19 #include <vector> 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/literal.h" 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/status_macros.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/util.h" 32 #include "tensorflow/compiler/xla/xla_data.pb.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/platform/logging.h" 36 37 namespace xla { 38 39 namespace { 40 41 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution 42 // operations with feature_group_count > 1 into convolutions with 43 // feature_group_count = 1. 44 class ConvolutionVisitor : public DfsHloVisitorWithDefault { 45 public: 46 // Default visitor action is to do nothing and return OK. 47 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 48 return Status::OK(); 49 } 50 51 Status HandleConvolution(HloInstruction* convolution) override; 52 53 Status HandleBatchGroupCount(HloInstruction* convolution); 54 55 // Runs the visitor on a computation. 56 static bool Run(HloComputation* computation, 57 std::function<bool(HloInstruction*)> is_cost_viable, 58 bool convert_batch_groups_only, 59 bool canonicalize_depthwise_filter); 60 61 // Returns whether any convolution ops were rewritten. 62 const bool changed() const { return changed_; } 63 64 ~ConvolutionVisitor() override = default; 65 66 private: 67 explicit ConvolutionVisitor( 68 HloComputation* computation, 69 std::function<bool(HloInstruction*)> is_cost_viable, 70 bool convert_batch_groups_only, 71 bool canonicalize_depthwise_filter = false) 72 : computation_(computation), 73 filter_expansion_(!canonicalize_depthwise_filter), 74 convert_batch_groups_only_(convert_batch_groups_only), 75 is_cost_viable_(is_cost_viable) {} 76 77 // Current HloComputation instance the ConvolutionVisitor is traversing. 78 HloComputation* computation_; 79 80 // Whether rewrite has occurred. 81 bool changed_ = false; 82 83 // Whether filter expansion is required. 84 bool filter_expansion_; 85 86 // Decides whether to convert batch groups or feature groups. 87 bool convert_batch_groups_only_; 88 89 // std::function<std::vector<LloValue*>(int64, int64)> chunk_fetcher 90 std::function<bool(HloInstruction*)> is_cost_viable_; 91 }; 92 93 bool ConvolutionVisitor::Run( 94 HloComputation* computation, 95 std::function<bool(HloInstruction*)> is_cost_viable, 96 bool convert_batch_groups_only, bool canonicalize_depthwise_filter) { 97 ConvolutionVisitor visitor(computation, is_cost_viable, 98 convert_batch_groups_only, 99 canonicalize_depthwise_filter); 100 TF_CHECK_OK(computation->Accept(&visitor)); 101 return visitor.changed_; 102 } 103 104 Shape ExpandedFilterShape(const Shape& shape, int64 group_count, 105 int64 input_feature_dim) { 106 int64 num_dims = shape.dimensions_size(); 107 CHECK_GE(num_dims, 2); 108 Shape expanded_shape = shape; 109 expanded_shape.set_dimensions( 110 input_feature_dim, shape.dimensions(input_feature_dim) * group_count); 111 return expanded_shape; 112 } 113 114 // Returns a vector with 'group_count' many groups, where the i-th group 115 // consists of 'group_size' times the value i. 116 std::vector<int32> GetMaskIds(int64 group_size, int64 group_count) { 117 std::vector<int32> values; 118 for (int i = 0; i < group_count; ++i) { 119 for (int j = 0; j < group_size; ++j) { 120 values.push_back(i); 121 } 122 } 123 return values; 124 } 125 126 // Create a mask for grouped convolution that will make a normal convolution 127 // produce the same results as a grouped convolution. For a [2, 1, 6] 128 // filter this returns a [2, 3, 6] mask 129 // 1 1 0 0 0 0 130 // 0 0 1 1 0 0 131 // 0 0 0 0 1 1 132 // 133 // 1 1 0 0 0 0 134 // 0 0 1 1 0 0 135 // 0 0 0 0 1 1 136 // 137 // The first step is to create a rank 1 constant: 138 // 0 1 2 139 // 140 // This is broadcasted to 141 // 0 0 0 0 0 0 142 // 1 1 1 1 1 1 143 // 2 2 2 2 2 2 144 // 145 // 0 0 0 0 0 0 146 // 1 1 1 1 1 1 147 // 2 2 2 2 2 2 148 // 149 // Then we create another rank 1 constant 150 // 0 0 1 1 2 2 151 // 152 // This is broadcasted to 153 // 0 0 1 1 2 2 154 // 0 0 1 1 2 2 155 // 0 0 1 1 2 2 156 // 157 // 0 0 1 1 2 2 158 // 0 0 1 1 2 2 159 // 0 0 1 1 2 2 160 // 161 // Finally we use the Eq op of these two broadcasted constants and get the 162 // desired mask. 163 HloInstruction* GetExpandedFilterMask( 164 const Shape& filter_shape, int64 kernel_input_feature_dim, 165 int64 kernel_output_feature_dim, int64 group_count, 166 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 167 add_instruction) { 168 Shape expanded_filter_shape = 169 ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim); 170 Shape mask_shape = ShapeUtil::MakeShape( 171 S32, AsInt64Slice(expanded_filter_shape.dimensions())); 172 int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim); 173 int64 group_size = filter_shape.dimensions(kernel_input_feature_dim); 174 175 // Create a 'input_feature' sized linspace and 'output_feature' sized linspace 176 // that will be broadcasted into perpendicular dimensions and compared. 177 const std::vector<int32> input_feature_filter_mask = 178 GetMaskIds(group_size, group_count); 179 const std::vector<int32> output_feature_filter_mask = 180 GetMaskIds(output_feature / group_count, group_count); 181 auto mask1 = add_instruction(HloInstruction::CreateConstant( 182 LiteralUtil::CreateR1<int32>(input_feature_filter_mask))); 183 auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast( 184 mask_shape, mask1, {kernel_input_feature_dim})); 185 auto mask2 = add_instruction(HloInstruction::CreateConstant( 186 LiteralUtil::CreateR1<int32>(output_feature_filter_mask))); 187 auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast( 188 mask_shape, mask2, {kernel_output_feature_dim})); 189 190 // Compare the broadcasted output feature linspace to the input feature 191 // linspace to create a diagonal predicate. 192 Shape predicate_shape = ShapeUtil::MakeShape( 193 PRED, AsInt64Slice(expanded_filter_shape.dimensions())); 194 return add_instruction(HloInstruction::CreateCompare( 195 predicate_shape, broadcasted_mask1, broadcasted_mask2, 196 ComparisonDirection::kEq)); 197 } 198 199 // This function handles batch_group_counts which are relevant only for 200 // depthwise backprop filter convolutions. 201 Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { 202 auto dim_numbers = convolution->convolution_dimension_numbers(); 203 auto activation = convolution->mutable_operand(0); 204 auto filter = convolution->mutable_operand(1); 205 int64 batch_group_count = convolution->batch_group_count(); 206 207 if (batch_group_count == 1) { 208 return Status::OK(); 209 } 210 211 VLOG(2) << "Dealing with batch_group_count " << batch_group_count 212 << " for convolution " << convolution->ToString() << "\n"; 213 214 auto add = [&](std::unique_ptr<HloInstruction> inst) { 215 return computation_->AddInstruction(std::move(inst)); 216 }; 217 218 int64 input_batch_dimension = dim_numbers.input_batch_dimension(); 219 int64 output_batch_dimension = dim_numbers.output_batch_dimension(); 220 int64 output_feature_dimension = dim_numbers.output_feature_dimension(); 221 222 int64 input_batch = activation->shape().dimensions(input_batch_dimension); 223 224 // We are not yet supporting batch_group of sizes greater than 1. 225 TF_RET_CHECK(input_batch == batch_group_count); 226 227 if (!is_cost_viable_(convolution) || filter_expansion_) { 228 // We first obtain the expanded the filter (which is the convolution 229 // output). The batch dimension is the expanded one (which originally 230 // represents kernel input feature dimension). We mask the filter to zero 231 // out the expanded regions. Next we reduce the filter in the batch 232 // dimension to obtain the original filter size. 233 234 HloInstruction* filter_mask = 235 GetExpandedFilterMask(convolution->shape(), output_batch_dimension, 236 output_feature_dimension, batch_group_count, add); 237 auto expanded_filter_shape = ExpandedFilterShape( 238 convolution->shape(), batch_group_count, output_batch_dimension); 239 240 auto new_convolution = add(HloInstruction::CreateConvolve( 241 expanded_filter_shape, activation, filter, 242 /*feature_group_count=*/1, /*batch_group_count=*/1, 243 convolution->window(), dim_numbers, convolution->precision_config())); 244 245 auto zero = add(HloInstruction::CreateConstant( 246 LiteralUtil::Zero(expanded_filter_shape.element_type()))); 247 auto zero_filter = 248 add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); 249 250 auto new_filter = add(HloInstruction::CreateTernary( 251 expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution, 252 zero_filter)); 253 254 PrimitiveType reduce_type = new_filter->shape().element_type(); 255 auto reduce_window_shape = new_convolution->shape(); 256 reduce_window_shape.set_dimensions(output_batch_dimension, 1); 257 258 // Ensure that data input to reduce window uses at least 32 bits. 259 if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) { 260 reduce_type = F32; 261 reduce_window_shape.set_element_type(F32); 262 Shape convert_shape = new_filter->shape(); 263 convert_shape.set_element_type(F32); 264 new_filter = 265 add(HloInstruction::CreateConvert(convert_shape, new_filter)); 266 } 267 268 auto zero_literal = LiteralUtil::Zero(reduce_type); 269 auto zero_scalar = 270 add(HloInstruction::CreateConstant(std::move(zero_literal))); 271 272 auto reduce_function = [&]() -> HloComputation* { 273 HloComputation::Builder b("add_computation"); 274 Shape shape = ShapeUtil::MakeShape(reduce_type, {}); 275 auto lhs = 276 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); 277 auto rhs = 278 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); 279 auto scalar_op = b.AddInstruction( 280 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); 281 return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); 282 }; 283 284 // Create the reduce window. 285 Window window; 286 for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { 287 auto* dim = window.add_dimensions(); 288 dim->set_padding_low(0); 289 dim->set_padding_high(0); 290 dim->set_window_dilation(1); 291 dim->set_base_dilation(1); 292 if (i == output_batch_dimension) { 293 dim->set_stride(batch_group_count); 294 dim->set_size(batch_group_count); 295 } else { 296 dim->set_stride(1); 297 dim->set_size(1); 298 } 299 } 300 auto reduce_window = add(HloInstruction::CreateReduceWindow( 301 reduce_window_shape, new_filter, zero_scalar, window, 302 reduce_function())); 303 304 Shape convert_back_shape = reduce_window->shape(); 305 convert_back_shape.set_element_type(activation->shape().element_type()); 306 307 // Convert reduced data back to the original data type. 308 auto reduce_window_converted = 309 HloInstruction::CreateConvert(convert_back_shape, reduce_window); 310 311 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 312 convolution, std::move(reduce_window_converted))); 313 changed_ = true; 314 } 315 316 return Status::OK(); 317 } 318 319 Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { 320 if (convert_batch_groups_only_) { 321 return HandleBatchGroupCount(convolution); 322 } 323 324 auto add = [&](std::unique_ptr<HloInstruction> inst) { 325 return computation_->AddInstruction(std::move(inst)); 326 }; 327 328 int64 group_count = convolution->feature_group_count(); 329 if (group_count == 1) { 330 return Status::OK(); 331 } 332 333 changed_ = true; 334 auto dim_numbers = convolution->convolution_dimension_numbers(); 335 auto filter = convolution->mutable_operand(1); 336 int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); 337 int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); 338 int64 kernel_output_feature_dim = 339 dim_numbers.kernel_output_feature_dimension(); 340 auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count, 341 kernel_input_feature_dim); 342 HloInstruction* filter_mask = 343 GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim, 344 kernel_output_feature_dim, group_count, add); 345 HloInstruction* expanded_filter; 346 347 if (group_size == 1) { 348 bool depthwise_separable = 349 (group_count == filter->shape().dimensions(kernel_output_feature_dim)); 350 // If the code generator handles depthwise separable convolutions 351 // inherently, then no filter expansion is needed. 352 if (!filter_expansion_ && depthwise_separable) { 353 changed_ = false; 354 return Status::OK(); 355 } 356 // We want to repeat 'filter' in the 'input_feature_dim' dimension 357 // 'group_count' times. 358 Shape reshaped_filter_shape = 359 ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape()); 360 auto reshaped_filter = 361 add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); 362 std::vector<int64> broadcast_dims; 363 for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { 364 if (i == kernel_input_feature_dim) { 365 continue; 366 } 367 broadcast_dims.push_back(i); 368 } 369 expanded_filter = add(HloInstruction::CreateBroadcast( 370 expanded_filter_shape, reshaped_filter, broadcast_dims)); 371 372 auto zero = add(HloInstruction::CreateConstant( 373 LiteralUtil::Zero(expanded_filter_shape.element_type()))); 374 auto zero_filter = 375 add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); 376 auto new_filter = add(HloInstruction::CreateTernary( 377 expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter, 378 zero_filter)); 379 380 auto new_convolution = HloInstruction::CreateConvolve( 381 convolution->shape(), convolution->mutable_operand(0), new_filter, 382 /*feature_group_count=*/1, /*batch_group_count=*/1, 383 convolution->window(), dim_numbers, convolution->precision_config()); 384 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 385 convolution, std::move(new_convolution))); 386 } else { 387 int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); 388 389 int64 output_feature = 390 filter->shape().dimensions(kernel_output_feature_dim); 391 392 // If group_count == output_feature, then we map those grouped convolutions 393 // onto depthwise convolution. This is done by adding an additional spatial 394 // dimension to the activations, kernel, and the output. 395 // E.g., we would turn 396 // [2, 12]{B, IF} conv [3, 4]{IF, OF} into 397 // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the 398 // additional spatial dimension. The generated convolution output will be 399 // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. 400 401 if (group_count == output_feature && !filter_expansion_) { 402 auto filter = convolution->mutable_operand(1); 403 auto activation = convolution->mutable_operand(0); 404 405 // Add spatial dimension to the activation, and reshape. 406 Shape reshaped_activation_shape = activation->shape(); 407 ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); 408 409 int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; 410 411 reshaped_activation_shape.set_dimensions(activation_input_feature_dim, 412 group_count); 413 activation = add( 414 HloInstruction::CreateReshape(reshaped_activation_shape, activation)); 415 416 // Add spatial dimension to the filter, and reshape. 417 Shape reshaped_filter_shape = filter->shape(); 418 ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); 419 420 filter = 421 add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); 422 423 Shape new_output_shape = convolution->shape(); 424 ShapeUtil::AppendMajorDimension(1, &new_output_shape); 425 426 // Edit convolution dimension numbers. Note that kernel_input_feature_dim 427 // now becomes a spatial dimension, and the newly added dimension of size 428 // 1 is the new kernel_input_feature_dim. 429 dim_numbers.add_input_spatial_dimensions(new_spatial_dim); 430 dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim); 431 dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); 432 dim_numbers.add_output_spatial_dimensions(new_spatial_dim); 433 434 // Add window for the new spatial dimension. 435 Window new_window = convolution->window(); 436 auto* dim = new_window.add_dimensions(); 437 dim->set_window_dilation(1); 438 dim->set_base_dilation(1); 439 dim->set_stride(1); 440 dim->set_size(group_size); 441 442 auto new_convolution = add(HloInstruction::CreateConvolve( 443 new_output_shape, activation, filter, group_count, 444 /*batch_group_count=*/1, new_window, dim_numbers, 445 convolution->precision_config())); 446 447 // Delete the extra spatial dimension, and reshape. 448 Shape reshaped_convolution_shape = 449 ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); 450 auto reshaped_convolution = HloInstruction::CreateReshape( 451 reshaped_convolution_shape, new_convolution); 452 453 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 454 convolution, std::move(reshaped_convolution))); 455 456 } else { 457 // The filter expansion mechanism adds zeroes in the kernel. 458 // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask 459 // would look like (IF on the Y-axis, OF on the X-axis) 460 // 1 1 1 1 0 0 0 0 0 0 0 0 461 // 1 1 1 1 0 0 0 0 0 0 0 0 462 // 0 0 0 0 1 1 1 1 0 0 0 0 463 // 0 0 0 0 1 1 1 1 0 0 0 0 464 // 0 0 0 0 0 0 0 0 1 1 1 1 465 // 0 0 0 0 0 0 0 0 1 1 1 1 466 // 467 // Instead of convolving the above with the input, we instead slice the 468 // kernel into three kernels, each containing islands of 1s from the 469 // filter above. We also slice the activations in the IF dimension with 470 // each slice of size = group_size. For each slice, we perform 471 // convolutions, and concatenate the generated outputs in the output OF 472 // dimension. 473 474 std::vector<HloInstruction*> sliced_convolutions; 475 auto activation = convolution->mutable_operand(0); 476 std::vector<int64> slice_strides(filter->shape().dimensions_size(), 1); 477 std::vector<int64> filter_slice_starts(filter->shape().dimensions_size(), 478 0); 479 std::vector<int64> filter_slice_limits( 480 filter->shape().dimensions().begin(), 481 filter->shape().dimensions().end()); 482 std::vector<int64> activation_slice_starts( 483 activation->shape().dimensions_size(), 0); 484 std::vector<int64> activation_slice_limits( 485 activation->shape().dimensions().begin(), 486 activation->shape().dimensions().end()); 487 488 int64 output_feature = 489 filter->shape().dimensions(kernel_output_feature_dim); 490 auto output_feature_dim = dim_numbers.output_feature_dimension(); 491 int64 filter_slice_width = output_feature / group_count; 492 493 int64 activation_input_feature_dim = 494 dim_numbers.input_feature_dimension(); 495 496 for (int64 i = 0; i < group_count; i++) { 497 filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width; 498 filter_slice_limits[kernel_output_feature_dim] = 499 (i + 1) * filter_slice_width; 500 auto filter_sliced_shape = filter->shape(); 501 filter_sliced_shape.set_dimensions(kernel_output_feature_dim, 502 filter_slice_width); 503 auto filter_slice = add(HloInstruction::CreateSlice( 504 filter_sliced_shape, filter, filter_slice_starts, 505 filter_slice_limits, slice_strides)); 506 507 activation_slice_starts[activation_input_feature_dim] = i * group_size; 508 activation_slice_limits[activation_input_feature_dim] = 509 (i + 1) * group_size; 510 auto activation_sliced_shape = activation->shape(); 511 activation_sliced_shape.set_dimensions(activation_input_feature_dim, 512 group_size); 513 auto activation_slice = add(HloInstruction::CreateSlice( 514 activation_sliced_shape, activation, activation_slice_starts, 515 activation_slice_limits, slice_strides)); 516 517 auto conv_slice_shape = convolution->shape(); 518 conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width); 519 520 auto new_convolution = add(HloInstruction::CreateConvolve( 521 conv_slice_shape, activation_slice, filter_slice, 522 /*feature_group_count=*/1, /*batch_group_count=*/1, 523 convolution->window(), dim_numbers, 524 convolution->precision_config())); 525 526 sliced_convolutions.push_back(new_convolution); 527 } 528 529 auto new_conv = HloInstruction::CreateConcatenate( 530 convolution->shape(), sliced_convolutions, output_feature_dim); 531 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 532 convolution, std::move(new_conv))); 533 } 534 } 535 536 return Status::OK(); 537 } 538 539 } // namespace 540 541 StatusOr<bool> ConvolutionGroupConverter::Run(HloModule* module) { 542 XLA_VLOG_LINES( 543 2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString()); 544 bool changed = false; 545 for (auto* comp : module->MakeNonfusionComputations()) { 546 if (ConvolutionVisitor::Run(comp, is_cost_viable_, 547 convert_batch_groups_only_, 548 filter_expansion_)) { 549 changed = true; 550 } 551 } 552 XLA_VLOG_LINES( 553 2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString()); 554 return changed; 555 } 556 557 } // namespace xla 558