1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" 17 #include "tensorflow/compiler/xla/literal_util.h" 18 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 19 20 namespace xla { 21 namespace gpu { 22 namespace { 23 24 class Visitor : public DfsHloVisitorWithDefault { 25 public: 26 explicit Visitor(HloComputation* computation) : computation_(computation) {} 27 28 static bool Run(HloComputation* computation) { 29 Visitor visitor(computation); 30 TF_CHECK_OK(computation->Accept(&visitor)); 31 return visitor.changed_; 32 } 33 34 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 35 return Status::OK(); 36 } 37 38 Status HandleBatchNormInference(HloInstruction* batch_norm) override; 39 Status HandleBatchNormTraining(HloInstruction* batch_norm) override; 40 Status HandleBatchNormGrad(HloInstruction* batch_norm) override; 41 42 private: 43 bool changed_ = false; 44 HloComputation* computation_; 45 }; 46 47 // cudnn defines CUDNN_BN_MIN_EPSILON = 1e-5 as the minimum acceptable epsilon 48 // for calls to its batchnorm ops. 49 bool EpsilonInRange(HloInstruction* batch_norm) { 50 return batch_norm->epsilon() >= 1e-5; 51 } 52 53 Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { 54 if (batch_norm->operand(0)->shape().element_type() != F32) { 55 VLOG(1) << "Not rewriting op with non-F32 element type: " 56 << batch_norm->ToString(); 57 return Status::OK(); 58 } 59 60 // cudnn errors out on zero-sized inputs. 61 if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { 62 return Status::OK(); 63 } 64 65 if (!EpsilonInRange(batch_norm)) { 66 return Status::OK(); 67 } 68 69 HloInstruction* epsilon = computation_->AddInstruction( 70 HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); 71 HloInstruction* feature_index = 72 computation_->AddInstruction(HloInstruction::CreateConstant( 73 Literal::CreateR0(batch_norm->feature_index()))); 74 75 std::vector<HloInstruction*> operands(batch_norm->operands().begin(), 76 batch_norm->operands().end()); 77 operands.push_back(epsilon); 78 operands.push_back(feature_index); 79 80 std::unique_ptr<HloInstruction> libcall = HloInstruction::CreateCustomCall( 81 batch_norm->shape(), operands, kCudnnBatchNormForwardInferenceCallTarget); 82 TF_RETURN_IF_ERROR( 83 computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall))); 84 changed_ = true; 85 return Status::OK(); 86 } 87 88 Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { 89 if (batch_norm->operand(0)->shape().element_type() != F32) { 90 VLOG(1) << "Not rewriting op with non-F32 element type: " 91 << batch_norm->ToString(); 92 return Status::OK(); 93 } 94 95 // cudnn errors out on zero-sized inputs. 96 if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { 97 return Status::OK(); 98 } 99 100 if (!EpsilonInRange(batch_norm)) { 101 return Status::OK(); 102 } 103 104 HloInstruction* epsilon = computation_->AddInstruction( 105 HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); 106 HloInstruction* feature_index = 107 computation_->AddInstruction(HloInstruction::CreateConstant( 108 Literal::CreateR0(batch_norm->feature_index()))); 109 110 std::vector<HloInstruction*> operands(batch_norm->operands().begin(), 111 batch_norm->operands().end()); 112 operands.push_back(epsilon); 113 operands.push_back(feature_index); 114 115 HloInstruction* libcall = 116 computation_->AddInstruction(HloInstruction::CreateCustomCall( 117 batch_norm->shape(), operands, 118 kCudnnBatchNormForwardTrainingCallTarget)); 119 120 // The cudnn libcall returns a tuple 121 // {output, mean, rsqrt(variance + epsilon)}, 122 // but the batchnorm HLO returns {output, mean, variance}. Fix it up. 123 HloInstruction* inverse_stddev = 124 computation_->AddInstruction(HloInstruction::CreateGetTupleElement( 125 libcall->shape().tuple_shapes(2), libcall, 2)); 126 HloInstruction* variance_plus_epsilon = 127 computation_->AddInstruction(HloInstruction::CreateBinary( 128 inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, 129 computation_->AddInstruction( 130 HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))))); 131 HloInstruction* variance = 132 computation_->AddInstruction(HloInstruction::CreateBinary( 133 variance_plus_epsilon->shape(), HloOpcode::kSubtract, 134 variance_plus_epsilon, epsilon)); 135 136 // Repackage the results. 137 std::unique_ptr<HloInstruction> new_tuple = HloInstruction::CreateTuple({ 138 computation_->AddInstruction(HloInstruction::CreateGetTupleElement( 139 libcall->shape().tuple_shapes(0), libcall, 0)), 140 computation_->AddInstruction(HloInstruction::CreateGetTupleElement( 141 libcall->shape().tuple_shapes(1), libcall, 1)), 142 variance, 143 }); 144 145 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 146 batch_norm, std::move(new_tuple))); 147 changed_ = true; 148 return Status::OK(); 149 } 150 151 Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { 152 if (batch_norm->operand(0)->shape().element_type() != F32) { 153 VLOG(1) << "Not rewriting op with non-F32 element type: " 154 << batch_norm->ToString(); 155 return Status::OK(); 156 } 157 158 // cudnn errors out on zero-sized inputs. 159 if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { 160 return Status::OK(); 161 } 162 163 if (!EpsilonInRange(batch_norm)) { 164 return Status::OK(); 165 } 166 167 HloInstruction* epsilon = computation_->AddInstruction( 168 HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); 169 HloInstruction* feature_index = 170 computation_->AddInstruction(HloInstruction::CreateConstant( 171 Literal::CreateR0(batch_norm->feature_index()))); 172 173 // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but 174 // the batchnorm HLO takes plain variance as input. Fix it up. 175 HloInstruction* var_plus_epsilon = 176 computation_->AddInstruction(HloInstruction::CreateBinary( 177 batch_norm->operand(3)->shape(), HloOpcode::kAdd, 178 batch_norm->mutable_operand(3), epsilon)); 179 HloInstruction* inverse_stddev = 180 computation_->AddInstruction(HloInstruction::CreateBinary( 181 var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, 182 computation_->AddInstruction( 183 HloInstruction::CreateConstant(Literal::CreateR0<float>(-.5))))); 184 185 std::vector<HloInstruction*> operands(batch_norm->operands().begin(), 186 batch_norm->operands().end()); 187 operands[3] = inverse_stddev; 188 operands.push_back(epsilon); 189 operands.push_back(feature_index); 190 191 std::unique_ptr<HloInstruction> libcall = HloInstruction::CreateCustomCall( 192 batch_norm->shape(), operands, kCudnnBatchNormBackwardCallTarget); 193 194 TF_RETURN_IF_ERROR( 195 computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall))); 196 changed_ = true; 197 return Status::OK(); 198 } 199 200 } // anonymous namespace 201 202 StatusOr<bool> CudnnBatchNormRewriter::Run(HloModule* module) { 203 VLOG(2) << "CudnnBatchNormRewriter::Run(), before:"; 204 XLA_VLOG_LINES(2, module->ToString()); 205 206 bool changed = false; 207 for (auto* comp : module->MakeNonfusionComputations()) { 208 if (Visitor::Run(comp)) { 209 changed = true; 210 } 211 } 212 213 VLOG(2) << "CudnnBatchNormRewriter::Run(), after:"; 214 XLA_VLOG_LINES(2, module->ToString()); 215 return changed; 216 } 217 218 } // namespace gpu 219 } // namespace xla 220