1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" 17 18 #include <numeric> 19 #include <vector> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/util.h" 27 #include "tensorflow/compiler/xla/window_util.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 32 namespace xla { 33 namespace gpu { 34 35 namespace { 36 37 bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { 38 const ConvolutionDimensionNumbers& dnums = 39 conv->convolution_dimension_numbers(); 40 if (dnums.input_spatial_dimensions_size() > 3) { 41 return false; 42 } 43 44 // CuDNN does not accept zero-element arguments 45 if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || 46 ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { 47 return false; 48 } 49 50 if (window_util::HasWindowReversal(conv->window())) { 51 return false; 52 } 53 return true; 54 } 55 56 // Try to match a backward filter pattern that contains "conv". 57 // Precondition: "conv" is a kConvolution. 58 std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter( 59 HloInstruction* conv) { 60 const auto no_match_result = 61 std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); 62 // Step 1: match the instruction pattern without considering the paddings and 63 // dimension numbers just yet. We may need some generic pattern matcher 64 // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h 65 // 66 // Backward filter convolution is implemented in XLA as the forward 67 // convolution of padded activations and dilated gradients. Padding on 68 // activations and dilation on gradients are specified in the "window" field 69 // of the forward convolution. 70 // 71 // activations gradients 72 // \ / 73 // v v 74 // Convolution 75 // conv 76 CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); 77 78 // Step 2: match paddings and dimension numbers of the forward convolution. 79 const ConvolutionDimensionNumbers& conv_dnums = 80 conv->convolution_dimension_numbers(); 81 auto input_batch_dim = conv_dnums.input_batch_dimension(); 82 auto input_feature_dim = conv_dnums.input_feature_dimension(); 83 auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); 84 auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); 85 auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); 86 auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); 87 auto output_batch_dim = conv_dnums.output_batch_dimension(); 88 auto output_feature_dim = conv_dnums.output_feature_dimension(); 89 auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); 90 91 for (const WindowDimension& window_dim : conv->window().dimensions()) { 92 if (window_dim.stride() != 1) { 93 VLOG(1) << "Forward convolution's window " 94 << conv->window().ShortDebugString() 95 << " should have stride of 1."; 96 return no_match_result; 97 } 98 if (window_dim.base_dilation() != 1) { 99 VLOG(1) << "Forward convolution's window " 100 << conv->window().ShortDebugString() 101 << " should have no base (LHS) dilation."; 102 return no_match_result; 103 } 104 if (window_dim.padding_low() < 0) { 105 VLOG(1) << "Padding low should be non-negative."; 106 return no_match_result; 107 } 108 if (window_dim.window_reversal()) { 109 VLOG(1) << "Window reversal field not supported"; 110 return no_match_result; 111 } 112 // Padding high will be checked in Step 3. 113 } 114 if (input_batch_dim == output_batch_dim && 115 !window_util::HasWindowDilation(conv->window())) { 116 VLOG(1) << conv->ToString() 117 << " is a regular forward convolution. No need " 118 "to fold it to a backward filter convolution."; 119 return no_match_result; 120 } 121 122 // Step 3: fuse the matched HLOs into a backward convolution instruction. 123 // 124 // Compute the window of the backward convolution. 125 Window backward_conv_window; 126 for (int i = 0; i < input_spatial_dims.size(); ++i) { 127 WindowDimension* dim = backward_conv_window.add_dimensions(); 128 // The window size of the backward convolution equals the output size of the 129 // forward convolution. 130 int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); 131 dim->set_size(filter_size); 132 // The window stride equals the window dilation of the forward convolution. 133 dim->set_stride(conv->window().dimensions(i).window_dilation()); 134 // The window's low padding is the same as the low padding of the 135 // activations. 136 dim->set_padding_low(conv->window().dimensions(i).padding_low()); 137 138 int64 input_size = 139 conv->operand(0)->shape().dimensions(input_spatial_dims[i]); 140 int64 output_size = conv->window().dimensions(i).size(); 141 // Compute the range of the amount of valid high padding. We first compute 142 // min_padding_high, the amount of padding on the right/bottom to ensure the 143 // last patch ends at the border, i.e., 144 // 145 // input_size + dim->padding_low() + min_padding_high 146 // = (output_size - 1) * stride + filter_size 147 // 148 // Because convolution ignores trailing incomplete windows, any amount of 149 // padding high from min_padding_high to min_padding_high+stride-1 150 // (max_padding_high) has the same effect. 151 int64 padded_input_size = filter_size + (output_size - 1) * dim->stride(); 152 int64 min_padding_high = 153 padded_input_size - input_size - dim->padding_low(); 154 int64 max_padding_high = min_padding_high + dim->stride() - 1; 155 CHECK_GE(dim->padding_low(), 0); 156 // In practice, since cuDNN convolution only supports even padding, we make 157 // the amount of high padding the same as the amount of low padding as long 158 // as it is between min_padding_high and max_padding_high. If it is not in 159 // that range, we pick the one that's closest to dim->padding_low() and let 160 // PadInsertion canonicalize the resultant backward convolution later. 161 // Picking the closest one minimizes the cost of the kPad instruction to be 162 // inserted by PadInsertion. 163 if (dim->padding_low() >= min_padding_high && 164 dim->padding_low() <= max_padding_high) { 165 dim->set_padding_high(dim->padding_low()); 166 } else { 167 if (dim->padding_low() < min_padding_high) { 168 dim->set_padding_high(min_padding_high); 169 } else { 170 dim->set_padding_high(max_padding_high); 171 } 172 } 173 if (dim->padding_high() < 0) { 174 LOG(ERROR) 175 << "Fusing this pattern to backward filter convolution would cause " 176 "negative padding (" 177 << dim->padding_high() 178 << ") on right/bottom of the weight gradients, which is not " 179 "supported by PadInsertion (b/32744257). Falling back to " 180 "unfused convolution for instruction: " 181 << conv->ToString(); 182 return no_match_result; 183 } 184 } 185 186 // Restore the dimension numbers of the backward convolution from the forward 187 // convolution. The two activation dimensions are reversed (batch and 188 // feature). 189 ConvolutionDimensionNumbers backward_conv_dnums; 190 backward_conv_dnums.set_input_batch_dimension(input_feature_dim); 191 backward_conv_dnums.set_input_feature_dimension(input_batch_dim); 192 for (int i = 0; i < input_spatial_dims.size(); ++i) { 193 backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); 194 } 195 backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); 196 backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); 197 for (int i = 0; i < kernel_spatial_dims.size(); ++i) { 198 backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); 199 } 200 // The dimension numbering of the output of the forward convolution (before 201 // transposition) is the same as that of the activations (according to the 202 // semantics of kConvolution). The batch dimension of the activations should 203 // be treated as the input feature dimension, and the feature dimension should 204 // be treated as the output feature. 205 backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); 206 backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); 207 for (int i = 0; i < output_spatial_dims.size(); ++i) { 208 backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); 209 } 210 211 return std::make_tuple(true, backward_conv_window, backward_conv_dnums); 212 } 213 214 // Try to match a backward input pattern that contains "conv". 215 // Precondition: "conv" is a kConvolution. 216 std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( 217 HloInstruction* conv) { 218 const auto no_match_result = 219 std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); 220 221 // Match instruction pattern. 222 CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); 223 HloInstruction* reverse_filter = conv->mutable_operand(1); 224 225 // Match the reverse of the filter. 226 ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); 227 const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); 228 if (reverse_filter->opcode() == HloOpcode::kReverse) { 229 if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || 230 !std::is_permutation(kernel_spatial_dims.begin(), 231 kernel_spatial_dims.end(), 232 reverse_filter->dimensions().begin())) { 233 VLOG(1) 234 << "Backward input convolution should reverse all kernel dimensions."; 235 return no_match_result; 236 } 237 } else { 238 // Possibly 1x1 filter. 239 for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { 240 if (conv->window().dimensions(i).size() != 1) { 241 VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " 242 << reverse_filter->ToString(); 243 return no_match_result; 244 } 245 } 246 if (!window_util::HasBaseDilation(conv->window())) { 247 VLOG(1) << conv->ToString() 248 << " is a regular forward convolution. No need " 249 "to fold it to a backward input convolution."; 250 return no_match_result; 251 } 252 } 253 254 // Match padding and dilation of the forward convolution. 255 for (const WindowDimension& window_dim : conv->window().dimensions()) { 256 if (window_dim.stride() != 1) { 257 VLOG(1) << "Forward convolution's window " 258 << conv->window().ShortDebugString() 259 << " should have stride of 1."; 260 return no_match_result; 261 } 262 if (window_dim.window_dilation() != 1) { 263 VLOG(1) << "Forward convolution's window " 264 << conv->window().ShortDebugString() 265 << " should have no window dilation."; 266 return no_match_result; 267 } 268 if (window_dim.window_reversal()) { 269 VLOG(1) << "Window reversal field not supported"; 270 return no_match_result; 271 } 272 } 273 274 const auto& input_spatial_dims = dnums.input_spatial_dimensions(); 275 const auto& output_spatial_dims = dnums.output_spatial_dimensions(); 276 CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); 277 CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); 278 279 const Window& old_window = conv->window(); 280 Window new_window = old_window; 281 for (size_t i = 0; i < input_spatial_dims.size(); ++i) { 282 // Restore backward convolution's padding config from the matched pattern. 283 // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc 284 // for how we convert backward input convolution to a variant of forward 285 // convolution. 286 // 287 // The stride of the backward convolution 288 // = the base dilation factor of the forward convolution 289 auto dim = new_window.mutable_dimensions(i); 290 dim->set_stride(old_window.dimensions(i).base_dilation()); 291 292 // The low padding = kernel_size - 1 - low padding on the gradients 293 // Make sure the low padding is not negative. 294 auto kernel_size = old_window.dimensions(i).size(); 295 auto backward_padding_low = 296 kernel_size - 1 - old_window.dimensions(i).padding_low(); 297 if (backward_padding_low < 0) { 298 LOG(ERROR) 299 << "The low padding of the backward convolution would be negative (" 300 << backward_padding_low 301 << "), which isn't supported by PadInsertion for now (b/32744257)."; 302 return no_match_result; 303 } 304 dim->set_padding_low(backward_padding_low); 305 306 // Compute the range of the amount of padding on the right/bottom of the 307 // activations. XLA's convolution requires all patches to be within the 308 // padded base. This gives us flexiblity to choose the amount of high 309 // padding from a set of values without changing the result of the backward 310 // convolution. The minimum amount (min_padding_high) makes the last patch 311 // end at the border. The maximum amount (max_padding_high) equals 312 // min_padding_high+stride-1 -- max_padding_high+1 would cause the output 313 // size to change. 314 auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); 315 auto output_size = 316 conv->operand(0)->shape().dimensions(input_spatial_dims[i]); 317 auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); 318 auto total_pad_size = padded_input_size - unpadded_input_size; 319 auto min_padding_high = total_pad_size - backward_padding_low; 320 auto max_padding_high = min_padding_high + dim->stride() - 1; 321 322 if (backward_padding_low >= min_padding_high && 323 backward_padding_low <= max_padding_high) { 324 // In the best case (most likely), if backward_padding_low is in the range 325 // of the amounts of valid high padding, we choose backward_padding_low 326 // because cuDNN supports even padding only. 327 dim->set_padding_high(backward_padding_low); 328 } else { 329 // Otherwise, we choose the amount that's closest to backward_padding_low, 330 // and PadInsertion will later insert kSlice instructions to enforce even 331 // padding. 332 // 333 // For example, consider the backward convolution pattern 334 // 335 // ab xy 336 // | pad | reverse 337 // .a.b yx 338 // \ / 339 // ABC 340 // 341 // The amount of low padding on activations (in backward convolution) is 342 // backward_padding_low = kernel_size - 1 - forward_padding_low 343 // = 2 - 1 - 1 = 0 344 // 345 // The amount of padding high must be between 1 and 2, in order to make 346 // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in 347 // the range of [1,2], so we pick the closest valid amount of padding 348 // high, which is 1 in this case. Therefore, we fuse the above pattern to 349 // 350 // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1) 351 if (backward_padding_low < min_padding_high) { 352 dim->set_padding_high(min_padding_high); 353 } else { 354 dim->set_padding_high(max_padding_high); 355 } 356 } 357 // PadInsertion doesn't handle backward input convolution with negative 358 // padding for now. So fall back to unfused convolution in case of negative 359 // padding. For example, 360 // ABCD = Conv(abc, reverse(xy), padding_high=2) 361 // could be fused to 362 // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) 363 // with positive padding low but negative padding high. 364 if (dim->padding_high() < 0) { 365 LOG(ERROR) << "Fusing this pattern to backward convolution would cause " 366 "negative padding (" 367 << dim->padding_high() 368 << ") on right/bottom of the activations, which is not " 369 "supported by PadInsertion (b/32744257). Falling back to " 370 "unfused convolution for instruction: " 371 << conv->ToString(); 372 return no_match_result; 373 } 374 } 375 376 // Fuse the matched HLOs into a backward convolution instruction. 377 // 378 // If the reverse is omitted (for 1x1 filters) in the original pattern, we add 379 // it back in the fusion instruction so that later passes (such as 380 // PadInsertion) can handle such fusion instructions easily. 381 if (reverse_filter->opcode() != HloOpcode::kReverse) { 382 reverse_filter = reverse_filter->parent()->AddInstruction( 383 HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, 384 AsInt64Slice(kernel_spatial_dims))); 385 TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); 386 } 387 dnums.set_kernel_input_feature_dimension( 388 conv->convolution_dimension_numbers().kernel_output_feature_dimension()); 389 dnums.set_kernel_output_feature_dimension( 390 conv->convolution_dimension_numbers().kernel_input_feature_dimension()); 391 392 return std::make_tuple(true, new_window, dnums); 393 } 394 395 // Tries to rewrite a single convolution into a call to cudnn. 396 StatusOr<bool> RunOnInstruction(HloInstruction* conv) { 397 CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); 398 399 HloInstruction* custom_call = [&]() -> HloInstruction* { 400 bool match; 401 Window window; 402 ConvolutionDimensionNumbers dnums; 403 404 std::tie(match, window, dnums) = MatchBackwardFilter(conv); 405 if (match) { 406 return CreateCudnnConvBackwardFilter( 407 conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), 408 window, dnums); 409 } 410 411 std::tie(match, window, dnums) = MatchBackwardInput(conv); 412 if (match) { 413 // Backward input conv subsumes the conv plus the reverse in operand 1. 414 HloInstruction* reverse = conv->mutable_operand(1); 415 CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); 416 HloInstruction* rhs = reverse->mutable_operand(0); 417 418 return CreateCudnnConvBackwardInput( 419 conv->shape(), conv->mutable_operand(0), rhs, window, dnums); 420 } 421 422 // If all else fails, try a forward convolution. 423 if (CanImplementAsCudnnForwardConv(conv)) { 424 return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), 425 conv->mutable_operand(1), conv->window(), 426 conv->convolution_dimension_numbers()); 427 } 428 429 return nullptr; 430 }(); 431 432 if (custom_call == nullptr) { 433 return false; 434 } 435 436 // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out 437 // the conv result and replace `conv` with it. 438 TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( 439 conv, 440 HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); 441 return true; 442 } 443 444 // Rewrites the convolutions in the given computation into calls to cudnn. 445 // Returns true if it made any changes. 446 StatusOr<bool> RunOnComputation(HloComputation* computation) { 447 std::vector<HloInstruction*> convs; 448 for (auto* hlo : computation->instructions()) { 449 if (hlo->opcode() == HloOpcode::kConvolution) { 450 convs.push_back(hlo); 451 } 452 } 453 454 bool changed = false; 455 for (HloInstruction* conv : convs) { 456 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); 457 changed |= result; 458 } 459 return changed; 460 } 461 } // namespace 462 463 StatusOr<bool> CudnnConvolutionRewriter::Run(HloModule* module) { 464 bool changed = false; 465 for (HloComputation* computation : module->MakeNonfusionComputations()) { 466 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); 467 changed |= result; 468 } 469 return changed; 470 } 471 472 } // namespace gpu 473 } // namespace xla 474