1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <numeric> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/compiler/xla/layout_util.h" 26 #include "tensorflow/compiler/xla/literal_util.h" 27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/service/hlo_query.h" 32 #include "tensorflow/compiler/xla/service/shape_inference.h" 33 #include "tensorflow/compiler/xla/shape_util.h" 34 #include "tensorflow/compiler/xla/status_macros.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/window_util.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/gtl/array_slice.h" 42 #include "tensorflow/core/lib/gtl/optional.h" 43 #include "tensorflow/core/platform/logging.h" 44 #include "tensorflow/core/platform/types.h" 45 46 namespace xla { 47 namespace { 48 49 // Returns whether operand is a literal with the given value. 50 bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { 51 return operand->opcode() == HloOpcode::kConstant && 52 operand->literal().IsAll(value); 53 } 54 55 bool IsAll(const HloInstruction* op, int8 value) { 56 if (IsLiteralWithValue(op, value)) { 57 return true; 58 } 59 if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { 60 return true; 61 } 62 return false; 63 } 64 65 // Returns whether the given transpose produces a result which is bit-wise 66 // identical to its operand and thus may be replaced with a bitcast. 67 bool TransposeIsBitcast(const HloInstruction* transpose) { 68 CHECK_EQ(HloOpcode::kTranspose, transpose->opcode()); 69 const HloInstruction* operand = transpose->operand(0); 70 return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(), 71 transpose->dimensions()); 72 } 73 74 // Returns true if the given reshape produces a result which is bit-wise 75 // identical to its operand and thus may be replaced with a bitcast. 76 // 77 // This function is conservative -- even if this function returns false, the 78 // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. 79 bool ReshapeIsBitcast( 80 const HloInstruction* reshape, 81 const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { 82 CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); 83 84 const HloInstruction* operand = reshape->operand(0); 85 // Can't insert bitcasts if the compiler used a memory layout which isn't 86 // compatible. 87 return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && 88 valid_bitcast_callback(operand->shape(), reshape->shape()); 89 } 90 91 // Adds a scalar computation to the module to enable optimizations with dot 92 // converting into reduction. 93 HloComputation* CreateScalarBinaryComputation(HloModule* module, 94 PrimitiveType primitive_type, 95 HloOpcode opcode) { 96 HloComputation::Builder b("scalar_computation"); 97 auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( 98 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); 99 auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( 100 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); 101 auto scalar_op = b.AddInstruction( 102 HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), 103 opcode, scalar_lhs, scalar_rhs)); 104 HloComputation* scalar_computation = 105 module->AddEmbeddedComputation(b.Build(scalar_op)); 106 return scalar_computation; 107 } 108 } // namespace 109 110 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain 111 // algebraic expressions to simplified forms. Note: This only supports 112 // simplifications that simply look at the operands of an instruction. For the 113 // more general case a worklist based approach would be needed. 114 class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { 115 public: 116 // Default visitor action is to do nothing and return OK. 117 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 118 return Status::OK(); 119 } 120 121 Status HandleAdd(HloInstruction* add) override; 122 123 Status HandleBitcast(HloInstruction* bitcast) override; 124 125 Status HandleBroadcast(HloInstruction* broadcast) override; 126 127 Status HandleConcatenate(HloInstruction* concatenate) override; 128 129 Status HandleConstant(HloInstruction* constant) override; 130 131 Status HandleCopy(HloInstruction* copy) override; 132 133 Status HandleConvert(HloInstruction* convert) override; 134 135 Status HandleComplex(HloInstruction* complex) override; 136 137 Status HandleReal(HloInstruction* real) override; 138 139 Status HandleImag(HloInstruction* imag) override; 140 141 Status HandleConvolution(HloInstruction* convolution) override; 142 143 Status HandleDivide(HloInstruction* divide) override; 144 145 Status HandleDot(HloInstruction* dot) override; 146 147 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 148 149 Status HandleLog(HloInstruction* log) override; 150 151 Status HandleMultiply(HloInstruction* multiply) override; 152 153 Status HandlePad(HloInstruction* pad) override; 154 155 Status HandlePower(HloInstruction* power) override; 156 157 Status HandleReshape(HloInstruction* reshape) override; 158 159 Status HandleReduce(HloInstruction* reduce) override; 160 161 Status HandleReduceWindow(HloInstruction* reduce_window) override; 162 163 Status HandleReverse(HloInstruction* reverse) override; 164 Status HandleSlice(HloInstruction* slice) override; 165 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 166 Status HandleDynamicUpdateSlice( 167 HloInstruction* dynamic_update_slice) override; 168 169 Status HandleTranspose(HloInstruction* transpose) override; 170 171 Status HandleSubtract(HloInstruction* sub) override; 172 173 Status HandleMaximum(HloInstruction* maximum) override; 174 Status HandleMinimum(HloInstruction* minimum) override; 175 176 // Returns whether algebraic simplification has occurred. 177 const bool changed() const { return changed_; } 178 179 // Runs the visitor on a computation. 180 static bool Run( 181 HloComputation* computation, bool is_layout_sensitive, 182 AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, 183 bool enable_dot_strength_reduction, bool enable_conv_simplification); 184 185 private: 186 explicit AlgebraicSimplifierVisitor( 187 HloComputation* computation, bool is_layout_sensitive, 188 AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, 189 bool enable_dot_strength_reduction, bool enable_conv_simplification) 190 : computation_(computation), 191 is_layout_sensitive_(is_layout_sensitive), 192 valid_bitcast_callback_(std::move(valid_bitcast_callback)), 193 enable_dot_strength_reduction_(enable_dot_strength_reduction), 194 enable_conv_simplification_(enable_conv_simplification) {} 195 196 // Transforms Dots where at least one input is a vector or has a degenerate 197 // dimension and converts it into a multiply and reduce. This should enable 198 // more fusion than leaving the nodes as Dot operations. 199 StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot); 200 201 // Reshapes an instruction to rank 1 if it is not already rank 1. 202 HloInstruction* Flatten(HloInstruction* hlo) { 203 if (ShapeUtil::Rank(hlo->shape()) == 1) { 204 return hlo; 205 } 206 return computation_->AddInstruction(HloInstruction::CreateReshape( 207 ShapeUtil::MakeShape(hlo->shape().element_type(), 208 {ShapeUtil::ElementsIn(hlo->shape())}), 209 hlo)); 210 } 211 212 // Helper method to perform and add reduction in a single dimension. 213 HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { 214 HloInstruction* zero = computation_->AddInstruction( 215 HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); 216 HloComputation* AddReduce_computation = CreateScalarBinaryComputation( 217 computation_->parent(), F32, HloOpcode::kAdd); 218 Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); 219 return computation_->AddInstruction(HloInstruction::CreateReduce( 220 shape, hlo, zero, {dim}, AddReduce_computation)); 221 } 222 223 // Convenience method for replacing an instruction with a bitcast. 224 void ReplaceWithBitcast(HloInstruction* instruction); 225 226 // Replace old instruction with new instruction if old and new instructions 227 // have the same shape. Updates uses and root instruction. Returns whether a 228 // replacement was made. 229 bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction, 230 HloInstruction* new_instruction); 231 232 // Returns whether the shape of the output of the given instructions are the 233 // same for the purposes of simplification. If is_layout_sensitive_ is true, 234 // then this tests shape equality including layout (ShapeUtil::Equal). If 235 // is_layout_sensitive_ is false, then the tests shape compatibility 236 // (ShapeUtil::Compatible). 237 bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; 238 239 // Returns whether it was possible to transform `root` to a clamp instruction. 240 // With min a minimum instruction, max a maximum instruction, min_operand a 241 // operand of min and max_operand a operand of max. 242 // Precondition: root is either a minimum or a maximum. 243 bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min, 244 HloInstruction* min_operand, 245 HloInstruction* operand, HloInstruction* max, 246 HloInstruction* max_operand); 247 248 // A Reshape or Broadcast that feeds an element-wise operation with a unique 249 // non-scalar operand can sink to after the operation. 250 StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( 251 HloInstruction* reshape_or_broadcast); 252 253 // Replaces the existing HLO instruction old_instruction, with 254 // new_instruction, and marks the optimizer status as changed. 255 // Returns the Status representing the result of the replace operation. 256 Status ReplaceWithNewInstruction( 257 HloInstruction* old_instruction, 258 std::unique_ptr<HloInstruction> new_instruction) { 259 VLOG(3) << "Replacing instruction:"; 260 VLOG(3) << " old: " << old_instruction->ToString(); 261 VLOG(3) << " new: " << new_instruction->ToString(); 262 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 263 old_instruction, std::move(new_instruction))); 264 changed_ = true; 265 return Status::OK(); 266 } 267 268 // Replaces the existing HLO instruction old_instruction, with 269 // new_instruction, and marks the optimizer status as changed. 270 // Returns the Status representing the result of the replace operation. 271 Status ReplaceInstruction(HloInstruction* old_instruction, 272 HloInstruction* new_instruction) { 273 VLOG(3) << "Replacing instruction:"; 274 VLOG(3) << " old: " << old_instruction->ToString(); 275 VLOG(3) << " new: " << new_instruction->ToString(); 276 TF_RETURN_IF_ERROR( 277 computation_->ReplaceInstruction(old_instruction, new_instruction)); 278 changed_ = true; 279 return Status::OK(); 280 } 281 282 StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); 283 StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( 284 const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, 285 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); 286 287 // Current HloComputation instance the AlgebraicSimplifierVisitor is 288 // traversing. 289 HloComputation* computation_; 290 291 // Whether algebraic simplification has occurred. 292 bool changed_ = false; 293 294 // Whether layout is considered during transformation. 295 bool is_layout_sensitive_; 296 297 // Callback used to determine if a bitcast is possible. 298 AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; 299 300 // Disable dot strength reduction on platforms where it causes a slowdown. 301 bool enable_dot_strength_reduction_; 302 303 // Disable convolution simplication on platforms where it causes a slowdown. 304 bool enable_conv_simplification_; 305 }; 306 307 bool AlgebraicSimplifierVisitor::Run( 308 HloComputation* computation, bool is_layout_sensitive, 309 AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, 310 bool enable_dot_strength_reduction, bool enable_conv_simplification) { 311 AlgebraicSimplifierVisitor visitor( 312 computation, is_layout_sensitive, std::move(valid_bitcast_callback), 313 enable_dot_strength_reduction, enable_conv_simplification); 314 TF_CHECK_OK(computation->Accept(&visitor)); 315 return visitor.changed_; 316 } 317 318 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, 319 const HloInstruction* rhs) const { 320 if (is_layout_sensitive_) { 321 return ShapeUtil::Equal(lhs->shape(), rhs->shape()); 322 } else { 323 return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); 324 } 325 } 326 327 void AlgebraicSimplifierVisitor::ReplaceWithBitcast( 328 HloInstruction* instruction) { 329 CHECK_EQ(1, instruction->operand_count()); 330 CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), 331 ShapeUtil::ElementsIn(instruction->operand(0)->shape())); 332 CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), 333 ShapeUtil::ByteSizeOf(instruction->operand(0)->shape())); 334 335 auto bitcast = computation_->AddInstruction( 336 HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, 337 instruction->mutable_operand(0))); 338 TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); 339 } 340 341 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( 342 HloInstruction* old_instruction, HloInstruction* new_instruction) { 343 if (!SameShape(old_instruction, new_instruction)) { 344 return false; 345 } 346 TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction)); 347 return true; 348 } 349 350 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { 351 auto lhs = add->mutable_operand(0); 352 auto rhs = add->mutable_operand(1); 353 // A + 0 => A 354 VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); 355 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { 356 return Status::OK(); 357 } 358 // 0 + A => A 359 VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); 360 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { 361 return Status::OK(); 362 } 363 364 // Canonicalization: Put constants on the right. This makes the reassociation 365 // rules below simpler. 366 VLOG(10) << "trying transform [Const + A => A + Const]"; 367 if (lhs->IsConstant() && !rhs->IsConstant()) { 368 return ReplaceWithNewInstruction( 369 add, 370 HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); 371 } 372 373 // Reassociate to allow constant folding. 374 // 375 // Note: This is not general. For example, we won't reassociate 376 // 377 // (A + C1) + (B + C2) => A + B + (C1 + C2). 378 // 379 VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; 380 if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && 381 !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { 382 auto* c1 = lhs->mutable_operand(1); 383 auto* c2 = rhs; 384 TF_ASSIGN_OR_RETURN( 385 Shape sum_of_constants_shape, 386 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2)); 387 388 auto* sum_of_constants = 389 computation_->AddInstruction(HloInstruction::CreateBinary( 390 sum_of_constants_shape, HloOpcode::kAdd, c1, c2)); 391 return ReplaceWithNewInstruction( 392 add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, 393 lhs->mutable_operand(0), 394 sum_of_constants)); 395 } 396 397 return Status::OK(); 398 } 399 400 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { 401 // If a bitcast feeds a bitcast, make it a single bitcast. 402 if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { 403 return ReplaceWithNewInstruction( 404 bitcast, HloInstruction::CreateUnary( 405 bitcast->shape(), HloOpcode::kBitcast, 406 bitcast->mutable_operand(0)->mutable_operand(0))); 407 } 408 // All bitcasts can be eliminated (assuming layout constraints are 409 // satisified). 410 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); 411 return Status::OK(); 412 } 413 414 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { 415 // If a copy feeds a copy, make it a single copy. 416 if (copy->operand(0)->opcode() == HloOpcode::kCopy) { 417 return ReplaceWithNewInstruction( 418 copy, HloInstruction::CreateUnary( 419 copy->shape(), HloOpcode::kCopy, 420 copy->mutable_operand(0)->mutable_operand(0))); 421 } 422 // All copies can be eliminated (assuming layout constraints are satisified). 423 ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); 424 return Status::OK(); 425 } 426 427 Status AlgebraicSimplifierVisitor::HandleConcatenate( 428 HloInstruction* concatenate) { 429 tensorflow::gtl::ArraySlice<HloInstruction*> operands( 430 concatenate->operands()); 431 if (operands.size() == 1) { 432 // Unary concatenates are useless. 433 ReplaceInstructionIfSameShape(concatenate, operands[0]); 434 return Status::OK(); 435 } 436 // Filter out and remove empty operands. 437 std::vector<HloInstruction*> nonempty_operands; 438 for (HloInstruction* operand : operands) { 439 if (!ShapeUtil::HasZeroElements(operand->shape())) { 440 nonempty_operands.push_back(operand); 441 } 442 } 443 if (nonempty_operands.size() < operands.size()) { 444 HloInstruction* replacement; 445 if (nonempty_operands.empty()) { 446 replacement = operands[0]; 447 } else if (nonempty_operands.size() == 1) { 448 replacement = nonempty_operands[0]; 449 } else { 450 replacement = 451 computation_->AddInstruction(concatenate->CloneWithNewOperands( 452 concatenate->shape(), nonempty_operands)); 453 } 454 VLOG(10) << "trying to replace " << concatenate->ToString() << " with " 455 << replacement->ToString(); 456 ReplaceInstructionIfSameShape(concatenate, replacement); 457 } else if (operands.size() == 2) { 458 // A binary concat with a broadcasted scalar as an operand can be converted 459 // into a pad which is simpler to fold into other operations. 460 bool is_effective_low_pad = 461 operands[0]->opcode() == HloOpcode::kBroadcast && 462 ShapeUtil::IsScalar(operands[0]->operand(0)->shape()); 463 bool is_effective_high_pad = 464 operands[1]->opcode() == HloOpcode::kBroadcast && 465 ShapeUtil::IsScalar(operands[1]->operand(0)->shape()); 466 if (!is_effective_low_pad && !is_effective_high_pad) { 467 return Status::OK(); 468 } 469 PaddingConfig padding_config; 470 for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) { 471 auto padding_config_dim = padding_config.add_dimensions(); 472 padding_config_dim->set_edge_padding_high(0); 473 padding_config_dim->set_edge_padding_low(0); 474 padding_config_dim->set_interior_padding(0); 475 if (dim == concatenate->concatenate_dimension()) { 476 if (is_effective_low_pad) { 477 padding_config_dim->set_edge_padding_low( 478 operands[0]->shape().dimensions(dim)); 479 } else { 480 padding_config_dim->set_edge_padding_high( 481 operands[1]->shape().dimensions(dim)); 482 } 483 } 484 } 485 int64 operand_to_pad = is_effective_low_pad ? 1 : 0; 486 int64 pad_value_operand = is_effective_low_pad ? 0 : 1; 487 HloInstruction* pad = 488 computation_->AddInstruction(HloInstruction::CreatePad( 489 concatenate->shape(), operands[operand_to_pad], 490 operands[pad_value_operand]->mutable_operand(0), padding_config)); 491 return ReplaceInstruction(concatenate, pad); 492 } 493 return Status::OK(); 494 } 495 496 static HloInstruction* BuildTupleConstant(HloComputation* computation, 497 const Literal& literal) { 498 if (ShapeUtil::IsTuple(literal.shape())) { 499 std::vector<HloInstruction*> elems; 500 elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); 501 for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { 502 elems.push_back( 503 BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); 504 } 505 return computation->AddInstruction(HloInstruction::CreateTuple(elems)); 506 } else { 507 return computation->AddInstruction( 508 HloInstruction::CreateConstant(literal.CloneToUnique())); 509 } 510 } 511 512 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { 513 // Tuple constants aren't directly supported by any backend. Expand them into 514 // explicit Tuple instructions. 515 if (ShapeUtil::IsTuple(constant->shape())) { 516 return ReplaceInstruction( 517 constant, BuildTupleConstant(computation_, constant->literal())); 518 } 519 return Status::OK(); 520 } 521 522 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { 523 auto lhs = sub->mutable_operand(0); 524 auto rhs = sub->mutable_operand(1); 525 // A - 0 => A 526 VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); 527 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { 528 return Status::OK(); 529 } 530 531 // Canonicalize subtraction of a constant to addition. 532 VLOG(10) << "trying transform [A - Const => A + (-Const)]"; 533 if (rhs->IsConstant() && !lhs->IsConstant()) { 534 HloInstruction* negative_const = computation_->AddInstruction( 535 HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); 536 return ReplaceWithNewInstruction( 537 sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, 538 negative_const)); 539 } 540 541 return Status::OK(); 542 } 543 544 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { 545 auto lhs = divide->mutable_operand(0); 546 auto rhs = divide->mutable_operand(1); 547 // A/1 => A 548 VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); 549 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { 550 return Status::OK(); 551 } 552 553 // exp(A)/exp(B) => exp(A-B) 554 if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { 555 VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); 556 HloInstruction* subtract = 557 computation_->AddInstruction(HloInstruction::CreateBinary( 558 divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), 559 rhs->mutable_operand(0))); 560 return ReplaceWithNewInstruction( 561 divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, 562 subtract)); 563 } 564 565 // A/exp(B) => A*exp(-B) 566 if (rhs->opcode() == HloOpcode::kExp) { 567 VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); 568 HloInstruction* negate = 569 computation_->AddInstruction(HloInstruction::CreateUnary( 570 divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); 571 HloInstruction* new_exp = computation_->AddInstruction( 572 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); 573 return ReplaceWithNewInstruction( 574 divide, HloInstruction::CreateBinary( 575 divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); 576 } 577 578 // A/pow(B,C) => A*pow(B,-C) 579 if (rhs->opcode() == HloOpcode::kPower) { 580 VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); 581 // The output shape of the created negate operator should be the same as the 582 // input. 583 const Shape& negate_shape = rhs->operand(1)->shape(); 584 HloInstruction* negate = 585 computation_->AddInstruction(HloInstruction::CreateUnary( 586 negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); 587 // And the power operator should retain the output shape of the old one. 588 const Shape& new_power_shape = rhs->shape(); 589 HloInstruction* new_power = computation_->AddInstruction( 590 HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, 591 rhs->mutable_operand(0), negate)); 592 return ReplaceWithNewInstruction( 593 divide, HloInstruction::CreateBinary( 594 divide->shape(), HloOpcode::kMultiply, lhs, new_power)); 595 } 596 597 // Simplifying integral division would produce unexpected results. 598 if (ShapeUtil::ElementIsIntegral(divide->shape())) { 599 return Status::OK(); 600 } 601 602 // A / Const => A * (1 / Const) 603 // 604 // (Backends can do this transformation, but generally only if the constant is 605 // a scalar.) 606 if (lhs->opcode() != HloOpcode::kConstant && 607 rhs->opcode() == HloOpcode::kConstant) { 608 HloInstruction* one = 609 computation_->AddInstruction(HloInstruction::CreateConstant( 610 Literal::One(lhs->shape().element_type()).CloneToUnique())); 611 HloInstruction* inverse = 612 computation_->AddInstruction(HloInstruction::CreateBinary( 613 rhs->shape(), HloOpcode::kDivide, one, rhs)); 614 return ReplaceWithNewInstruction( 615 divide, HloInstruction::CreateBinary( 616 divide->shape(), HloOpcode::kMultiply, lhs, inverse)); 617 } 618 619 // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) 620 if (lhs->opcode() == HloOpcode::kDivide && 621 rhs->opcode() == HloOpcode::kDivide) { 622 TF_ASSIGN_OR_RETURN( 623 const Shape a_times_d_shape, 624 ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, 625 lhs->operand(0), rhs->operand(1))); 626 auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( 627 a_times_d_shape, HloOpcode::kMultiply, lhs->mutable_operand(0), 628 rhs->mutable_operand(1))); 629 TF_ASSIGN_OR_RETURN( 630 const Shape b_times_c_shape, 631 ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, 632 lhs->operand(1), rhs->operand(0))); 633 auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( 634 b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), 635 rhs->mutable_operand(0))); 636 return ReplaceWithNewInstruction( 637 divide, HloInstruction::CreateBinary( 638 divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c)); 639 } 640 641 // (A / B) / C => A / (B * C) 642 if (lhs->opcode() == HloOpcode::kDivide) { 643 TF_ASSIGN_OR_RETURN(const Shape b_times_c_shape, 644 ShapeInference::InferBinaryOpShape( 645 HloOpcode::kMultiply, lhs->operand(1), rhs)); 646 auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( 647 b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); 648 return ReplaceWithNewInstruction( 649 divide, 650 HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, 651 lhs->mutable_operand(0), b_times_c)); 652 } 653 654 // A / (B / C) => (A*C) / B 655 if (rhs->opcode() == HloOpcode::kDivide) { 656 TF_ASSIGN_OR_RETURN(const Shape a_times_c_shape, 657 ShapeInference::InferBinaryOpShape( 658 HloOpcode::kMultiply, lhs, rhs->operand(1))); 659 auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( 660 a_times_c_shape, HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); 661 return ReplaceWithNewInstruction( 662 divide, 663 HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, 664 a_times_c, rhs->mutable_operand(0))); 665 } 666 667 return Status::OK(); 668 } 669 670 StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction( 671 HloInstruction* dot) { 672 HloInstruction* lhs = dot->mutable_operand(0); 673 HloInstruction* rhs = dot->mutable_operand(1); 674 int64 lhs_collapsing_dim = 675 dot->dot_dimension_numbers().lhs_contracting_dimensions(0); 676 if (lhs->IsRank2Transpose()) { 677 lhs = lhs->mutable_operand(0); 678 lhs_collapsing_dim = 1 - lhs_collapsing_dim; 679 } 680 const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; 681 682 int64 rhs_collapsing_dim = 683 dot->dot_dimension_numbers().rhs_contracting_dimensions(0); 684 if (rhs->IsRank2Transpose()) { 685 rhs = rhs->mutable_operand(0); 686 rhs_collapsing_dim = 1 - rhs_collapsing_dim; 687 } 688 const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; 689 690 auto reshape_if_necessary = [&](HloInstruction* hlo) { 691 if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { 692 return hlo; 693 } 694 return computation_->AddInstruction( 695 HloInstruction::CreateReshape(dot->shape(), hlo)); 696 }; 697 698 auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, 699 int64 dim) { 700 return computation_->AddInstruction( 701 HloInstruction::CreateBroadcast(shape, hlo, {dim})); 702 }; 703 704 auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { 705 return computation_->AddInstruction(HloInstruction::CreateBinary( 706 local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); 707 }; 708 709 // Strength reduce dot(a[K] , b[K]) = 710 // reshape(result.shape, 711 // reduce_sum(multiply(a, b), {0})) 712 if (ShapeUtil::Rank(rhs->shape()) == 1 && 713 ShapeUtil::Rank(lhs->shape()) == 1) { 714 TF_RETURN_IF_ERROR( 715 ReplaceInstruction(dot, reshape_if_necessary(AddReduce( 716 multiply(Flatten(lhs), Flatten(rhs)), 0)))); 717 return true; 718 } 719 720 if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && 721 ShapeUtil::IsEffectiveScalar(lhs->shape())) { 722 TF_RETURN_IF_ERROR(ReplaceInstruction( 723 dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); 724 return true; 725 } 726 727 // Simplify outer product into multiply with implicit broadcasting. 728 // 729 // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) 730 if (ShapeUtil::Rank(rhs->shape()) == 2 && 731 rhs->shape().dimensions(rhs_collapsing_dim) == 1) { 732 TF_RETURN_IF_ERROR(ReplaceInstruction( 733 dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), 734 broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); 735 return true; 736 } 737 738 // Strength reduce dot(a[1, K], b) = 739 // reshape(result.shape, 740 // reduce_sum( 741 // multiply(broadcast(reshape(a, [K]), {0}), b), 742 // {0}) 743 // ) 744 // ) 745 if (ShapeUtil::Rank(lhs->shape()) == 1 || 746 (ShapeUtil::Rank(lhs->shape()) == 2 && 747 lhs->shape().dimensions(lhs_kept_dim) == 1)) { 748 if (ShapeUtil::Rank(rhs->shape()) == 1) { 749 TF_RETURN_IF_ERROR(ReplaceInstruction( 750 dot, 751 reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); 752 return true; 753 } 754 TF_RETURN_IF_ERROR(ReplaceInstruction( 755 dot, reshape_if_necessary( 756 AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), 757 rhs_collapsing_dim), 758 rhs), 759 rhs_collapsing_dim)))); 760 return true; 761 } 762 763 // Strength reduce dot(a, b[K, 1]) = 764 // reshape(result.shape, 765 // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) 766 // ) 767 if (ShapeUtil::Rank(rhs->shape()) == 1 || 768 (ShapeUtil::Rank(rhs->shape()) == 2 && 769 rhs->shape().dimensions(rhs_kept_dim) == 1)) { 770 TF_RETURN_IF_ERROR(ReplaceInstruction( 771 dot, reshape_if_necessary(AddReduce( 772 multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), 773 lhs_collapsing_dim)), 774 lhs_collapsing_dim)))); 775 return true; 776 } 777 return false; 778 } 779 780 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat( 781 HloInstruction* dot) { 782 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 783 if (dnums.lhs_contracting_dimensions_size() != 1 || 784 dnums.lhs_batch_dimensions_size() != 0) { 785 return nullptr; 786 } 787 788 const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); 789 const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); 790 HloInstruction* lhs = dot->mutable_operand(0); 791 HloInstruction* rhs = dot->mutable_operand(1); 792 793 TF_ASSIGN_OR_RETURN( 794 HloInstruction * optimized_lhs_concat, 795 OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, 796 rhs_contracting_dim, /*swapped=*/false)); 797 if (optimized_lhs_concat) { 798 return optimized_lhs_concat; 799 } 800 801 return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, 802 lhs_contracting_dim, /*swapped=*/true); 803 } 804 805 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( 806 const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, 807 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { 808 bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && 809 lhs->concatenate_dimension() == lhs_contracting_dim && 810 rhs->opcode() == HloOpcode::kConstant; 811 if (!can_optimize) { 812 return nullptr; 813 } 814 815 // We're replacing this: 816 // 817 // +-----+-----+-----+ +-------------------+ 818 // | | | | | | 819 // | | | | | R_0 | 820 // | | | | | | 821 // | | | | +-------------------+ 822 // | | | | | | 823 // | L_0 | L_1 | L_2 | * | R_1 | 824 // | | | | | | 825 // | | | | +-------------------+ 826 // | | | | | | 827 // | | | | | R_2 | 828 // | | | | | | 829 // +-----+-----+-----+ +-------------------+ 830 // 831 // with this: 832 // 833 // [Sum over i] 834 // 835 // +-----+ +-------------------+ 836 // | | | | 837 // | | * | R_i | 838 // | | | | 839 // | | +-------------------+ 840 // | | 841 // | L_i | 842 // | | 843 // | | 844 // | | 845 // | | 846 // | | 847 // +-----+ 848 // 849 // where the LHS is a concatenate operation (so we can "split" the LHS tensor 850 // for free) and the RHS is a constant tensor (and thus can be split at 851 // compile time). In the future, we may also want to do this when both the 852 // LHS and the RHS are concatenate operations that line up along the dimension 853 // being contracted over. 854 // 855 // We should be able to generalize this transform to work on a non-constant 856 // RHS when/if we have in-place slices or support input-fusing slices into 857 // Dots. 858 859 // Dimension numbers for the new dot instructions we'll create (L_i * R_i in 860 // the diagram above). 861 DotDimensionNumbers new_dot_dnums; 862 new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim 863 : lhs_contracting_dim); 864 new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim 865 : rhs_contracting_dim); 866 867 // Here we use the MKN notation, where the contracted dimension has K 868 // elements and the two non-contracted dimensions have M and N elements. 869 HloInstruction* add_result = nullptr; 870 int64 rhs_contracting_dim_offset = 0; 871 int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); 872 for (HloInstruction* concat_op : lhs->operands()) { 873 int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); 874 Shape rhs_slice_shape(rhs->shape()); 875 rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); 876 877 std::array<int64, 2> start_indices; 878 start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; 879 start_indices[1 - rhs_contracting_dim] = 0; 880 881 std::array<int64, 2> limit_indices; 882 limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; 883 limit_indices[1 - rhs_contracting_dim] = n; 884 885 HloInstruction* rhs_slice = 886 computation_->AddInstruction(HloInstruction::CreateSlice( 887 rhs_slice_shape, rhs, /*start_indices=*/start_indices, 888 /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); 889 890 // TODO(b/69062148): We can get rid of `swapped` once all backends support 891 // "non-canonical" contraction dimensions (that contracts dimension 1 of the 892 // LHS with dimension 0 of the RHS). But for now we keep the same 893 // contraction dimensions as the incoming dot operation to ensure the new 894 // dot operations can be lowered. 895 HloInstruction *new_dot_lhs, *new_dot_rhs; 896 if (swapped) { 897 new_dot_lhs = rhs_slice; 898 new_dot_rhs = concat_op; 899 } else { 900 new_dot_lhs = concat_op; 901 new_dot_rhs = rhs_slice; 902 } 903 904 auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( 905 dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); 906 907 if (add_result) { 908 add_result = computation_->AddInstruction(HloInstruction::CreateBinary( 909 dot_shape, HloOpcode::kAdd, add_result, new_dot)); 910 } else { 911 add_result = new_dot; 912 } 913 914 rhs_contracting_dim_offset += sub_k; 915 } 916 917 return add_result; 918 } 919 920 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { 921 auto lhs = dot->mutable_operand(0); 922 auto rhs = dot->mutable_operand(1); 923 924 // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or 925 // below. 926 if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || 927 ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { 928 return Status::OK(); 929 } 930 931 // Replace a zero element dot with a broadcast of the constant 0. 932 if (ShapeUtil::HasZeroElements(dot->shape()) || 933 ShapeUtil::HasZeroElements(lhs->shape()) || 934 ShapeUtil::HasZeroElements(rhs->shape())) { 935 auto zero = computation_->AddInstruction( 936 HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); 937 return ReplaceWithNewInstruction( 938 dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); 939 } 940 941 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, 942 OptimizeDotOfConcat(dot)); 943 if (dot_of_concat_optimized) { 944 VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " 945 "constant)...)"; 946 return ReplaceInstruction(dot, dot_of_concat_optimized); 947 } 948 949 if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { 950 TF_ASSIGN_OR_RETURN(bool did_strength_reduction, 951 HandleDotStrengthReduction(dot)); 952 if (did_strength_reduction) { 953 return Status::OK(); 954 } 955 } 956 957 // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). 958 if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { 959 DotDimensionNumbers dot_dimension_numbers; 960 dot_dimension_numbers.add_lhs_contracting_dimensions(1); 961 dot_dimension_numbers.add_rhs_contracting_dimensions(0); 962 auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( 963 ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), 964 rhs->mutable_operand(0), lhs->mutable_operand(0), 965 dot_dimension_numbers)); 966 return ReplaceWithNewInstruction( 967 dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); 968 } 969 970 return Status::OK(); 971 } 972 973 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { 974 auto lhs = multiply->mutable_operand(0); 975 auto rhs = multiply->mutable_operand(1); 976 // A*1 => A 977 VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); 978 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { 979 return Status::OK(); 980 } 981 // 1*A => A 982 VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); 983 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { 984 return Status::OK(); 985 } 986 987 // exp(A) * exp(B) => exp(A+B) 988 if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { 989 auto add = computation_->AddInstruction(HloInstruction::CreateBinary( 990 multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), 991 rhs->mutable_operand(0))); 992 return ReplaceWithNewInstruction( 993 multiply, 994 HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); 995 } 996 return Status::OK(); 997 } 998 999 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { 1000 // ln(exp(A)) => A 1001 VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); 1002 auto operand = log->mutable_operand(0); 1003 if (operand->opcode() == HloOpcode::kExp && 1004 ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { 1005 return Status::OK(); 1006 } 1007 1008 // ln(pow(A,B)) => B*ln(A) 1009 if (operand->opcode() == HloOpcode::kPower) { 1010 auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( 1011 log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); 1012 return ReplaceWithNewInstruction( 1013 log, 1014 HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, 1015 new_log, operand->mutable_operand(1))); 1016 } 1017 1018 return Status::OK(); 1019 } 1020 1021 Status AlgebraicSimplifierVisitor::HandleGetTupleElement( 1022 HloInstruction* get_tuple_element) { 1023 auto operand = get_tuple_element->mutable_operand(0); 1024 if (operand->opcode() == HloOpcode::kTuple) { 1025 // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i 1026 VLOG(10) << "trying transform " 1027 << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: " 1028 << get_tuple_element->ToString(); 1029 if (ReplaceInstructionIfSameShape( 1030 get_tuple_element, 1031 operand->mutable_operand(get_tuple_element->tuple_index()))) { 1032 return Status::OK(); 1033 } 1034 } 1035 return Status::OK(); 1036 } 1037 1038 namespace { 1039 1040 // Return whether the given reshape instruction leaves the dimensions at the 1041 // given input indices unmodified, and returns their output indices. 1042 // 1043 // Example: 1044 // input_dim_indices = {2, 3} 1045 // input shape = T[a, b, x, y, cd] 1046 // output shape = T[ab, x, 1, y, c, d] 1047 // return value = {1, 3} 1048 // 1049 // Precondition: input_dim_indices is sorted. 1050 std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified( 1051 const HloInstruction* hlo, 1052 tensorflow::gtl::ArraySlice<int64> input_dim_indices) { 1053 CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); 1054 CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); 1055 1056 std::vector<int64> output_dim_indices; 1057 std::vector<std::pair<int64, int64>> unmodified_dims = 1058 ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(), 1059 hlo->shape()); 1060 size_t i = 0; // index to unmodified_dims 1061 for (int64 input_dim_index : input_dim_indices) { 1062 // Search unmodified_dims for input_dim_index. We can search from the last 1063 // matching position because input_dim_indices is guaranteed to be sorted. 1064 while (i < unmodified_dims.size() && 1065 unmodified_dims[i].first < input_dim_index) { 1066 ++i; 1067 } 1068 if (i >= unmodified_dims.size() || 1069 unmodified_dims[i].first != input_dim_index) { 1070 return std::make_pair(false, std::vector<int64>()); 1071 } 1072 output_dim_indices.push_back(unmodified_dims[i].second); 1073 } 1074 return std::make_pair(true, output_dim_indices); 1075 } 1076 1077 // Returns true if the output of "instruction" is a permutation of the 1078 // elements of "operand". Precondition: "operand" is an operand of 1079 // "instruction". 1080 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, 1081 HloInstruction* operand) { 1082 DCHECK(!instruction->OperandIndices(operand).empty()); 1083 switch (instruction->opcode()) { 1084 case HloOpcode::kReshape: 1085 case HloOpcode::kReverse: 1086 case HloOpcode::kSort: 1087 case HloOpcode::kTranspose: 1088 return true; 1089 default: 1090 return false; 1091 } 1092 } 1093 1094 // Returns true if the output of "instruction" is a subset of the elements of 1095 // "operand". Precondition: "operand" is an operand of "instruction". 1096 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, 1097 HloInstruction* operand) { 1098 std::vector<int64> operand_indices = instruction->OperandIndices(operand); 1099 CHECK(!operand_indices.empty()); 1100 if (operand_indices.size() != 1) { 1101 return false; 1102 } 1103 int64 operand_index = operand_indices[0]; 1104 switch (instruction->opcode()) { 1105 case HloOpcode::kSlice: 1106 CHECK_EQ(0, operand_index); 1107 return true; 1108 case HloOpcode::kDynamicSlice: 1109 return operand_index == 0; 1110 default: 1111 return false; 1112 } 1113 } 1114 1115 } // namespace 1116 1117 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { 1118 auto operand = broadcast->mutable_operand(0); 1119 // A degenerate broadcast of a reshape that does not change the number of 1120 // elements can be replaced by a reshape. 1121 if (std::is_sorted(broadcast->dimensions().begin(), 1122 broadcast->dimensions().end()) && 1123 ShapeUtil::ElementsIn(broadcast->shape()) == 1124 ShapeUtil::ElementsIn(operand->shape())) { 1125 VLOG(10) << "transform broadcast(X) -> reshape(X) where " 1126 "n(broadcast(X)) == n(X)"; 1127 return ReplaceWithNewInstruction( 1128 broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); 1129 } 1130 1131 // A degenerate broadcast that has the same input and output rank can be 1132 // converted into a transpose. 1133 if (ShapeUtil::Rank(broadcast->shape()) == 1134 ShapeUtil::Rank(operand->shape()) && 1135 ShapeUtil::ElementsIn(broadcast->shape()) == 1136 ShapeUtil::ElementsIn(operand->shape())) { 1137 VLOG(10) << "transform broadcast(X) -> transpose(X) where " 1138 "n(broadcast(X)) == n(X)"; 1139 return ReplaceWithNewInstruction( 1140 broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, 1141 broadcast->dimensions())); 1142 } 1143 1144 // A broadcast of a reshape which merely inserts 1-sized dimensions can 1145 // elide its operand. 1146 { 1147 bool merely_inserts_or_deletes_1_sized_dimensions; 1148 std::vector<int64> inserted_indices, deleted_indices; 1149 std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices, 1150 inserted_indices) = 1151 operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); 1152 if (merely_inserts_or_deletes_1_sized_dimensions && 1153 deleted_indices.empty()) { 1154 std::reverse(inserted_indices.begin(), inserted_indices.end()); 1155 auto dims = broadcast->dimensions(); 1156 for (auto inserted_index : inserted_indices) { 1157 dims.erase(dims.begin() + inserted_index); 1158 } 1159 return ReplaceWithNewInstruction( 1160 broadcast, 1161 HloInstruction::CreateBroadcast(broadcast->shape(), 1162 operand->mutable_operand(0), dims)); 1163 } 1164 } 1165 1166 // A Broadcast that feeds a unary element-wise operation can sink the 1167 // broadcast after the unary element-wise operation. 1168 TF_ASSIGN_OR_RETURN( 1169 bool sink_succeeded, 1170 TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); 1171 changed_ |= sink_succeeded; 1172 if (sink_succeeded) { 1173 return Status::OK(); 1174 } 1175 1176 // A scalar broadcast feeding an instruction which only permutes (reshape, 1177 // transpose, sort, reverse) or selects a subset of operand elements (slice, 1178 // dynamic slice) can be replaced with a broadcast directly to the output 1179 // shape of the instruction. 1180 if (ShapeUtil::IsScalar(operand->shape())) { 1181 for (HloInstruction* user : broadcast->users()) { 1182 // Skip if the broadcast user has no uses itself. 1183 if (user->user_count() == 0 && user != computation_->root_instruction()) { 1184 continue; 1185 } 1186 if (OutputIsPermutationOfOperandElements(user, broadcast) || 1187 OutputIsSubsetOfOperandElements(user, broadcast)) { 1188 VLOG(10) << "transform permuting/subset of a scalar broadcast into " 1189 << "a single broadcast"; 1190 HloInstruction* new_broadcast = computation_->AddInstruction( 1191 HloInstruction::CreateBroadcast(user->shape(), operand, {})); 1192 // Use HloInstruction::ReplaceAllUsesWith instead of 1193 // HloComputation::ReplaceWithNewInstruction because we are replacing an 1194 // instruction other than the visited instruction. 1195 changed_ = true; 1196 return user->ReplaceAllUsesWith(new_broadcast); 1197 } 1198 } 1199 } 1200 return Status::OK(); 1201 } 1202 1203 // A conversion to the same element type as the operand is a nop and can be 1204 // removed. A conversion of a constant can be simplified by making a new 1205 // constant. 1206 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { 1207 PrimitiveType src_type = convert->operand(0)->shape().element_type(); 1208 PrimitiveType dest_type = convert->shape().element_type(); 1209 if (src_type == dest_type) { 1210 return ReplaceInstruction(convert, convert->mutable_operand(0)); 1211 } 1212 return Status::OK(); 1213 } 1214 1215 // Complex(Real(c), Imag(c)) -> c 1216 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { 1217 auto real = complex->mutable_operand(0); 1218 auto imag = complex->mutable_operand(1); 1219 if (real->opcode() == HloOpcode::kReal && 1220 imag->opcode() == HloOpcode::kImag && 1221 real->operand(0) == imag->operand(0)) { 1222 return ReplaceInstruction(complex, real->mutable_operand(0)); 1223 } 1224 return Status::OK(); 1225 } 1226 1227 // Real(Complex(r, i)) -> r 1228 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { 1229 auto operand = real->mutable_operand(0); 1230 if (operand->opcode() == HloOpcode::kComplex) { 1231 return ReplaceInstruction(real, operand->mutable_operand(0)); 1232 } 1233 return Status::OK(); 1234 } 1235 1236 // Imag(Complex(r, i)) -> i 1237 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { 1238 auto operand = imag->mutable_operand(0); 1239 if (operand->opcode() == HloOpcode::kComplex) { 1240 return ReplaceInstruction(imag, operand->mutable_operand(1)); 1241 } 1242 return Status::OK(); 1243 } 1244 1245 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { 1246 if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { 1247 return ReplaceWithNewInstruction( 1248 pad, HloInstruction::CreateBroadcast(pad->shape(), 1249 pad->mutable_operand(1), {})); 1250 } 1251 // Eliminate nop pads (padding all zero), and replace a pad with negative 1252 // padding with a pad with non-negative padding followed by a slice. 1253 bool all_zero = true; 1254 bool has_negative = false; 1255 for (auto& padding_dimension : pad->padding_config().dimensions()) { 1256 if (padding_dimension.edge_padding_low() < 0 || 1257 padding_dimension.edge_padding_high() < 0) { 1258 has_negative = true; 1259 } 1260 if (padding_dimension.edge_padding_low() != 0 || 1261 padding_dimension.edge_padding_high() != 0) { 1262 all_zero = false; 1263 } 1264 } 1265 1266 if (all_zero) { 1267 ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); 1268 return Status::OK(); 1269 } 1270 1271 if (has_negative) { 1272 // Pad has negative padding. Replace with a pad with the non-negative 1273 // padding followed by a slice which effectively performs the negative 1274 // padding. 1275 // TODO(b/34628603): Add support for negative padding in the backends, or 1276 // change kPad semantics to disallow negative padding and use slice 1277 // instead. 1278 1279 // First construct the padding config with non-negative entries and the 1280 // compute the shape of this new pad instruction. 1281 PaddingConfig nonzero_padding = pad->padding_config(); 1282 for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) { 1283 PaddingConfig::PaddingConfigDimension* padding_dimension = 1284 nonzero_padding.mutable_dimensions(i); 1285 // Set negative padding to zero. 1286 if (padding_dimension->edge_padding_low() < 0) { 1287 padding_dimension->set_edge_padding_low(0); 1288 } 1289 if (padding_dimension->edge_padding_high() < 0) { 1290 padding_dimension->set_edge_padding_high(0); 1291 } 1292 } 1293 TF_ASSIGN_OR_RETURN(Shape nonzero_pad_shape, 1294 ShapeInference::InferPadShape(pad->operand(0)->shape(), 1295 pad->operand(1)->shape(), 1296 nonzero_padding)); 1297 // Copy the layout from the original pad instructions. The new pad and the 1298 // slice instruction should all have the same layout. 1299 TF_RETURN_IF_ERROR( 1300 LayoutUtil::CopyLayoutBetweenShapes(pad->shape(), &nonzero_pad_shape)); 1301 HloInstruction* nonzero_pad = computation_->AddInstruction( 1302 HloInstruction::CreatePad(nonzero_pad_shape, pad->mutable_operand(0), 1303 pad->mutable_operand(1), nonzero_padding)); 1304 1305 // Second, construct the slice instruction to perform the negative padding. 1306 std::vector<int64> start_indices; 1307 std::vector<int64> end_indices; 1308 std::vector<int64> strides; 1309 for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) { 1310 const PaddingConfig::PaddingConfigDimension& padding_dimension = 1311 pad->padding_config().dimensions(i); 1312 int64 start = 0; 1313 if (padding_dimension.edge_padding_low() < 0) { 1314 start = -1 * padding_dimension.edge_padding_low(); 1315 } 1316 int64 end = nonzero_pad_shape.dimensions(i); 1317 if (padding_dimension.edge_padding_high() < 0) { 1318 end += padding_dimension.edge_padding_high(); 1319 } 1320 start_indices.push_back(start); 1321 end_indices.push_back(end); 1322 strides.push_back(1); 1323 } 1324 1325 // Verify that the slice shape matches the pad shape. 1326 TF_ASSIGN_OR_RETURN( 1327 Shape inferred_slice_shape, 1328 ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices, 1329 end_indices, strides)); 1330 TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); 1331 1332 std::unique_ptr<HloInstruction> slice = HloInstruction::CreateSlice( 1333 pad->shape(), nonzero_pad, start_indices, end_indices, strides); 1334 return ReplaceWithNewInstruction(pad, std::move(slice)); 1335 } 1336 1337 return Status::OK(); 1338 } 1339 1340 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { 1341 VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); 1342 auto lhs = power->mutable_operand(0); 1343 auto rhs = power->mutable_operand(1); 1344 if (IsAll(rhs, 0)) { 1345 auto one = HloInstruction::CreateConstant( 1346 Literal::One(power->shape().element_type()).CloneToUnique()); 1347 std::unique_ptr<HloInstruction> ones; 1348 if (ShapeUtil::IsScalar(power->shape())) { 1349 ones = std::move(one); 1350 } else { 1351 ones = HloInstruction::CreateBroadcast( 1352 power->shape(), computation_->AddInstruction(std::move(one)), {}); 1353 } 1354 return ReplaceWithNewInstruction(power, std::move(ones)); 1355 } 1356 1357 VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); 1358 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { 1359 return Status::OK(); 1360 } 1361 1362 // pow(exp(A),B) => exp(A*B) 1363 if (lhs->opcode() == HloOpcode::kExp) { 1364 auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( 1365 power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); 1366 return ReplaceWithNewInstruction( 1367 power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, 1368 a_times_b)); 1369 } 1370 VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); 1371 if (IsAll(rhs, 2)) { 1372 return ReplaceWithNewInstruction( 1373 power, HloInstruction::CreateBinary(power->shape(), 1374 HloOpcode::kMultiply, lhs, lhs)); 1375 } 1376 1377 VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); 1378 if (IsAll(rhs, -1)) { 1379 auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( 1380 Literal::One(rhs->shape().element_type()).CloneToUnique())); 1381 1382 // Explicitly broadcast scalar 1 to the output shape, to avoid implicit 1383 // broadcast in divide HLO as we are trying to eliminate implicit 1384 // broadcasting at HLO level. 1385 auto* broadcast_one = computation_->AddInstruction( 1386 HloInstruction::CreateBroadcast(power->shape(), one, {})); 1387 return ReplaceWithNewInstruction( 1388 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, 1389 broadcast_one, lhs)); 1390 } 1391 1392 VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: " 1393 << power->ToString(); 1394 1395 // Don't perform this optimization if either of the exponents is complex; this 1396 // identity is true only for real-valued exponents. In addition, we cowardly 1397 // refuse to do this transformation if the two expontents have different 1398 // element types. 1399 if (lhs->opcode() == HloOpcode::kPower && 1400 !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) && 1401 !ShapeUtil::ElementIsComplex(rhs->shape()) && 1402 ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) { 1403 auto exponent_product = 1404 computation_->AddInstruction(HloInstruction::CreateBinary( 1405 rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); 1406 return ReplaceWithNewInstruction( 1407 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower, 1408 lhs->mutable_operand(0), 1409 exponent_product)); 1410 } 1411 1412 return Status::OK(); 1413 } 1414 1415 StatusOr<bool> AlgebraicSimplifierVisitor:: 1416 TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( 1417 HloInstruction* reshape_or_broadcast) { 1418 bool changed = false; 1419 if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { 1420 return false; 1421 } 1422 HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); 1423 for (HloInstruction* user : reshape_or_broadcast->users()) { 1424 if (user->user_count() == 0 && user != computation_->root_instruction()) { 1425 continue; 1426 } 1427 // Do not move reshapes or broadcasts past copies since the shape the copy 1428 // will operate on will change. 1429 if (user->opcode() == HloOpcode::kCopy) { 1430 continue; 1431 } 1432 // Do not change the shape of fusion nodes in case there a multiple shapes 1433 // inside the fusion node already. 1434 if (user->opcode() == HloOpcode::kFusion) { 1435 continue; 1436 } 1437 if (!user->IsElementwise()) { 1438 continue; 1439 } 1440 1441 int64 reshape_or_broadcast_operand_index = -1; 1442 // Find the unique non-scalar operand or continue if there isn't one. 1443 int64 scalar_count = 0; 1444 for (int64 i = 0; i < user->operand_count(); ++i) { 1445 if (ShapeUtil::IsScalar(user->operand(i)->shape())) { 1446 ++scalar_count; 1447 } else { 1448 reshape_or_broadcast_operand_index = i; 1449 } 1450 } 1451 if (scalar_count != user->operand_count() - 1) { 1452 continue; 1453 } 1454 VLOG(4) << "Sinking reshape or broadcast after user:"; 1455 VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); 1456 VLOG(4) << " old user: " << user->ToString(); 1457 CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), 1458 reshape_or_broadcast); 1459 auto new_user_operands = user->operands(); 1460 new_user_operands[reshape_or_broadcast_operand_index] = operand; 1461 auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( 1462 ShapeUtil::MakeShapeWithLayout( 1463 user->shape().element_type(), 1464 AsInt64Slice(operand->shape().dimensions()), 1465 LayoutUtil::MinorToMajor(operand->shape())), 1466 new_user_operands)); 1467 VLOG(4) << " new user: " << new_user->ToString(); 1468 HloInstruction* new_reshape_or_broadcast = nullptr; 1469 if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { 1470 new_reshape_or_broadcast = 1471 computation_->AddInstruction(HloInstruction::CreateReshape( 1472 ShapeUtil::MakeShapeWithLayout( 1473 user->shape().element_type(), 1474 AsInt64Slice(reshape_or_broadcast->shape().dimensions()), 1475 LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), 1476 new_user)); 1477 } else { 1478 TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); 1479 new_reshape_or_broadcast = 1480 computation_->AddInstruction(HloInstruction::CreateBroadcast( 1481 ShapeUtil::MakeShapeWithLayout( 1482 user->shape().element_type(), 1483 AsInt64Slice(reshape_or_broadcast->shape().dimensions()), 1484 LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), 1485 new_user, reshape_or_broadcast->dimensions())); 1486 } 1487 VLOG(4) << " new reshape/broadcast: " 1488 << new_reshape_or_broadcast->ToString(); 1489 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); 1490 changed = true; 1491 } 1492 return changed; 1493 } 1494 1495 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { 1496 auto operand = reshape->mutable_operand(0); 1497 1498 // Reshape directly to empty constant if the shape contains zero-element 1499 // dimension. 1500 if (ShapeUtil::HasZeroElements(reshape->shape())) { 1501 auto empty_constant = HloInstruction::CreateConstant( 1502 Literal::CreateFromShape(reshape->shape())); 1503 1504 return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); 1505 } 1506 1507 // Delete no-op reshapes, i.e. where shape = operand shape. 1508 if (SameShape(reshape, operand)) { 1509 VLOG(10) << "deleting no-op reshape"; 1510 return ReplaceInstruction(reshape, operand); 1511 } 1512 1513 // Merge reshapes. 1514 if (HloOpcode::kReshape == operand->opcode()) { 1515 return ReplaceWithNewInstruction( 1516 reshape, HloInstruction::CreateReshape(reshape->shape(), 1517 operand->mutable_operand(0))); 1518 } 1519 1520 if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { 1521 auto opt_dims = ReshapeLeavesDimensionsUnmodified( 1522 reshape, reshape->operand(0)->dimensions()); 1523 if (opt_dims.first) { 1524 return ReplaceWithNewInstruction( 1525 reshape, 1526 HloInstruction::CreateBroadcast( 1527 reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), 1528 opt_dims.second)); 1529 } 1530 } 1531 1532 // A Reshape that feeds a unary element-wise operation can sink the 1533 // reshape after the unary element-wise operation. 1534 TF_ASSIGN_OR_RETURN( 1535 bool sink_succeeded, 1536 TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); 1537 changed_ |= sink_succeeded; 1538 if (sink_succeeded) { 1539 return Status::OK(); 1540 } 1541 1542 // Make this a bitcast if possible. 1543 if (is_layout_sensitive_ && 1544 ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { 1545 ReplaceWithBitcast(reshape); 1546 return Status::OK(); 1547 } 1548 1549 return Status::OK(); 1550 } 1551 1552 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { 1553 // When all the dimensions to reverse are trivial (i.e. the bound is 1), 1554 // there is nothing to be done. 1555 auto dim_is_one = [&](int64 i) -> bool { 1556 return reverse->shape().dimensions(i) == 1; 1557 }; 1558 if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), 1559 dim_is_one)) { 1560 return ReplaceInstruction(reverse, reverse->mutable_operand(0)); 1561 } 1562 return Status::OK(); 1563 } 1564 1565 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { 1566 // Delete no-op slices, i.e. where shape = operand shape. 1567 if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { 1568 return Status::OK(); 1569 } 1570 return Status::OK(); 1571 } 1572 1573 Status AlgebraicSimplifierVisitor::HandleDynamicSlice( 1574 HloInstruction* dynamic_slice) { 1575 auto operand = dynamic_slice->mutable_operand(0); 1576 auto start_indices = dynamic_slice->operand(1); 1577 if (ShapeUtil::IsScalar(dynamic_slice->shape())) { 1578 return ReplaceInstruction(dynamic_slice, operand); 1579 } 1580 // DynamicSlice where operand has the same size as the output and 1581 // start_indices are all zero is simply equal to operand. 1582 if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { 1583 return ReplaceInstruction(dynamic_slice, operand); 1584 } 1585 return Status::OK(); 1586 } 1587 1588 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( 1589 HloInstruction* dynamic_update_slice) { 1590 auto update = dynamic_update_slice->mutable_operand(1); 1591 auto start_indices = dynamic_update_slice->operand(2); 1592 // DynamicUpdateSlice on a scalar just passes through the update argument. 1593 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { 1594 return ReplaceInstruction(dynamic_update_slice, update); 1595 } 1596 1597 // DynamicUpdateSlice where operand and update have the same size and 1598 // start_indices are all zero is simply equal to update. 1599 // 1600 // (We require start_indices to be all zero because we want this optimization 1601 // not to affect the visible behavior of this op even when the indices are out 1602 // of range. Currently dynamic-update-slice wraps out-of-range indices, so 1603 // we can only remove the op if its indices never wrap.) 1604 if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { 1605 return ReplaceInstruction(dynamic_update_slice, update); 1606 } 1607 return Status::OK(); 1608 } 1609 1610 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { 1611 auto arg = reduce->mutable_operand(0); 1612 auto init_value = reduce->mutable_operand(1); 1613 tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions()); 1614 HloComputation* function = reduce->to_apply(); 1615 if (ShapeUtil::HasZeroElements(arg->shape()) || 1616 ShapeUtil::HasZeroElements(reduce->shape())) { 1617 return ReplaceWithNewInstruction( 1618 reduce, 1619 HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); 1620 } 1621 1622 // A Transpose feeding a reduce can simply permute the reduction dimensions 1623 // field if the output of the reduce is a vector or scalar. Higher ranked 1624 // result may require a transpose of the output. 1625 if (ShapeUtil::Rank(reduce->shape()) <= 1 && 1626 arg->opcode() == HloOpcode::kTranspose) { 1627 auto transpose_dimensions = arg->dimensions(); 1628 std::vector<int64> new_reduce_dimensions; 1629 for (auto dim : dimensions) { 1630 new_reduce_dimensions.push_back(transpose_dimensions[dim]); 1631 } 1632 return ReplaceWithNewInstruction( 1633 reduce, HloInstruction::CreateReduce( 1634 reduce->shape(), arg->mutable_operand(0), init_value, 1635 new_reduce_dimensions, function)); 1636 } 1637 1638 // A reshape that collapses multiple dimensions into a dimension being 1639 // reduced can just reduce all of those dimensions instead of doing a 1640 // collapsing reshape before a reduction. 1641 if (arg->opcode() == HloOpcode::kReshape) { 1642 std::vector<std::pair<int64, int64>> unmodified_dims = 1643 ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), 1644 arg->shape()); 1645 std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); 1646 std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); 1647 for (auto dim : dimensions) { 1648 arg_dim_in_output[dim] = false; 1649 } 1650 for (auto dim_pair : unmodified_dims) { 1651 arg_dim_unmodified[dim_pair.second] = true; 1652 } 1653 // The goal is to verify that all dimensions that are not removed in the 1654 // reduce are unmodified by the reshape. For example: 1655 // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2]) 1656 bool can_move_reshape_into_reduce = true; 1657 for (int64 i = 0; i < arg_dim_in_output.size(); ++i) { 1658 if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) { 1659 can_move_reshape_into_reduce = false; 1660 } 1661 } 1662 if (can_move_reshape_into_reduce) { 1663 changed_ = true; 1664 std::unordered_set<int64> dimensions_not_to_reduce; 1665 for (auto dim_pair : unmodified_dims) { 1666 if (arg_dim_in_output[dim_pair.second]) { 1667 dimensions_not_to_reduce.insert(dim_pair.first); 1668 } 1669 } 1670 std::vector<int64> new_reduce_dimensions; 1671 for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { 1672 if (dimensions_not_to_reduce.count(i) == 0) { 1673 new_reduce_dimensions.push_back(i); 1674 } 1675 } 1676 return ReplaceWithNewInstruction( 1677 reduce, HloInstruction::CreateReduce( 1678 reduce->shape(), arg->mutable_operand(0), init_value, 1679 new_reduce_dimensions, function)); 1680 } 1681 } 1682 if (ShapeUtil::ElementsIn(reduce->shape()) == 1683 ShapeUtil::ElementsIn(arg->shape()) || 1684 ShapeUtil::HasZeroElements(arg->shape())) { 1685 auto reshape = computation_->AddInstruction( 1686 HloInstruction::CreateReshape(reduce->shape(), arg)); 1687 return ReplaceWithNewInstruction( 1688 reduce, HloInstruction::CreateMap(reduce->shape(), 1689 {reshape, init_value}, function)); 1690 } 1691 return Status::OK(); 1692 } 1693 1694 Status AlgebraicSimplifierVisitor::HandleReduceWindow( 1695 HloInstruction* reduce_window) { 1696 if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { 1697 return ReplaceWithNewInstruction( 1698 reduce_window, 1699 HloInstruction::CreateBroadcast(reduce_window->shape(), 1700 reduce_window->mutable_operand(1), {})); 1701 } 1702 auto operand = reduce_window->mutable_operand(0); 1703 const Window& window = reduce_window->window(); 1704 auto function = reduce_window->to_apply(); 1705 if (ShapeUtil::IsScalar(operand->shape())) { 1706 TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape())); 1707 return ReplaceWithNewInstruction( 1708 reduce_window, 1709 HloInstruction::CreateMap(reduce_window->shape(), 1710 {operand, reduce_window->mutable_operand(1)}, 1711 function)); 1712 } 1713 1714 VLOG(10) << "Considering folding Pad: " << operand->ToString() 1715 << "\ninto reduce-window: " << reduce_window->ToString(); 1716 1717 // This optimization folds a pad op into reduce_window. 1718 if (operand->opcode() != HloOpcode::kPad) { 1719 VLOG(10) << "Not folding pad into reduce-window as there is no pad."; 1720 return Status::OK(); 1721 } 1722 1723 // Do not fold interior padding into ReduceWindow since the backends do not 1724 // support it. 1725 const PaddingConfig& pad_config = operand->padding_config(); 1726 if (HasInteriorPadding(pad_config)) { 1727 VLOG(10) << "Not folding pad into reduce-window due to interior padding."; 1728 return Status::OK(); 1729 } 1730 1731 // If reduce_window already has padding, the pad value of the pad op and the 1732 // init value of reduce_window must match to allow folding the pad. 1733 const HloInstruction* pad_value = operand->operand(1); 1734 const HloInstruction* reduce_init_value = reduce_window->operand(1); 1735 if (pad_value != reduce_init_value) { 1736 // The pad value is usually a constant, so we handle that case and do not 1737 // try to get more fancy about proving equivalence in cases beyond that. 1738 if (pad_value->opcode() != HloOpcode::kConstant || 1739 reduce_init_value->opcode() != HloOpcode::kConstant || 1740 pad_value->literal() != reduce_init_value->literal()) { 1741 VLOG(10) << "Not folding pad into reduce-window due to different pad " 1742 "values."; 1743 return Status::OK(); 1744 } 1745 } 1746 1747 // If the pad puts a single non-identity value in each window that we're 1748 // reducing, then this is a broadcast. 1749 HloInstruction* pad_operand = operand->mutable_operand(0); 1750 auto is_effective_broadcast = [&] { 1751 if (window_util::HasStride(window)) { 1752 VLOG(10) << "Window has stride."; 1753 return false; 1754 } 1755 if (!window_util::HasSymmetricPadding(pad_config)) { 1756 VLOG(10) << "Window has uneven padding."; 1757 return false; 1758 } 1759 for (int64 i = 0; i < pad_config.dimensions_size(); ++i) { 1760 const auto& pad_dimension = pad_config.dimensions(i); 1761 if ((pad_dimension.edge_padding_low() != 0 || 1762 pad_dimension.edge_padding_high() != 0) && 1763 pad_operand->shape().dimensions(i) != 1) { 1764 VLOG(10) << "Found non-trivial dimension being padded: " << i; 1765 return false; 1766 } 1767 } 1768 VLOG(10) << "Found to be padding trivial dimensions only."; 1769 1770 for (int64 i = 0; i < window.dimensions_size(); ++i) { 1771 const auto& pad_dimension = pad_config.dimensions(i); 1772 const WindowDimension& window_dimension = window.dimensions(i); 1773 bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 || 1774 pad_dimension.edge_padding_high() != 0); 1775 if (dimension_has_padding && 1776 window_dimension.size() < pad_dimension.edge_padding_low() + 1) { 1777 VLOG(10) << "Found window did not cover single unpadded element in " 1778 "dimension: " 1779 << i; 1780 return false; 1781 } 1782 if (pad_operand->shape().dimensions(i) != 1 && 1783 window_dimension.size() != 1) { 1784 VLOG(10) << "Found window covers more than one element in non-trivial " 1785 "dimension: " 1786 << i; 1787 return false; 1788 } 1789 } 1790 VLOG(10) << "Found window covers a single unpadded element."; 1791 return true; 1792 }; 1793 if (is_effective_broadcast()) { 1794 VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; 1795 auto fadd = [this](std::unique_ptr<HloInstruction> x) { 1796 return computation_->AddInstruction(std::move(x)); 1797 }; 1798 return ReplaceWithNewInstruction( 1799 reduce_window, HloInstruction::CreateBroadcastSequence( 1800 /*output_shape=*/reduce_window->shape(), 1801 /*operand=*/pad_operand, fadd)); 1802 } 1803 1804 // Carry out the folding of the pad into reduce_window. 1805 VLOG(10) << "Folding pad into reduce-window."; 1806 Window new_window = window; 1807 const int64 rank = ShapeUtil::Rank(reduce_window->shape()); 1808 TF_RET_CHECK(pad_config.dimensions_size() == rank); 1809 TF_RET_CHECK(window.dimensions_size() == rank); 1810 for (int64 i = 0; i < rank; ++i) { 1811 const auto& pad_dim = pad_config.dimensions(i); 1812 auto& window_dim = *new_window.mutable_dimensions(i); 1813 window_dim.set_padding_low(window_dim.padding_low() + 1814 pad_dim.edge_padding_low()); 1815 window_dim.set_padding_high(window_dim.padding_high() + 1816 pad_dim.edge_padding_high()); 1817 } 1818 return ReplaceWithNewInstruction( 1819 reduce_window, HloInstruction::CreateReduceWindow( 1820 /*shape=*/reduce_window->shape(), 1821 /*operand=*/pad_operand, 1822 /*init_value=*/reduce_window->mutable_operand(1), 1823 /*window=*/new_window, 1824 /*reduce_computation=*/function)); 1825 } 1826 1827 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { 1828 auto operand = transpose->mutable_operand(0); 1829 if (std::is_sorted(transpose->dimensions().begin(), 1830 transpose->dimensions().end())) { 1831 VLOG(10) << "deleting no-op transpose"; 1832 return ReplaceInstruction(transpose, operand); 1833 } 1834 1835 if (HloOpcode::kTranspose == operand->opcode()) { 1836 return ReplaceWithNewInstruction( 1837 transpose, HloInstruction::CreateTranspose( 1838 transpose->shape(), operand->mutable_operand(0), 1839 ComposePermutations(operand->dimensions(), 1840 transpose->dimensions()))); 1841 } 1842 1843 if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { 1844 ReplaceWithBitcast(transpose); 1845 return Status::OK(); 1846 } 1847 1848 return Status::OK(); 1849 } 1850 1851 Status AlgebraicSimplifierVisitor::HandleConvolution( 1852 HloInstruction* convolution) { 1853 auto lhs = convolution->mutable_operand(0); 1854 auto rhs = convolution->mutable_operand(1); 1855 if (ShapeUtil::HasZeroElements(lhs->shape()) || 1856 ShapeUtil::HasZeroElements(rhs->shape())) { 1857 return ReplaceWithNewInstruction( 1858 convolution, 1859 HloInstruction::CreateBroadcast( 1860 convolution->shape(), 1861 computation_->AddInstruction(HloInstruction::CreateConvert( 1862 ShapeUtil::MakeShape(convolution->shape().element_type(), {}), 1863 computation_->AddInstruction( 1864 HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), 1865 {})); 1866 } 1867 const auto& window = convolution->window(); 1868 if (!enable_conv_simplification_) { 1869 return Status::OK(); 1870 } 1871 // HandleConvolution tries to replace a convolution with a DOT instruction. 1872 // 1873 // Only add when bitcasts can be used: 1874 // - if bitcasts are not supported, then reshapes could be used but will 1875 // end up with another copy. 1876 // - if bitcasts are supported, the simplifier will be called again with 1877 // bitcasts_ == true. 1878 1879 // TODO(cwhipkey): b/31337498, make this layout insensitive. 1880 if (!is_layout_sensitive_) { 1881 return Status::OK(); 1882 } 1883 1884 const ConvolutionDimensionNumbers& dnums = 1885 convolution->convolution_dimension_numbers(); 1886 const Shape& input_shape = lhs->shape(); 1887 const Shape& filter_shape = rhs->shape(); 1888 const Shape& convolution_shape = convolution->shape(); 1889 TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); 1890 TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape)); 1891 TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape)); 1892 1893 // Require the spatial dimensions in the kernel to have a bound of one. 1894 for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { 1895 if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { 1896 return Status::OK(); 1897 } 1898 } 1899 1900 // Stride ignores part of the output, which matrix multiplication does not do, 1901 // so require no stride. Padding and base (lhs) dilation both implicitly 1902 // extend the data, which matrix multiplication also does not do, so require 1903 // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect 1904 // for a 1x1 window, so window dilation is no problem. 1905 if (window_util::HasStride(window) || window_util::HasPadding(window) || 1906 window_util::HasBaseDilation(window)) { 1907 return Status::OK(); 1908 } 1909 1910 // Also, the shapes must align for a rowmajor matmul: 1911 // - the input and output have the same layout. 1912 // - for input/output, the channel dimension must be the most minor. Other 1913 // spatial dims can be in any order. 1914 // - for filters, the input channel dimension must be more major than the 1915 // output channel dimension. The width+height don't matter because 1916 // they are 1. 1917 // 1918 // These constraints are harsh. If the channel dimension is the most major 1919 // and/or the layout of input/output feature dimensions are reversed, we can 1920 // still convert Conv into more efficient Matmul with operand transposition 1921 // (such as the transposition flags in cuBLAS SGEMM). 1922 if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || 1923 LayoutUtil::Minor(input_shape.layout(), 0) != 1924 dnums.input_feature_dimension() || 1925 LayoutUtil::Minor(convolution_shape.layout(), 0) != 1926 dnums.output_feature_dimension() || 1927 // The input feature dimension should come later in the minor-to-major 1928 // order. 1929 (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), 1930 dnums.kernel_input_feature_dimension()) < 1931 PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), 1932 dnums.kernel_output_feature_dimension()))) { 1933 return Status::OK(); 1934 } 1935 1936 auto add_bitcast = [&](Shape shape, HloInstruction* operand) { 1937 std::vector<int64> dims(operand->shape().dimensions_size()); 1938 std::iota(dims.begin(), dims.end(), 0); 1939 return computation_->AddInstruction( 1940 HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand)); 1941 }; 1942 1943 // Replace it with a dot, with bitcasts around it to get the right shape. 1944 const int64 input_channels = 1945 input_shape.dimensions(dnums.input_feature_dimension()); 1946 const int64 output_channels = 1947 filter_shape.dimensions(dnums.kernel_output_feature_dimension()); 1948 1949 // Computes the product of the non-feature dimensions. 1950 int64 conv_width = 1; 1951 for (int i = 0; i < input_shape.dimensions_size(); ++i) { 1952 if (i != dnums.input_feature_dimension()) { 1953 conv_width *= input_shape.dimensions(i); 1954 } 1955 } 1956 1957 // We already checked feature_dimension is most minor, so data in input_shape 1958 // and row-major {conv_width,input_channels} are bitwise identical. 1959 const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( 1960 input_shape.element_type(), {conv_width, input_channels}); 1961 // We already checked input_feature_dimension is more major than 1962 // output_feature_dimension, so data in filter_shape and row-major 1963 // {input_channels,output_channels} are bitwise identical. 1964 const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( 1965 filter_shape.element_type(), {input_channels, output_channels}); 1966 const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( 1967 convolution_shape.element_type(), {conv_width, output_channels}); 1968 1969 // We cannot insert bitcasts if the layouts will not be compatible. 1970 // TODO(b/33178038): Consider inserting a transpose if a bitcast would be 1971 // invalid. 1972 if (!valid_bitcast_callback_(input_shape, new_input_shape) || 1973 !valid_bitcast_callback_(filter_shape, new_filter_shape) || 1974 !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { 1975 return Status::OK(); 1976 } 1977 1978 auto new_lhs = add_bitcast(new_input_shape, lhs); 1979 auto new_rhs = add_bitcast(new_filter_shape, rhs); 1980 DotDimensionNumbers dot_dimension_numbers; 1981 dot_dimension_numbers.add_lhs_contracting_dimensions(1); 1982 dot_dimension_numbers.add_rhs_contracting_dimensions(0); 1983 auto dot = computation_->AddInstruction(HloInstruction::CreateDot( 1984 dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); 1985 return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); 1986 } 1987 1988 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( 1989 HloInstruction* root, HloInstruction* min, HloInstruction* min_operand, 1990 HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) { 1991 // Ensure shapes of min and max operand are equal to match current shape 1992 // inference. 1993 if (!SameShape(min_operand, max_operand)) { 1994 return false; 1995 } 1996 1997 auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, 1998 max_operand, operand, min_operand); 1999 TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp))); 2000 return true; 2001 } 2002 2003 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { 2004 // Match the following tree: 2005 // min_operand operand 2006 // \ / 2007 // max_operand min 2008 // \ / 2009 // max 2010 // where max_operand and min_operand are scalar constants. 2011 { 2012 HloInstruction* min; 2013 HloInstruction* max_operand; 2014 HloInstruction* min_operand; 2015 HloInstruction* operand; 2016 2017 if (hlo_query::MatchBinaryInstructionOperandOpcode( 2018 HloOpcode::kMinimum, maximum, 2019 /*matching_operand=*/&min, 2020 /*other_operand=*/&max_operand) && 2021 hlo_query::MatchBinaryInstructionOperand( 2022 hlo_query::IsScalarConstant, min, 2023 /*matching_operand=*/&min_operand, 2024 /*other_operand=*/&operand) && 2025 TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, 2026 max_operand)) { 2027 return Status::OK(); 2028 } 2029 } 2030 2031 return Status::OK(); 2032 } 2033 2034 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { 2035 // Match the following tree: 2036 // max_operand operand 2037 // \ / 2038 // min_operand max 2039 // \ / 2040 // min 2041 // where max_operand and min_operand are scalar constants. 2042 { 2043 HloInstruction* max; 2044 HloInstruction* max_operand; 2045 HloInstruction* min_operand; 2046 HloInstruction* operand; 2047 2048 if (hlo_query::MatchBinaryInstructionOperandOpcode( 2049 HloOpcode::kMaximum, minimum, 2050 /*matching_operand=*/&max, 2051 /*other_operand=*/&min_operand) && 2052 hlo_query::MatchBinaryInstructionOperand( 2053 hlo_query::IsScalarConstant, max, 2054 /*matching_operand=*/&max_operand, 2055 /*other_operand=*/&operand) && 2056 TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, 2057 max_operand)) { 2058 return Status::OK(); 2059 } 2060 } 2061 2062 return Status::OK(); 2063 } 2064 2065 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { 2066 XLA_VLOG_LINES(2, 2067 "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); 2068 bool changed = false; 2069 for (auto* comp : module->MakeNonfusionComputations()) { 2070 if (AlgebraicSimplifierVisitor::Run( 2071 comp, is_layout_sensitive_, valid_bitcast_callback_, 2072 enable_dot_strength_reduction_, enable_conv_simplification_)) { 2073 changed = true; 2074 } 2075 } 2076 XLA_VLOG_LINES(2, 2077 "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); 2078 return changed; 2079 } 2080 2081 } // namespace xla 2082