1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 17 18 #include <algorithm> 19 #include <cmath> 20 #include <functional> 21 #include <iterator> 22 #include <memory> 23 #include <numeric> 24 #include <string> 25 #include <utility> 26 #include <vector> 27 28 #include "absl/algorithm/container.h" 29 #include "absl/container/flat_hash_map.h" 30 #include "absl/container/flat_hash_set.h" 31 #include "absl/container/inlined_vector.h" 32 #include "absl/memory/memory.h" 33 #include "absl/strings/str_cat.h" 34 #include "absl/types/optional.h" 35 #include "absl/types/span.h" 36 #include "tensorflow/compiler/xla/layout_util.h" 37 #include "tensorflow/compiler/xla/literal.h" 38 #include "tensorflow/compiler/xla/literal_util.h" 39 #include "tensorflow/compiler/xla/primitive_util.h" 40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" 42 #include "tensorflow/compiler/xla/service/hlo_computation.h" 43 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" 44 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 45 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 46 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 47 #include "tensorflow/compiler/xla/service/hlo_query.h" 48 #include "tensorflow/compiler/xla/service/pattern_matcher.h" 49 #include "tensorflow/compiler/xla/shape.h" 50 #include "tensorflow/compiler/xla/shape_util.h" 51 #include "tensorflow/compiler/xla/status_macros.h" 52 #include "tensorflow/compiler/xla/types.h" 53 #include "tensorflow/compiler/xla/util.h" 54 #include "tensorflow/compiler/xla/window_util.h" 55 #include "tensorflow/compiler/xla/xla_data.pb.h" 56 #include "tensorflow/core/lib/core/bits.h" 57 #include "tensorflow/core/lib/core/errors.h" 58 #include "tensorflow/core/lib/core/status.h" 59 #include "tensorflow/core/platform/logging.h" 60 #include "tensorflow/core/platform/types.h" 61 62 namespace xla { 63 64 namespace { 65 66 namespace m = match; 67 68 bool IsAll(const HloInstruction* op, int8 value) { 69 switch (op->opcode()) { 70 case HloOpcode::kBroadcast: 71 return IsAll(op->operand(0), value); 72 case HloOpcode::kConstant: 73 return op->literal().IsAll(value); 74 default: 75 return false; 76 } 77 } 78 79 // Checks whether `op` is a floating-point constant or broadcast of a constant 80 // of the form +/- 2^k for some integer k positive, negative, or zero. Such 81 // values are interesting because multiplying by a power of 2 just moves the 82 // exponent. 83 bool IsAllFpConstantPowerOf2(const HloInstruction* op) { 84 // Unwrap the broadcast if necessary. 85 const HloInstruction* c; 86 if (!Match(op, m::ConstantEffectiveScalar(&c)) && 87 !Match(op, m::Broadcast(m::Constant(&c).WithShape( 88 m::Shape().IsEffectiveScalar())))) { 89 return false; 90 } 91 auto val = [&]() -> absl::optional<double> { 92 switch (c->shape().element_type()) { 93 case BF16: 94 return static_cast<double>(c->literal().GetFirstElement<bfloat16>()); 95 case F16: 96 return static_cast<double>(c->literal().GetFirstElement<Eigen::half>()); 97 case F32: 98 return c->literal().GetFirstElement<float>(); 99 case F64: 100 return c->literal().GetFirstElement<double>(); 101 default: 102 // Cowardly refuse to consider complex types. 103 return absl::nullopt; 104 } 105 }(); 106 if (!val) { 107 return false; 108 } 109 110 int exp; 111 double mantissa = std::frexp(*val, &exp); 112 // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value 113 // of +/-0.5 therefore indicates that the floating point value is a power of 114 // 2. 115 return mantissa == 0.5 || mantissa == -0.5; 116 } 117 118 // Returns whether the given transpose produces a result which is bit-wise 119 // identical to its operand and thus may be replaced with a bitcast. 120 bool TransposeIsBitcast(const HloInstruction* transpose) { 121 CHECK_EQ(HloOpcode::kTranspose, transpose->opcode()); 122 const HloInstruction* operand = transpose->operand(0); 123 return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(), 124 transpose->dimensions()); 125 } 126 127 // Recursive helper for method below. 128 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper( 129 HloInstruction* instr, HloInstruction* operand, 130 const AlgebraicSimplifierOptions& options) { 131 // Can't replace chain of copies and reshapes with bitcasts if the compiler 132 // used a memory layout which isn't compatible. 133 if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) { 134 return operand; 135 } 136 137 // If the operand is a copy or reshape try to see if the operand's operand 138 // would produce a bitcast with initial instruction. 139 if (HloOpcode::kReshape == operand->opcode() || 140 HloOpcode::kCopy == operand->opcode()) { 141 return BitcastingOperandOfReshapeOrCopyChainHelper( 142 instr, operand->mutable_operand(0), options); 143 } 144 return nullptr; 145 } 146 147 // Returns an operand of a chain of reshapes and copies that is bit-wise 148 // identical to first reshape or copy in the chain. 149 HloInstruction* BitcastingOperandOfReshapeOrCopyChain( 150 HloInstruction* instr, const AlgebraicSimplifierOptions& options) { 151 if (!options.is_layout_sensitive()) { 152 return nullptr; 153 } 154 CHECK(HloOpcode::kReshape == instr->opcode() || 155 HloOpcode::kCopy == instr->opcode()); 156 return BitcastingOperandOfReshapeOrCopyChainHelper( 157 instr, instr->mutable_operand(0), options); 158 } 159 160 bool IsUnstridedSlice(const HloInstruction* hlo) { 161 return absl::c_all_of(hlo->slice_strides(), 162 [](int64 stride) { return stride == 1; }); 163 } 164 165 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain 166 // algebraic expressions to simplified forms. Note: This only supports 167 // simplifications that simply look at the operands of an instruction. For the 168 // more general case a worklist based approach would be needed. 169 class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { 170 public: 171 // Default visitor action is to do nothing and return OK. 172 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 173 return Status::OK(); 174 } 175 176 Status HandleAdd(HloInstruction* add) override; 177 178 Status HandleAnd(HloInstruction* logical_and) override; 179 180 Status HandleBitcast(HloInstruction* bitcast) override; 181 182 Status HandleBitcastConvert(HloInstruction* bitcast) override; 183 184 Status HandleBroadcast(HloInstruction* broadcast) override; 185 186 Status HandleConcatenate(HloInstruction* concatenate) override; 187 188 Status HandleConstant(HloInstruction* constant) override; 189 190 Status HandleCopy(HloInstruction* copy) override; 191 192 Status HandleConvert(HloInstruction* convert) override; 193 194 Status HandleComplex(HloInstruction* complex) override; 195 196 Status HandleReal(HloInstruction* real) override; 197 198 Status HandleImag(HloInstruction* imag) override; 199 200 Status HandleIota(HloInstruction* instruction) override; 201 202 Status HandleConvolution(HloInstruction* convolution) override; 203 204 Status HandleDivide(HloInstruction* divide) override; 205 206 Status HandleDot(HloInstruction* dot) override; 207 208 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 209 210 Status HandleLog(HloInstruction* log) override; 211 212 Status HandleMultiply(HloInstruction* multiply) override; 213 214 Status HandleNegate(HloInstruction* negate) override; 215 216 Status HandleNot(HloInstruction* logical_not) override; 217 218 Status HandleOr(HloInstruction* logical_or) override; 219 220 Status HandlePad(HloInstruction* pad) override; 221 222 Status HandlePower(HloInstruction* power) override; 223 224 Status HandleRemainder(HloInstruction* remainder) override; 225 226 Status HandleReshape(HloInstruction* reshape) override; 227 228 Status HandleReduce(HloInstruction* reduce) override; 229 230 Status HandleReduceWindow(HloInstruction* reduce_window) override; 231 232 Status HandleReverse(HloInstruction* reverse) override; 233 Status HandleSlice(HloInstruction* slice) override; 234 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 235 Status HandleDynamicUpdateSlice( 236 HloInstruction* dynamic_update_slice) override; 237 238 Status HandleSelect(HloInstruction* select) override; 239 240 Status HandleSort(HloInstruction* sort) override; 241 242 Status HandleTranspose(HloInstruction* transpose) override; 243 244 Status HandleSubtract(HloInstruction* sub) override; 245 246 Status HandleMap(HloInstruction* map) override; 247 248 // Returns whether algebraic simplification has occurred. 249 const bool changed() const { return changed_; } 250 251 // Runs the visitor on a computation. 252 static bool Run(HloComputation* computation, 253 const AlgebraicSimplifierOptions& options); 254 255 private: 256 explicit AlgebraicSimplifierVisitor(HloComputation* computation, 257 const AlgebraicSimplifierOptions& options) 258 : computation_(computation), options_(options) {} 259 260 // Transforms Dots where at least one input is a vector or has a degenerate 261 // dimension and converts it into a multiply and reduce. This should enable 262 // more fusion than leaving the nodes as Dot operations. 263 StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot); 264 265 // Removes dimension dim from hlo. 266 HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { 267 CHECK_EQ(hlo->shape().dimensions(dim), 1); 268 return computation_->AddInstruction(HloInstruction::CreateReshape( 269 ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); 270 } 271 272 // Reshapes an instruction to rank 1 if it is not already rank 1. 273 HloInstruction* Flatten(HloInstruction* hlo) { 274 if (hlo->shape().rank() == 1) { 275 return hlo; 276 } 277 return computation_->AddInstruction(HloInstruction::CreateReshape( 278 ShapeUtil::MakeShape(hlo->shape().element_type(), 279 {ShapeUtil::ElementsIn(hlo->shape())}), 280 hlo)); 281 } 282 283 // Converts to primitive type if the input hlo is not that type, otherwise 284 // returns the original hlo. 285 HloInstruction* AsType(HloInstruction* hlo, 286 const PrimitiveType element_type) { 287 if (hlo->shape().element_type() == element_type) { 288 return hlo; 289 } 290 return computation_->AddInstruction(HloInstruction::CreateConvert( 291 ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); 292 } 293 294 // Transposes a dot operand such that the batch dimensions are the msot major, 295 // and the contracting dimensions are most minor. 296 StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor( 297 HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions, 298 absl::Span<const int64> contracting_dimensions) { 299 std::vector<int64> transpose_dimensions(batch_dimensions.begin(), 300 batch_dimensions.end()); 301 for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { 302 if (!(absl::c_linear_search(batch_dimensions, i) || 303 absl::c_linear_search(contracting_dimensions, i))) { 304 transpose_dimensions.push_back(i); 305 } 306 } 307 transpose_dimensions.insert(transpose_dimensions.end(), 308 contracting_dimensions.begin(), 309 contracting_dimensions.end()); 310 return MakeTransposeHlo(dot_operand, transpose_dimensions); 311 } 312 313 // Helper method to perform and add reduction on a list of dimensions. 314 HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) { 315 HloInstruction* zero = 316 computation_->AddInstruction(HloInstruction::CreateConstant( 317 LiteralUtil::Zero(hlo->shape().element_type()).Clone())); 318 HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); 319 Shape shape = ShapeUtil::FilterDimensions( 320 [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, 321 hlo->shape()); 322 return computation_->AddInstruction(HloInstruction::CreateReduce( 323 shape, hlo, zero, dims, AddReduce_computation)); 324 } 325 326 HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { 327 return AddReduce(hlo, std::vector<int64>{dim}); 328 } 329 330 // Convenience method for replacing an instruction with a bitcast. If operand 331 // is not null, then the bitcast will use the specified operand instead of the 332 // operand of the instruction. 333 void ReplaceWithBitcast(HloInstruction* instruction, 334 HloInstruction* operand = nullptr); 335 336 // Replace old instruction with new instruction if old and new instructions 337 // have the same shape. Updates uses and root instruction. Returns whether a 338 // replacement was made. 339 bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction, 340 HloInstruction* new_instruction); 341 342 // Returns whether the shape of the output of the given instructions are the 343 // same for the purposes of simplification. If options_.is_layout_sensitive() 344 // is true, then this tests shape equality including layout 345 // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the 346 // tests shape compatibility (ShapeUtil::Compatible). 347 bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; 348 349 // Returns whether it was possible to transform `root` to a clamp instruction. 350 // With min a minimum instruction, max a maximum instruction, min_operand a 351 // operand of min and max_operand a operand of max. 352 // Precondition: root is either a minimum or a maximum. 353 bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min, 354 HloInstruction* min_operand, 355 HloInstruction* operand, HloInstruction* max, 356 HloInstruction* max_operand); 357 358 // A Broadcast that feeds an element-wise operation with a unique non-scalar 359 // operand can sink to after the operation. 360 StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( 361 HloInstruction* broadcast); 362 363 // Replaces the existing HLO instruction old_instruction, with 364 // new_instruction, and marks the optimizer status as changed. 365 // Returns the Status representing the result of the replace operation. 366 Status ReplaceWithNewInstruction( 367 HloInstruction* old_instruction, 368 std::unique_ptr<HloInstruction> new_instruction) { 369 VLOG(3) << "Replacing instruction:"; 370 VLOG(3) << " old: " << old_instruction->ToString(); 371 VLOG(3) << " new: " << new_instruction->ToString(); 372 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( 373 old_instruction, std::move(new_instruction))); 374 changed_ = true; 375 return Status::OK(); 376 } 377 378 // Replaces the existing HLO instruction old_instruction, with 379 // new_instruction, and marks the optimizer status as changed. 380 // Returns the Status representing the result of the replace operation. 381 Status ReplaceInstruction(HloInstruction* old_instruction, 382 HloInstruction* new_instruction) { 383 VLOG(3) << "Replacing instruction:"; 384 VLOG(3) << " old: " << old_instruction->ToString(); 385 VLOG(3) << " new: " << new_instruction->ToString(); 386 TF_RETURN_IF_ERROR( 387 computation_->ReplaceInstruction(old_instruction, new_instruction)); 388 changed_ = true; 389 return Status::OK(); 390 } 391 392 StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); 393 StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( 394 const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, 395 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); 396 397 StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot); 398 399 HloComputation* GetOrCreateScalarAddComputation() { 400 if (scalar_add_computation_) { 401 return scalar_add_computation_; 402 } 403 404 HloComputation::Builder b("scalar_add_computation"); 405 Shape shape = ShapeUtil::MakeShape(F32, {}); 406 auto scalar_lhs = b.AddInstruction( 407 HloInstruction::CreateParameter(0, shape, "scalar_lhs")); 408 auto scalar_rhs = b.AddInstruction( 409 HloInstruction::CreateParameter(1, shape, "scalar_rhs")); 410 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( 411 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); 412 scalar_add_computation_ = 413 computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); 414 return scalar_add_computation_; 415 } 416 417 // Tries to fold a kPad in the input or filter into the convolution 418 // instruction's window. 419 StatusOr<bool> FoldConvInputPad(HloInstruction* convolution); 420 StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution); 421 422 // Tries to use a kDot in place of the given convolution. 423 StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution); 424 425 // Tries to simplify a slice where the result of the slice is a scalar. 426 StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice); 427 428 // Tries to convert slice(reshape(X)) into reshape(slice(X)) 429 StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice); 430 431 // Current HloComputation instance the AlgebraicSimplifierVisitor is 432 // traversing. 433 HloComputation* computation_; 434 435 // The backend-specific options selected for the algebraic simplifier. 436 const AlgebraicSimplifierOptions& options_; 437 438 // Whether algebraic simplification has occurred. 439 bool changed_ = false; 440 441 // Cached computation for adding two scalar F32. 442 HloComputation* scalar_add_computation_ = nullptr; 443 }; 444 445 } // namespace 446 447 bool AlgebraicSimplifierVisitor::Run( 448 HloComputation* computation, const AlgebraicSimplifierOptions& options) { 449 AlgebraicSimplifierVisitor visitor(computation, options); 450 TF_CHECK_OK(computation->Accept(&visitor)); 451 return visitor.changed_; 452 } 453 454 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, 455 const HloInstruction* rhs) const { 456 if (options_.is_layout_sensitive()) { 457 return ShapeUtil::Equal(lhs->shape(), rhs->shape()); 458 } else { 459 return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); 460 } 461 } 462 463 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, 464 HloInstruction* operand) { 465 CHECK_EQ(1, instruction->operand_count()); 466 if (operand == nullptr) { 467 operand = instruction->mutable_operand(0); 468 } 469 CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), 470 ShapeUtil::ElementsIn(operand->shape())); 471 CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), 472 ShapeUtil::ByteSizeOf(operand->shape())); 473 474 auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary( 475 instruction->shape(), HloOpcode::kBitcast, operand)); 476 TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); 477 } 478 479 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( 480 HloInstruction* old_instruction, HloInstruction* new_instruction) { 481 if (!SameShape(old_instruction, new_instruction)) { 482 return false; 483 } 484 TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction)); 485 return true; 486 } 487 488 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { 489 HloInstruction *lhs, *rhs; 490 CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); 491 492 // A + 0 => A 493 VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); 494 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { 495 return Status::OK(); 496 } 497 // 0 + A => A 498 VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); 499 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { 500 return Status::OK(); 501 } 502 503 // Canonicalization: Put constants on the right. This makes the reassociation 504 // rules below simpler. 505 VLOG(10) << "trying transform [Const + A => A + Const]"; 506 if (Match(add, m::Add(m::Constant(), m::NonConstant()))) { 507 return ReplaceWithNewInstruction( 508 add, 509 HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); 510 } 511 512 // Reassociate to allow constant folding. 513 // 514 // Note: This is not general. For example, we won't reassociate 515 // 516 // (A + C1) + (B + C2) => A + B + (C1 + C2). 517 // 518 VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; 519 HloInstruction *a, *c1, *c2; 520 if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), 521 m::Constant(&c2)))) { 522 TF_ASSIGN_OR_RETURN(auto* sum_of_constants, 523 MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); 524 return ReplaceWithNewInstruction( 525 add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, 526 sum_of_constants)); 527 } 528 529 // A*C + B*C => (A+B)*C 530 // 531 // - If A, B, and C are integers, do this unconditionally. Proof of 532 // correctness: https://rise4fun.com/Alive/u9X. 533 // 534 // - If A, B, and C are floating point, do this if C is a scalar constant or 535 // broadcast of scalar constant and is equal to +/- 2^k for some (possibly 536 // negative) integer k. 537 // 538 // Multiplying by a power of 2 just moves the exponent, so our answer is 539 // exact modulo rounding of intermediate results so long as 540 // 541 // - none of the three products has an exponent which underflows (so the 542 // result is 0 or denormal), and 543 // - none of the three products overflows to inf. 544 // 545 // Proof: See algebraic_simplifier_proof_distributive_property.py. 546 // 547 // We deem these differences in rounding, underflow, and overflow 548 // acceptable in the ML context. 549 HloInstruction *b, *c; 550 if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && 551 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || 552 (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && 553 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && 554 (ShapeUtil::ElementIsIntegral(add->shape()) || 555 IsAllFpConstantPowerOf2(c))) { 556 return ReplaceWithNewInstruction( 557 add, HloInstruction::CreateBinary( 558 add->shape(), HloOpcode::kMultiply, 559 computation_->AddInstruction(HloInstruction::CreateBinary( 560 add->shape(), HloOpcode::kAdd, a, b)), 561 c)); 562 } 563 return Status::OK(); 564 } 565 566 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { 567 HloInstruction *lhs, *rhs; 568 CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); 569 // Simplify logical and 570 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && 571 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { 572 // A && True => A 573 VLOG(10) << "trying transform [A && True => A]: " 574 << logical_and->ToString(); 575 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) { 576 return Status::OK(); 577 } 578 // True && A => A 579 VLOG(10) << "trying transform [True && A => A]: " 580 << logical_and->ToString(); 581 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { 582 return Status::OK(); 583 } 584 585 // A && False => False 586 VLOG(10) << "trying transform [A && False => False]: " 587 << logical_and->ToString(); 588 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { 589 return Status::OK(); 590 } 591 592 // False && A => False 593 VLOG(10) << "trying transform [False && A => False]: " 594 << logical_and->ToString(); 595 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { 596 return Status::OK(); 597 } 598 } 599 600 return Status::OK(); 601 } 602 603 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { 604 // If a bitcast feeds a bitcast, make it a single bitcast. 605 HloInstruction* op; 606 if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { 607 return ReplaceWithNewInstruction( 608 bitcast, 609 HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op)); 610 } 611 // All bitcasts can be eliminated (assuming layout constraints are 612 // satisified). 613 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); 614 return Status::OK(); 615 } 616 617 Status AlgebraicSimplifierVisitor::HandleBitcastConvert( 618 HloInstruction* bitcast) { 619 // Eliminate bitcast converts between same shape. 620 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); 621 return Status::OK(); 622 } 623 624 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { 625 // If a copy feeds a copy, make it a single copy. 626 HloInstruction* op; 627 if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) { 628 return ReplaceWithNewInstruction( 629 copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); 630 } 631 // All copies can be eliminated (assuming layout constraints are satisified). 632 if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) { 633 return Status::OK(); 634 } 635 636 if (HloInstruction* bitcast_operand = 637 BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { 638 ReplaceWithBitcast(copy, bitcast_operand); 639 } 640 641 return Status::OK(); 642 } 643 644 Status AlgebraicSimplifierVisitor::HandleConcatenate( 645 HloInstruction* concatenate) { 646 absl::Span<HloInstruction* const> operands(concatenate->operands()); 647 if (operands.size() == 1) { 648 // Unary concatenates are useless. 649 ReplaceInstructionIfSameShape(concatenate, operands[0]); 650 return Status::OK(); 651 } 652 // Filter out and remove empty operands. 653 std::vector<HloInstruction*> nonempty_operands; 654 for (HloInstruction* operand : operands) { 655 if (!ShapeUtil::IsZeroElementArray(operand->shape())) { 656 nonempty_operands.push_back(operand); 657 } 658 } 659 if (nonempty_operands.size() < operands.size()) { 660 HloInstruction* replacement; 661 if (nonempty_operands.empty()) { 662 replacement = operands[0]; 663 } else if (nonempty_operands.size() == 1) { 664 replacement = nonempty_operands[0]; 665 } else { 666 replacement = 667 computation_->AddInstruction(concatenate->CloneWithNewOperands( 668 concatenate->shape(), nonempty_operands)); 669 } 670 VLOG(10) << "trying to replace " << concatenate->ToString() << " with " 671 << replacement->ToString(); 672 ReplaceInstructionIfSameShape(concatenate, replacement); 673 return Status::OK(); 674 } 675 676 // Check if we can merge "adjacent" slice operands which take slices from the 677 // same other op. For simplicity we only merge unstrided slices. 678 int64 concatenate_dimension = concatenate->concatenate_dimension(); 679 for (int64 i = 0; i < operands.size(); ++i) { 680 if (operands[i]->opcode() != HloOpcode::kSlice || 681 !IsUnstridedSlice(operands[i])) { 682 continue; 683 } 684 int64 slice_end = operands[i]->slice_limits(concatenate_dimension); 685 HloInstruction* slice_operand = operands[i]->mutable_operand(0); 686 int64 j = i + 1; 687 while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice && 688 IsUnstridedSlice(operands[j]) && 689 operands[j]->operand(0) == slice_operand && 690 operands[j]->slice_starts(concatenate_dimension) == slice_end) { 691 // Check that all the slice_start values are the same in all other 692 // dimensions. This implies that the slice_limit values are also the same, 693 // because operands of concatenate need to have the same shape, and we 694 // already checked that the slices are unstrided. 695 bool same_other_starts = true; 696 for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) { 697 if (k == concatenate_dimension) { 698 continue; 699 } 700 if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) { 701 same_other_starts = false; 702 break; 703 } 704 } 705 if (!same_other_starts) { 706 break; 707 } 708 slice_end = operands[j]->slice_limits(concatenate_dimension); 709 ++j; 710 } 711 if (j - i > 1) { 712 Shape new_slice_shape = operands[i]->shape(); 713 new_slice_shape.set_dimensions( 714 concatenate_dimension, 715 slice_end - operands[i]->slice_starts(concatenate_dimension)); 716 auto new_limit_indices = operands[i]->slice_limits(); 717 new_limit_indices[concatenate_dimension] = slice_end; 718 auto new_slice_op = 719 computation_->AddInstruction(HloInstruction::CreateSlice( 720 new_slice_shape, slice_operand, 721 /*start_indices=*/operands[i]->slice_starts(), 722 /*limit_indices=*/new_limit_indices, 723 /*strides=*/operands[i]->slice_strides())); 724 std::vector<HloInstruction*> new_operands; 725 for (int64 k = 0; k < i; ++k) { 726 new_operands.push_back(operands[k]); 727 } 728 new_operands.push_back(new_slice_op); 729 for (int64 k = j; k < operands.size(); ++k) { 730 new_operands.push_back(operands[k]); 731 } 732 auto replacement = 733 computation_->AddInstruction(concatenate->CloneWithNewOperands( 734 concatenate->shape(), new_operands)); 735 ReplaceInstructionIfSameShape(concatenate, replacement); 736 return Status::OK(); 737 } 738 } 739 740 if (operands.size() == 2) { 741 // A binary concat with a broadcasted scalar as an operand can be converted 742 // into a pad which is simpler to fold into other operations. 743 bool is_effective_low_pad = Match( 744 operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); 745 bool is_effective_high_pad = Match( 746 operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); 747 if (!is_effective_low_pad && !is_effective_high_pad) { 748 return Status::OK(); 749 } 750 PaddingConfig padding_config; 751 for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) { 752 auto padding_config_dim = padding_config.add_dimensions(); 753 padding_config_dim->set_edge_padding_high(0); 754 padding_config_dim->set_edge_padding_low(0); 755 padding_config_dim->set_interior_padding(0); 756 if (dim == concatenate_dimension) { 757 if (is_effective_low_pad) { 758 padding_config_dim->set_edge_padding_low( 759 operands[0]->shape().dimensions(dim)); 760 } else { 761 padding_config_dim->set_edge_padding_high( 762 operands[1]->shape().dimensions(dim)); 763 } 764 } 765 } 766 int64 operand_to_pad = is_effective_low_pad ? 1 : 0; 767 int64 pad_value_operand = is_effective_low_pad ? 0 : 1; 768 HloInstruction* pad = 769 computation_->AddInstruction(HloInstruction::CreatePad( 770 concatenate->shape(), operands[operand_to_pad], 771 operands[pad_value_operand]->mutable_operand(0), padding_config)); 772 return ReplaceInstruction(concatenate, pad); 773 } 774 return Status::OK(); 775 } 776 777 static HloInstruction* BuildTupleConstant(HloComputation* computation, 778 const LiteralSlice& literal) { 779 if (literal.shape().IsTuple()) { 780 std::vector<HloInstruction*> elems; 781 elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); 782 for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { 783 elems.push_back( 784 BuildTupleConstant(computation, LiteralSlice(literal, {i}))); 785 } 786 return computation->AddInstruction(HloInstruction::CreateTuple(elems)); 787 } else { 788 return computation->AddInstruction( 789 HloInstruction::CreateConstant(literal.Clone())); 790 } 791 } 792 793 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { 794 // Tuple constants aren't directly supported by any backend. Expand them into 795 // explicit Tuple instructions. 796 if (constant->shape().IsTuple()) { 797 return ReplaceInstruction( 798 constant, BuildTupleConstant(computation_, constant->literal())); 799 } 800 801 if (constant->shape().element_type() == TOKEN) { 802 return Status::OK(); 803 } 804 805 // If a literal is all the same element replace it with a scalar broadcast. 806 if (ShapeUtil::ElementsIn(constant->shape()) > 1 && 807 constant->literal().IsAllFirst()) { 808 Literal unique_scalar( 809 LiteralUtil::GetFirstScalarLiteral(constant->literal())); 810 HloInstruction* scalar = computation_->AddInstruction( 811 HloInstruction::CreateConstant(std::move(unique_scalar))); 812 return ReplaceWithNewInstruction( 813 constant, 814 HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); 815 } 816 817 // If a literal is an increasing sequence from zero, replace it with an iota. 818 if (constant->shape().rank() == 1 && 819 ShapeUtil::ElementsIn(constant->shape()) > 1 && 820 constant->literal().IsR1Iota()) { 821 return ReplaceWithNewInstruction( 822 constant, HloInstruction::CreateIota(constant->shape(), 0)); 823 } 824 return Status::OK(); 825 } 826 827 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { 828 HloInstruction *lhs, *rhs; 829 CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); 830 // A - 0 => A 831 VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); 832 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { 833 return Status::OK(); 834 } 835 836 // Canonicalize subtraction of a constant to addition. 837 VLOG(10) << "trying transform [A - Const => A + (-Const)]"; 838 if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { 839 HloInstruction* negative_const = computation_->AddInstruction( 840 HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); 841 return ReplaceWithNewInstruction( 842 sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, 843 negative_const)); 844 } 845 846 return Status::OK(); 847 } 848 namespace { 849 template <typename T> 850 Status InvertConstant(const HloInstruction& constant, Literal* result) { 851 return result->Populate<T>([&](absl::Span<const int64> indices) { 852 return T{1.0} / constant.literal().Get<T>(indices); 853 }); 854 } 855 856 template <typename T> 857 std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide, 858 HloComputation* computation) { 859 HloInstruction *a, *b, *c; 860 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); 861 862 if (ShapeUtil::ElementIsIntegral(divide->shape()) && 863 !Match(b, m::ConstantEffectiveScalar(&c)) && 864 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { 865 return nullptr; 866 } 867 868 if (ShapeUtil::ElementIsSigned(divide->shape())) { 869 int64 b_value = c->literal().GetFirstElement<T>(); 870 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) { 871 // Handle negative dividends by negating the result of the division. 872 HloInstruction* zero_like_a = BroadcastZeros( 873 computation, a->shape().element_type(), a->shape().dimensions()); 874 875 auto* dividend_is_negative = 876 computation->AddInstruction(HloInstruction::CreateCompare( 877 ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, 878 ComparisonDirection::kLt)); 879 880 auto* negated_dividend = computation->AddInstruction( 881 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); 882 883 auto* abs_dividend = 884 computation->AddInstruction(HloInstruction::CreateTernary( 885 a->shape(), HloOpcode::kSelect, dividend_is_negative, 886 negated_dividend, a)); 887 888 int log2_abs_b_value = tensorflow::Log2Floor64(b_value); 889 890 auto* shift_amount = 891 computation->AddInstruction(HloInstruction::CreateConstant( 892 LiteralUtil::CreateR0<T>(log2_abs_b_value))); 893 if (!ShapeUtil::IsScalar(b->shape())) { 894 shift_amount = computation->AddInstruction( 895 HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); 896 } 897 898 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( 899 divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend, 900 shift_amount)); 901 902 auto* neqated_quotient = 903 computation->AddInstruction(HloInstruction::CreateUnary( 904 quotient->shape(), HloOpcode::kNegate, quotient)); 905 906 return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect, 907 dividend_is_negative, 908 neqated_quotient, quotient); 909 } 910 } else { 911 uint64 b_value = c->literal().GetFirstElement<T>(); 912 if (IsPowerOfTwo(b_value)) { 913 int log2_abs_b_value = tensorflow::Log2Floor64(b_value); 914 HloInstruction* shift_amount = 915 computation->AddInstruction(HloInstruction::CreateConstant( 916 LiteralUtil::CreateR0<T>(log2_abs_b_value))); 917 if (!ShapeUtil::IsScalar(b->shape())) { 918 shift_amount = computation->AddInstruction( 919 HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); 920 } 921 return HloInstruction::CreateBinary( 922 divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount); 923 } 924 } 925 926 return nullptr; 927 } 928 } // namespace 929 930 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { 931 HloInstruction *a, *b, *c, *d; 932 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); 933 // A/1 => A 934 VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); 935 if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) { 936 return Status::OK(); 937 } 938 939 // A / B => A >> log2(B) if B is a power of 2. 940 switch (divide->shape().element_type()) { 941 case S8: 942 if (std::unique_ptr<HloInstruction> shift = 943 TryDivideToShift<int8>(divide, computation_)) { 944 return ReplaceWithNewInstruction(divide, std::move(shift)); 945 } 946 break; 947 case S16: 948 if (std::unique_ptr<HloInstruction> shift = 949 TryDivideToShift<int16>(divide, computation_)) { 950 return ReplaceWithNewInstruction(divide, std::move(shift)); 951 } 952 break; 953 case S32: 954 if (std::unique_ptr<HloInstruction> shift = 955 TryDivideToShift<int32>(divide, computation_)) { 956 return ReplaceWithNewInstruction(divide, std::move(shift)); 957 } 958 break; 959 case S64: 960 if (std::unique_ptr<HloInstruction> shift = 961 TryDivideToShift<int64>(divide, computation_)) { 962 return ReplaceWithNewInstruction(divide, std::move(shift)); 963 } 964 break; 965 case U8: 966 if (std::unique_ptr<HloInstruction> shift = 967 TryDivideToShift<uint8>(divide, computation_)) { 968 return ReplaceWithNewInstruction(divide, std::move(shift)); 969 } 970 break; 971 case U16: 972 if (std::unique_ptr<HloInstruction> shift = 973 TryDivideToShift<uint16>(divide, computation_)) { 974 return ReplaceWithNewInstruction(divide, std::move(shift)); 975 } 976 break; 977 case U32: 978 if (std::unique_ptr<HloInstruction> shift = 979 TryDivideToShift<uint32>(divide, computation_)) { 980 return ReplaceWithNewInstruction(divide, std::move(shift)); 981 } 982 break; 983 case U64: 984 if (std::unique_ptr<HloInstruction> shift = 985 TryDivideToShift<uint64>(divide, computation_)) { 986 return ReplaceWithNewInstruction(divide, std::move(shift)); 987 } 988 break; 989 default: 990 break; 991 } 992 993 Shape* shape; 994 // exp(A)/exp(B) => exp(A-B) 995 if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) 996 .WithShape(m::Shape(&shape)))) { 997 VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); 998 HloInstruction* subtract = computation_->AddInstruction( 999 HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b)); 1000 return ReplaceWithNewInstruction( 1001 divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract)); 1002 } 1003 1004 // A/exp(B) => A*exp(-B) 1005 if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) { 1006 VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); 1007 HloInstruction* negate = computation_->AddInstruction( 1008 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b)); 1009 HloInstruction* new_exp = computation_->AddInstruction( 1010 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); 1011 return ReplaceWithNewInstruction( 1012 divide, HloInstruction::CreateBinary(divide->shape(), 1013 HloOpcode::kMultiply, a, new_exp)); 1014 } 1015 1016 // A/pow(B,C) => A*pow(B,-C) 1017 if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) { 1018 VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); 1019 // The output shape of the created negate operator should be the same as the 1020 // input. 1021 const Shape& negate_shape = c->shape(); 1022 HloInstruction* negate = computation_->AddInstruction( 1023 HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c)); 1024 // And the power operator should retain the output shape of the old one. 1025 const Shape& new_power_shape = b->shape(); 1026 HloInstruction* new_power = 1027 computation_->AddInstruction(HloInstruction::CreateBinary( 1028 new_power_shape, HloOpcode::kPower, b, negate)); 1029 return ReplaceWithNewInstruction( 1030 divide, HloInstruction::CreateBinary( 1031 divide->shape(), HloOpcode::kMultiply, a, new_power)); 1032 } 1033 1034 // A/sqrt(B) => A*rsqrt(X). 1035 if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) { 1036 auto* rsqrt = computation_->AddInstruction( 1037 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b)); 1038 return ReplaceWithNewInstruction( 1039 divide, HloInstruction::CreateBinary(rsqrt->shape(), 1040 HloOpcode::kMultiply, a, rsqrt)); 1041 } 1042 1043 // A/rsqrt(B) => A*sqrt(B). 1044 if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) { 1045 auto* sqrt = computation_->AddInstruction( 1046 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b)); 1047 return ReplaceWithNewInstruction( 1048 divide, HloInstruction::CreateBinary(sqrt->shape(), 1049 HloOpcode::kMultiply, a, sqrt)); 1050 } 1051 1052 // Simplifying integral division would produce unexpected results. 1053 if (ShapeUtil::ElementIsIntegral(divide->shape())) { 1054 return Status::OK(); 1055 } 1056 1057 // A / Const => A * (1 / Const) 1058 // 1059 // (Backends can do this transformation, but generally only if the constant is 1060 // a scalar.) 1061 if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { 1062 Shape result_shape = b->literal().shape(); 1063 Literal new_literal(result_shape); 1064 switch (result_shape.element_type()) { 1065 case F16: 1066 TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal)); 1067 break; 1068 case F32: 1069 TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal)); 1070 break; 1071 case BF16: 1072 TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal)); 1073 break; 1074 case F64: 1075 TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal)); 1076 break; 1077 case C64: 1078 TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal)); 1079 break; 1080 case C128: 1081 TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal)); 1082 break; 1083 default: 1084 return Status::OK(); 1085 } 1086 auto inverse = computation_->AddInstruction( 1087 HloInstruction::CreateConstant((new_literal.Clone()))); 1088 TF_ASSIGN_OR_RETURN(auto new_divide, 1089 MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); 1090 return ReplaceInstruction(divide, new_divide); 1091 } 1092 1093 // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) 1094 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), 1095 m::Divide(m::Op(&c), m::Op(&d))))) { 1096 TF_ASSIGN_OR_RETURN(auto a_times_d, 1097 MakeBinaryHlo(HloOpcode::kMultiply, a, d)); 1098 TF_ASSIGN_OR_RETURN(auto b_times_c, 1099 MakeBinaryHlo(HloOpcode::kMultiply, b, c)); 1100 TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, 1101 a_times_d, b_times_c)); 1102 1103 return ReplaceInstruction(divide, new_divide); 1104 } 1105 1106 // (A / B) / C => A / (B * C) 1107 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { 1108 TF_ASSIGN_OR_RETURN(auto b_times_c, 1109 MakeBinaryHlo(HloOpcode::kMultiply, b, c)); 1110 TF_ASSIGN_OR_RETURN(auto new_divide, 1111 MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); 1112 return ReplaceInstruction(divide, new_divide); 1113 } 1114 1115 // A / (B / C) => (A*C) / B 1116 if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { 1117 TF_ASSIGN_OR_RETURN(auto a_times_c, 1118 MakeBinaryHlo(HloOpcode::kMultiply, a, c)); 1119 TF_ASSIGN_OR_RETURN(auto new_divide, 1120 MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); 1121 return ReplaceInstruction(divide, new_divide); 1122 } 1123 1124 return Status::OK(); 1125 } 1126 1127 StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction( 1128 HloInstruction* dot) { 1129 HloInstruction *lhs, *rhs; 1130 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); 1131 1132 const auto kept_dim = [](int64 rank, int64 contracting_dimension, 1133 absl::Span<const int64> batch_dimensions) -> int64 { 1134 for (int64 i = 0; i < rank; ++i) { 1135 if (i != contracting_dimension && 1136 !absl::c_linear_search(batch_dimensions, i)) { 1137 return i; 1138 } 1139 } 1140 return -1; 1141 }; 1142 1143 const int64 dot_rank = dot->shape().rank(); 1144 const int64 rhs_rank = rhs->shape().rank(); 1145 const int64 lhs_rank = lhs->shape().rank(); 1146 const auto& dnums = dot->dot_dimension_numbers(); 1147 if (dnums.rhs_contracting_dimensions_size() != 1) { 1148 return false; 1149 } 1150 if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { 1151 return false; 1152 } 1153 int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); 1154 int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, 1155 AsInt64Slice(dnums.lhs_batch_dimensions())); 1156 // If there is no non-contracting dimension in rank 2, do not strength reduce. 1157 if (lhs_kept_dim == -1 && lhs_rank > 1) { 1158 return false; 1159 } 1160 if (lhs->IsRank2Transpose()) { 1161 lhs = lhs->mutable_operand(0); 1162 std::swap(lhs_collapsing_dim, lhs_kept_dim); 1163 } 1164 1165 int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); 1166 int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, 1167 AsInt64Slice(dnums.rhs_batch_dimensions())); 1168 // If there is no non-contracting dimension in rank 2, do not strength reduce. 1169 if (rhs_kept_dim == -1 && rhs_rank > 1) { 1170 return false; 1171 } 1172 if (rhs->IsRank2Transpose()) { 1173 rhs = rhs->mutable_operand(0); 1174 std::swap(rhs_collapsing_dim, rhs_kept_dim); 1175 } 1176 1177 auto reshape_if_necessary = [&](HloInstruction* hlo) { 1178 hlo = AsType(hlo, dot->shape().element_type()); 1179 if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { 1180 hlo = computation_->AddInstruction( 1181 HloInstruction::CreateReshape(dot->shape(), hlo)); 1182 } 1183 return hlo; 1184 }; 1185 1186 auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { 1187 return AddReduce(AsType(hlo, F32), dim); 1188 }; 1189 1190 auto broadcast = [&](HloInstruction* hlo, const Shape& shape, 1191 absl::Span<const int64> dims) { 1192 return computation_->AddInstruction( 1193 HloInstruction::CreateBroadcast(shape, hlo, dims)); 1194 }; 1195 1196 auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, 1197 int64 dim) { 1198 return broadcast(hlo, shape, {dim}); 1199 }; 1200 1201 auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { 1202 return computation_->AddInstruction(HloInstruction::CreateBinary( 1203 local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); 1204 }; 1205 1206 // Strength reduce dot(a[K] , b[K]) = 1207 // reshape(result.shape, 1208 // reduce_sum(multiply(a, b), {0})) 1209 if (rhs_rank == 1 && lhs_rank == 1) { 1210 TF_RETURN_IF_ERROR(ReplaceInstruction( 1211 dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); 1212 return true; 1213 } 1214 1215 if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && 1216 ShapeUtil::IsEffectiveScalar(lhs->shape())) { 1217 TF_RETURN_IF_ERROR(ReplaceInstruction( 1218 dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); 1219 return true; 1220 } 1221 1222 // Simplify outer product into multiply with broadcasting. 1223 // 1224 // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) 1225 if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { 1226 TF_RETURN_IF_ERROR(ReplaceInstruction( 1227 dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), 1228 broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); 1229 return true; 1230 } 1231 1232 // Strength reduce dot(a[1, K], b) = 1233 // reshape(result.shape, 1234 // reduce_sum( 1235 // multiply(broadcast(reshape(a, [K]), {0}), b), 1236 // {0}) 1237 // ) 1238 // ) 1239 if (lhs_rank == 1 || 1240 (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { 1241 if (rhs->shape().rank() == 1) { 1242 TF_RETURN_IF_ERROR( 1243 ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( 1244 multiply(Flatten(lhs), rhs), 0)))); 1245 return true; 1246 } 1247 TF_RETURN_IF_ERROR(ReplaceInstruction( 1248 dot, reshape_if_necessary(add_reduce_in_f32( 1249 multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), 1250 rhs_collapsing_dim), 1251 rhs), 1252 rhs_collapsing_dim)))); 1253 return true; 1254 } 1255 1256 // Strength reduce dot(a, b[K, 1]) = 1257 // reshape(result.shape, 1258 // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) 1259 // ) 1260 if (rhs_rank == 1 || 1261 (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { 1262 TF_RETURN_IF_ERROR(ReplaceInstruction( 1263 dot, reshape_if_necessary(add_reduce_in_f32( 1264 multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), 1265 lhs_collapsing_dim)), 1266 lhs_collapsing_dim)))); 1267 return true; 1268 } 1269 1270 // Only consider kDot with batch dimension. 1271 if (dot_rank <= 2) { 1272 return false; 1273 } 1274 1275 CHECK_EQ(rhs_rank, lhs_rank); 1276 CHECK_EQ(dot_rank, lhs_rank); 1277 // If there is more than one non-contracting dimension or the batch dimensions 1278 // are not equal, bail out since transposes may be required to do a strength 1279 // reduction. 1280 if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || 1281 !absl::c_equal(dnums.lhs_batch_dimensions(), 1282 dnums.rhs_batch_dimensions())) { 1283 return false; 1284 } 1285 1286 auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { 1287 absl::InlinedVector<int64, 8> dims; 1288 for (int64 i = 0; i < rank; ++i) { 1289 if (i != non_broadcast_dim) { 1290 dims.push_back(i); 1291 } 1292 } 1293 return dims; 1294 }; 1295 1296 // If the contracting dimension is 1, remove the degnerate dimnensions from 1297 // the lhs and rhs, broadcast each to the result shape and multiply. 1298 if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && 1299 (rhs_kept_dim == rhs_rank - 1 || 1300 (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { 1301 CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); 1302 const int64 lhs_kept_dim_in_output = 1303 lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; 1304 absl::InlinedVector<int64, 8> lhs_broadcast_dims; 1305 for (const int64 dim : dnums.lhs_batch_dimensions()) { 1306 lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); 1307 } 1308 absl::InlinedVector<int64, 8> rhs_broadcast_dims = lhs_broadcast_dims; 1309 lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); 1310 absl::c_sort(lhs_broadcast_dims); 1311 rhs_broadcast_dims.push_back(dot_rank - 1); 1312 absl::c_sort(rhs_broadcast_dims); 1313 TF_RETURN_IF_ERROR(ReplaceInstruction( 1314 dot, reshape_if_necessary( 1315 multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), 1316 dot->shape(), lhs_broadcast_dims), 1317 broadcast(StripDim(rhs, rhs_collapsing_dim), 1318 dot->shape(), rhs_broadcast_dims))))); 1319 return true; 1320 } 1321 1322 // If the lhs and rhs non-contracting dimensions are both one, strip each one, 1323 // multiply and then reduce the collapsing dimension 1324 if (lhs->shape().dimensions(lhs_kept_dim) == 1 && 1325 rhs->shape().dimensions(rhs_kept_dim) == 1 && 1326 lhs_kept_dim == rhs_kept_dim) { 1327 auto new_lhs = StripDim(lhs, lhs_kept_dim); 1328 auto new_rhs = StripDim(rhs, rhs_kept_dim); 1329 const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim 1330 ? (rhs_collapsing_dim - 1) 1331 : rhs_collapsing_dim; 1332 TF_RETURN_IF_ERROR( 1333 ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( 1334 multiply(new_lhs, new_rhs), reduce_dim)))); 1335 return true; 1336 } 1337 1338 // If the lhs non-contracting dimensions is one, strip the one, brodcast to 1339 // the rhs shape, multiply and then reduce the collapsing dimension 1340 if (lhs->shape().dimensions(lhs_kept_dim) == 1) { 1341 auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), 1342 broadcast_dims(rhs_rank, rhs_kept_dim)); 1343 TF_RETURN_IF_ERROR(ReplaceInstruction( 1344 dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), 1345 rhs_collapsing_dim)))); 1346 return true; 1347 } 1348 1349 // If the rhs non-contracting dimensions is one, strip the one, brodcast to 1350 // the lhs shape, multiply and then reduce the collapsing dimension 1351 if (rhs->shape().dimensions(rhs_kept_dim) == 1) { 1352 auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), 1353 broadcast_dims(lhs_rank, lhs_kept_dim)); 1354 TF_RETURN_IF_ERROR(ReplaceInstruction( 1355 dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), 1356 lhs_collapsing_dim)))); 1357 return true; 1358 } 1359 1360 return false; 1361 } 1362 1363 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat( 1364 HloInstruction* dot) { 1365 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 1366 if (dnums.lhs_contracting_dimensions_size() != 1 || 1367 dnums.lhs_batch_dimensions_size() != 0) { 1368 return nullptr; 1369 } 1370 1371 const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); 1372 const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); 1373 HloInstruction *lhs, *rhs; 1374 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); 1375 1376 TF_ASSIGN_OR_RETURN( 1377 HloInstruction * optimized_lhs_concat, 1378 OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs, 1379 rhs_contracting_dim, /*swapped=*/false)); 1380 if (optimized_lhs_concat) { 1381 return optimized_lhs_concat; 1382 } 1383 1384 return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs, 1385 lhs_contracting_dim, /*swapped=*/true); 1386 } 1387 1388 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( 1389 const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, 1390 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { 1391 bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && 1392 lhs->concatenate_dimension() == lhs_contracting_dim && 1393 rhs->opcode() == HloOpcode::kConstant; 1394 if (!can_optimize) { 1395 return nullptr; 1396 } 1397 1398 // We're replacing this: 1399 // 1400 // +-----+-----+-----+ +-------------------+ 1401 // | | | | | | 1402 // | | | | | R_0 | 1403 // | | | | | | 1404 // | | | | +-------------------+ 1405 // | | | | | | 1406 // | L_0 | L_1 | L_2 | * | R_1 | 1407 // | | | | | | 1408 // | | | | +-------------------+ 1409 // | | | | | | 1410 // | | | | | R_2 | 1411 // | | | | | | 1412 // +-----+-----+-----+ +-------------------+ 1413 // 1414 // with this: 1415 // 1416 // [Sum over i] 1417 // 1418 // +-----+ +-------------------+ 1419 // | | | | 1420 // | | * | R_i | 1421 // | | | | 1422 // | | +-------------------+ 1423 // | | 1424 // | L_i | 1425 // | | 1426 // | | 1427 // | | 1428 // | | 1429 // | | 1430 // +-----+ 1431 // 1432 // where the LHS is a concatenate operation (so we can "split" the LHS tensor 1433 // for free) and the RHS is a constant tensor (and thus can be split at 1434 // compile time). In the future, we may also want to do this when both the 1435 // LHS and the RHS are concatenate operations that line up along the dimension 1436 // being contracted over. 1437 // 1438 // We should be able to generalize this transform to work on a non-constant 1439 // RHS when/if we have in-place slices or support input-fusing slices into 1440 // Dots. 1441 1442 // Dimension numbers for the new dot instructions we'll create (L_i * R_i in 1443 // the diagram above). 1444 DotDimensionNumbers new_dot_dnums; 1445 new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim 1446 : lhs_contracting_dim); 1447 new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim 1448 : rhs_contracting_dim); 1449 1450 // Here we use the MKN notation, where the contracted dimension has K 1451 // elements and the two non-contracted dimensions have M and N elements. 1452 HloInstruction* add_result = nullptr; 1453 int64 rhs_contracting_dim_offset = 0; 1454 int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); 1455 for (HloInstruction* concat_op : lhs->operands()) { 1456 int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); 1457 Shape rhs_slice_shape(rhs->shape()); 1458 rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); 1459 1460 std::array<int64, 2> start_indices; 1461 start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; 1462 start_indices[1 - rhs_contracting_dim] = 0; 1463 1464 std::array<int64, 2> limit_indices; 1465 limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; 1466 limit_indices[1 - rhs_contracting_dim] = n; 1467 1468 HloInstruction* rhs_slice = 1469 computation_->AddInstruction(HloInstruction::CreateSlice( 1470 rhs_slice_shape, rhs, /*start_indices=*/start_indices, 1471 /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); 1472 1473 // TODO(b/69062148): We can get rid of `swapped` once all backends support 1474 // "non-canonical" contraction dimensions (that contracts dimension 1 of the 1475 // LHS with dimension 0 of the RHS). But for now we keep the same 1476 // contraction dimensions as the incoming dot operation to ensure the new 1477 // dot operations can be lowered. 1478 HloInstruction *new_dot_lhs, *new_dot_rhs; 1479 if (swapped) { 1480 new_dot_lhs = rhs_slice; 1481 new_dot_rhs = concat_op; 1482 } else { 1483 new_dot_lhs = concat_op; 1484 new_dot_rhs = rhs_slice; 1485 } 1486 1487 auto* new_dot = computation_->AddInstruction( 1488 HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, 1489 new_dot_dnums, dot.precision_config())); 1490 1491 if (add_result) { 1492 add_result = computation_->AddInstruction(HloInstruction::CreateBinary( 1493 dot.shape(), HloOpcode::kAdd, add_result, new_dot)); 1494 } else { 1495 add_result = new_dot; 1496 } 1497 1498 rhs_contracting_dim_offset += sub_k; 1499 } 1500 1501 return add_result; 1502 } 1503 1504 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( 1505 HloInstruction* dot) { 1506 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 1507 if (dnums.lhs_contracting_dimensions_size() != 1 || 1508 dnums.rhs_contracting_dimensions_size() != 1 || 1509 dnums.lhs_batch_dimensions_size() != 0 || 1510 dnums.rhs_batch_dimensions_size() != 0 || 1511 dot->shape().dimensions_size() != 2) { // dot output 2D 1512 VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; 1513 return nullptr; 1514 } 1515 1516 // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). 1517 // Currently a Gather is a DynamicSlice. 1518 auto is_dynamic_slice_constant_combination = 1519 [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { 1520 // First operand is a DynamicSlice(Constant). 1521 if (a->opcode() != HloOpcode::kDynamicSlice) { 1522 return false; 1523 } 1524 auto* dynamic_slice_op = a->operand(0); 1525 if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { 1526 return false; 1527 } 1528 // Second operand is a Constant. 1529 if (b->opcode() != HloOpcode::kConstant) { 1530 return false; 1531 } 1532 // The DynamicSlice output is a vector. 1533 const Shape& dynamic_slice_shape = a->shape(); 1534 if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { 1535 return false; 1536 } 1537 // Constant size is the same before and after slice in the contracting 1538 // dimension, otherwise we either must precompute for all possible slice 1539 // indices or dot is invalid. 1540 const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); 1541 if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != 1542 dynamic_slice_shape.dimensions(a_contracting_dimension)) { 1543 return false; 1544 } 1545 return true; 1546 }; 1547 1548 HloInstruction* lhs = dot->mutable_operand(0); 1549 HloInstruction* rhs = dot->mutable_operand(1); 1550 int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); 1551 int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); 1552 1553 if (!is_dynamic_slice_constant_combination( 1554 lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && 1555 !is_dynamic_slice_constant_combination( 1556 rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { 1557 VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " 1558 "dot(ctB, DS(ctA)), where the two constants have equal " 1559 "contracting dimensions."; 1560 return nullptr; 1561 } 1562 1563 // LHS is DynamicSlice: 1564 // input: dot(DS(ctA), ctB)) 1565 // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. 1566 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. 1567 // output: DS(dot(ctA, ctB)) 1568 // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. 1569 1570 // RHS is DynamicSlice: 1571 // input: dot(ctA, DS(ctB)) 1572 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). 1573 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. 1574 // output: DS(dot(ctA, ctB)) 1575 // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. 1576 1577 bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; 1578 HloDynamicSliceInstruction* dynamic_slice = 1579 lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs) 1580 : Cast<HloDynamicSliceInstruction>(rhs); 1581 1582 // ctA: 1583 HloInstruction* left_operand = 1584 lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; 1585 // ctB: 1586 HloInstruction* right_operand = 1587 lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); 1588 // Build ctA x ctB. 1589 const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); 1590 const int n = 1591 right_operand->shape().dimensions(1 - rhs_contracting_dimension); 1592 auto memoized_shape = 1593 ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); 1594 auto* memoized_inst = computation_->AddInstruction( 1595 HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, 1596 dnums, dot->precision_config())); 1597 // Get pair {start, 0} or {0, start}. 1598 // Position of start: 1599 int index_of_non_zero_start = lhs_is_dynamic_slice 1600 ? 1 - lhs_contracting_dimension 1601 : 1 - rhs_contracting_dimension; 1602 // Position of zero: 1603 int index_of_zero_start = 1 - index_of_non_zero_start; 1604 1605 // Slice out start and 0 components and reorder if necessary. 1606 auto indices_type = dynamic_slice->operand(1)->shape().element_type(); 1607 Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); 1608 Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); 1609 HloInstruction* non_zero_start = 1610 dynamic_slice->mutable_operand(1 + index_of_non_zero_start); 1611 HloInstruction* zero_start = 1612 dynamic_slice->mutable_operand(1 + index_of_zero_start); 1613 std::vector<HloInstruction*> new_start_indices; 1614 if (lhs_is_dynamic_slice) { 1615 new_start_indices = {non_zero_start, zero_start}; 1616 } else { 1617 new_start_indices = {zero_start, non_zero_start}; 1618 } 1619 1620 // Build DynamicSlice(ctA x ctB). 1621 const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; 1622 const int new_slice_n = lhs_is_dynamic_slice ? n : 1; 1623 auto* memoized_lookup = 1624 computation_->AddInstruction(HloInstruction::CreateDynamicSlice( 1625 dot->shape(), memoized_inst, new_start_indices, 1626 {new_slice_m, new_slice_n})); 1627 1628 return memoized_lookup; 1629 } 1630 1631 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { 1632 HloInstruction *lhs, *rhs; 1633 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); 1634 if (options_.is_layout_sensitive()) { 1635 return Status::OK(); 1636 } 1637 // Replace a zero element dot with a broadcast of the constant 0. 1638 if (ShapeUtil::IsZeroElementArray(dot->shape()) || 1639 ShapeUtil::IsZeroElementArray(lhs->shape()) || 1640 ShapeUtil::IsZeroElementArray(rhs->shape())) { 1641 auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( 1642 LiteralUtil::Zero(dot->shape().element_type()))); 1643 return ReplaceWithNewInstruction( 1644 dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); 1645 } 1646 1647 // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are 1648 // rank 2 or below. 1649 if (dot->shape().element_type() != F32 && 1650 dot->shape().element_type() != BF16) { 1651 return Status::OK(); 1652 } 1653 1654 // If there are no contracting dimensions, a dot can be rewritten as 1655 // mul(broadcast(transpose(x)),broadcast(transpose(y))) 1656 if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { 1657 TF_ASSIGN_OR_RETURN( 1658 HloInstruction * new_lhs, 1659 NormalizeDotOperandToBatchMajorAndContractingMinor( 1660 lhs, 1661 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), 1662 AsInt64Slice( 1663 dot->dot_dimension_numbers().lhs_contracting_dimensions()))); 1664 if (dot->shape().rank() != lhs->shape().rank()) { 1665 std::vector<int64> lhs_broadcast_dims(lhs->shape().rank()); 1666 absl::c_iota(lhs_broadcast_dims, 0); 1667 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( 1668 dot->shape(), new_lhs, lhs_broadcast_dims)); 1669 } 1670 TF_ASSIGN_OR_RETURN( 1671 HloInstruction * new_rhs, 1672 NormalizeDotOperandToBatchMajorAndContractingMinor( 1673 rhs, 1674 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), 1675 AsInt64Slice( 1676 dot->dot_dimension_numbers().rhs_contracting_dimensions()))); 1677 if (dot->shape().rank() != rhs->shape().rank()) { 1678 std::vector<int64> rhs_broadcast_dims( 1679 dot->dot_dimension_numbers().lhs_batch_dimensions_size()); 1680 absl::c_iota(rhs_broadcast_dims, 0); 1681 for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) { 1682 rhs_broadcast_dims.push_back(i); 1683 } 1684 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( 1685 dot->shape(), new_rhs, rhs_broadcast_dims)); 1686 } 1687 return ReplaceWithNewInstruction( 1688 dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, 1689 new_lhs, new_rhs)); 1690 } 1691 1692 // If the lhs or rhs have only batch and contracting dimensions, a dot can be 1693 // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) 1694 if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + 1695 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1696 lhs->shape().rank()) || 1697 (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + 1698 dot->dot_dimension_numbers().rhs_batch_dimensions_size() == 1699 rhs->shape().rank())) { 1700 TF_ASSIGN_OR_RETURN( 1701 HloInstruction * new_lhs, 1702 NormalizeDotOperandToBatchMajorAndContractingMinor( 1703 lhs, 1704 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), 1705 AsInt64Slice( 1706 dot->dot_dimension_numbers().lhs_contracting_dimensions()))); 1707 TF_ASSIGN_OR_RETURN( 1708 HloInstruction * new_rhs, 1709 NormalizeDotOperandToBatchMajorAndContractingMinor( 1710 rhs, 1711 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), 1712 AsInt64Slice( 1713 dot->dot_dimension_numbers().rhs_contracting_dimensions()))); 1714 1715 int64 lhs_outer_dims = 1716 lhs->shape().rank() - 1717 (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + 1718 dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); 1719 int64 rhs_outer_dims = 1720 rhs->shape().rank() - 1721 (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + 1722 dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); 1723 CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); 1724 if (rhs_outer_dims > 0) { 1725 std::vector<int64> lhs_broadcast_dims( 1726 dot->dot_dimension_numbers().lhs_batch_dimensions_size()); 1727 absl::c_iota(lhs_broadcast_dims, 0); 1728 lhs_broadcast_dims.resize(lhs->shape().rank()); 1729 std::iota(lhs_broadcast_dims.begin() + 1730 dot->dot_dimension_numbers().lhs_batch_dimensions_size(), 1731 lhs_broadcast_dims.end(), 1732 dot->dot_dimension_numbers().lhs_batch_dimensions_size() + 1733 rhs_outer_dims); 1734 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( 1735 new_rhs->shape(), new_lhs, lhs_broadcast_dims)); 1736 } else if (lhs_outer_dims > 0) { 1737 std::vector<int64> rhs_broadcast_dims( 1738 dot->dot_dimension_numbers().rhs_batch_dimensions_size()); 1739 absl::c_iota(rhs_broadcast_dims, 0); 1740 rhs_broadcast_dims.resize(rhs->shape().rank()); 1741 std::iota(rhs_broadcast_dims.begin() + 1742 dot->dot_dimension_numbers().rhs_batch_dimensions_size(), 1743 rhs_broadcast_dims.end(), 1744 dot->dot_dimension_numbers().rhs_batch_dimensions_size() + 1745 lhs_outer_dims); 1746 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( 1747 new_lhs->shape(), new_rhs, rhs_broadcast_dims)); 1748 } 1749 1750 TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, 1751 MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); 1752 std::vector<int64> reduce_dims( 1753 dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); 1754 new_dot = AsType(new_dot, F32); 1755 const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); 1756 absl::c_iota( 1757 reduce_dims, 1758 outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); 1759 new_dot = AddReduce(new_dot, reduce_dims); 1760 new_dot = AsType(new_dot, dot->shape().element_type()); 1761 return ReplaceInstruction(dot, new_dot); 1762 } 1763 1764 if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || 1765 dot->shape().rank() > 2) { 1766 if (options_.enable_dot_strength_reduction() && 1767 !options_.is_layout_sensitive()) { 1768 TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); 1769 } 1770 return Status::OK(); 1771 } 1772 1773 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, 1774 OptimizeDotOfConcat(dot)); 1775 if (dot_of_concat_optimized) { 1776 VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " 1777 "constant)...)"; 1778 return ReplaceInstruction(dot, dot_of_concat_optimized); 1779 } 1780 1781 // Simplify dot(ConstA, Gather(Index, ConstB)) to: 1782 // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately 1783 // batched version of dot. 1784 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, 1785 OptimizeDotOfGather(dot)); 1786 if (dot_of_gather_optimized) { 1787 VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " 1788 "gather(i, dot*(constA, constB))"; 1789 return ReplaceInstruction(dot, dot_of_gather_optimized); 1790 } 1791 1792 if (options_.enable_dot_strength_reduction() && 1793 !options_.is_layout_sensitive()) { 1794 TF_ASSIGN_OR_RETURN(bool did_strength_reduction, 1795 HandleDotStrengthReduction(dot)); 1796 if (did_strength_reduction) { 1797 return Status::OK(); 1798 } 1799 } 1800 1801 // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). 1802 if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 && 1803 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 && 1804 dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 && 1805 dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 && 1806 lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { 1807 DotDimensionNumbers dot_dimension_numbers; 1808 dot_dimension_numbers.add_lhs_contracting_dimensions(1); 1809 dot_dimension_numbers.add_rhs_contracting_dimensions(0); 1810 auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( 1811 ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), 1812 rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, 1813 dot->precision_config())); 1814 return ReplaceWithNewInstruction( 1815 dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); 1816 } 1817 1818 return Status::OK(); 1819 } 1820 1821 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { 1822 HloInstruction *lhs, *rhs; 1823 CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); 1824 // A*1 => A 1825 VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); 1826 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { 1827 return Status::OK(); 1828 } 1829 // 1*A => A 1830 VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); 1831 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { 1832 return Status::OK(); 1833 } 1834 1835 // 0*A => 0. Only applies for integral types for correct NaN-handling. 1836 if (IsAll(lhs, 0) && 1837 primitive_util::IsIntegralType(multiply->shape().element_type()) && 1838 ReplaceInstructionIfSameShape(multiply, lhs)) { 1839 return Status::OK(); 1840 } 1841 // A*0 => 0 1842 if (IsAll(rhs, 0) && 1843 primitive_util::IsIntegralType(multiply->shape().element_type()) && 1844 ReplaceInstructionIfSameShape(multiply, rhs)) { 1845 return Status::OK(); 1846 } 1847 1848 // exp(A) * exp(B) => exp(A+B) 1849 if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { 1850 auto add = computation_->AddInstruction(HloInstruction::CreateBinary( 1851 multiply->shape(), HloOpcode::kAdd, lhs, rhs)); 1852 return ReplaceWithNewInstruction( 1853 multiply, 1854 HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); 1855 } 1856 return Status::OK(); 1857 } 1858 1859 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { 1860 // negate(negate(x)) => x 1861 HloInstruction* x; 1862 if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) && 1863 ReplaceInstructionIfSameShape(negate, x)) { 1864 return Status::OK(); 1865 } 1866 return Status::OK(); 1867 } 1868 1869 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) { 1870 // not(not(x)) => x 1871 HloInstruction* x; 1872 if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) && 1873 ReplaceInstructionIfSameShape(logical_not, x)) { 1874 return Status::OK(); 1875 } 1876 return Status::OK(); 1877 } 1878 1879 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { 1880 HloInstruction *lhs, *rhs; 1881 CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs)))); 1882 1883 // Simplify logical or 1884 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && 1885 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { 1886 // A || True => True 1887 VLOG(10) << "trying transform [A || True => True]: " 1888 << logical_or->ToString(); 1889 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) { 1890 return Status::OK(); 1891 } 1892 // True || A => True 1893 VLOG(10) << "trying transform [True || A => True]: " 1894 << logical_or->ToString(); 1895 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { 1896 return Status::OK(); 1897 } 1898 1899 // A || False => A 1900 VLOG(10) << "trying transform [A || False => A]: " 1901 << logical_or->ToString(); 1902 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { 1903 return Status::OK(); 1904 } 1905 1906 // False || A => A 1907 VLOG(10) << "trying transform [False || A => A]: " 1908 << logical_or->ToString(); 1909 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { 1910 return Status::OK(); 1911 } 1912 } 1913 1914 return Status::OK(); 1915 } 1916 1917 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { 1918 // ln(exp(A)) => A 1919 VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); 1920 HloInstruction *a, *b; 1921 if (Match(log, m::Log(m::Exp(m::Op(&a)))) && 1922 ReplaceInstructionIfSameShape(log, a)) { 1923 return Status::OK(); 1924 } 1925 1926 // ln(pow(A,B)) => B*ln(A) 1927 if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { 1928 auto new_log = computation_->AddInstruction( 1929 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); 1930 return ReplaceWithNewInstruction( 1931 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, 1932 new_log, b)); 1933 } 1934 1935 return Status::OK(); 1936 } 1937 1938 Status AlgebraicSimplifierVisitor::HandleGetTupleElement( 1939 HloInstruction* get_tuple_element) { 1940 auto operand = get_tuple_element->mutable_operand(0); 1941 if (operand->opcode() == HloOpcode::kTuple) { 1942 // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i 1943 VLOG(10) << "trying transform " 1944 << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: " 1945 << get_tuple_element->ToString(); 1946 if (ReplaceInstructionIfSameShape( 1947 get_tuple_element, 1948 operand->mutable_operand(get_tuple_element->tuple_index()))) { 1949 return Status::OK(); 1950 } 1951 } 1952 return Status::OK(); 1953 } 1954 1955 namespace { 1956 1957 // Return whether the given reshape instruction leaves the dimensions at the 1958 // given input indices unmodified, and returns their output indices. 1959 // 1960 // Example: 1961 // input_dim_indices = {2, 3} 1962 // input shape = T[a, b, x, y, cd] 1963 // output shape = T[ab, x, 1, y, c, d] 1964 // return value = {1, 3} 1965 // 1966 // Precondition: input_dim_indices is sorted. 1967 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified( 1968 const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) { 1969 CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); 1970 CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); 1971 1972 std::vector<int64> output_dim_indices; 1973 std::vector<std::pair<int64, int64>> unmodified_dims = 1974 ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(), 1975 hlo->shape()); 1976 size_t i = 0; // index to unmodified_dims 1977 for (int64 input_dim_index : input_dim_indices) { 1978 // Search unmodified_dims for input_dim_index. We can search from the last 1979 // matching position because input_dim_indices is guaranteed to be sorted. 1980 while (i < unmodified_dims.size() && 1981 unmodified_dims[i].first < input_dim_index) { 1982 ++i; 1983 } 1984 if (i >= unmodified_dims.size() || 1985 unmodified_dims[i].first != input_dim_index) { 1986 return absl::nullopt; 1987 } 1988 output_dim_indices.push_back(unmodified_dims[i].second); 1989 } 1990 return output_dim_indices; 1991 } 1992 1993 // Returns true if the output of "instruction" is a permutation of the 1994 // elements of "operand". Precondition: "operand" is an operand of 1995 // "instruction". 1996 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, 1997 HloInstruction* operand) { 1998 DCHECK(!instruction->OperandIndices(operand).empty()); 1999 switch (instruction->opcode()) { 2000 case HloOpcode::kReshape: 2001 case HloOpcode::kReverse: 2002 case HloOpcode::kTranspose: 2003 return true; 2004 case HloOpcode::kSort: 2005 return (!instruction->shape().IsTuple()); 2006 default: 2007 return false; 2008 } 2009 } 2010 2011 // Returns true if the output of "instruction" is a subset of the elements of 2012 // "operand". Precondition: "operand" is an operand of "instruction". 2013 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, 2014 HloInstruction* operand) { 2015 std::vector<int64> operand_indices = instruction->OperandIndices(operand); 2016 CHECK(!operand_indices.empty()); 2017 if (operand_indices.size() != 1) { 2018 return false; 2019 } 2020 int64 operand_index = operand_indices[0]; 2021 switch (instruction->opcode()) { 2022 case HloOpcode::kSlice: 2023 CHECK_EQ(0, operand_index); 2024 return true; 2025 case HloOpcode::kDynamicSlice: 2026 return operand_index == 0; 2027 default: 2028 return false; 2029 } 2030 } 2031 2032 } // namespace 2033 2034 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { 2035 HloInstruction* operand; 2036 CHECK(Match(broadcast, m::Broadcast(m::Op(&operand)))); 2037 auto dims = broadcast->dimensions(); 2038 // A degenerate broadcast of a reshape that does not change the number of 2039 // elements can be replaced by a reshape. 2040 if (std::is_sorted(dims.begin(), dims.end()) && 2041 ShapeUtil::ElementsIn(broadcast->shape()) == 2042 ShapeUtil::ElementsIn(operand->shape())) { 2043 VLOG(10) << "transform broadcast(X) -> reshape(X) where " 2044 "n(broadcast(X)) == n(X)"; 2045 return ReplaceWithNewInstruction( 2046 broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); 2047 } 2048 2049 // A degenerate broadcast that has the same input and output rank can be 2050 // converted into a transpose. 2051 if (broadcast->shape().rank() == operand->shape().rank() && 2052 ShapeUtil::ElementsIn(broadcast->shape()) == 2053 ShapeUtil::ElementsIn(operand->shape())) { 2054 VLOG(10) << "transform broadcast(X) -> transpose(X) where " 2055 "n(broadcast(X)) == n(X)"; 2056 return ReplaceWithNewInstruction( 2057 broadcast, 2058 HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); 2059 } 2060 2061 // A broadcast of a reshape which merely inserts 1-sized dimensions can 2062 // elide its operand. 2063 { 2064 bool merely_inserts_or_deletes_1_sized_dimensions; 2065 std::vector<int64> inserted_indices, deleted_indices; 2066 std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices, 2067 inserted_indices) = 2068 operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); 2069 if (merely_inserts_or_deletes_1_sized_dimensions && 2070 deleted_indices.empty()) { 2071 std::reverse(inserted_indices.begin(), inserted_indices.end()); 2072 for (auto inserted_index : inserted_indices) { 2073 dims.erase(dims.begin() + inserted_index); 2074 } 2075 return ReplaceWithNewInstruction( 2076 broadcast, 2077 HloInstruction::CreateBroadcast(broadcast->shape(), 2078 operand->mutable_operand(0), dims)); 2079 } 2080 } 2081 2082 // A Broadcast that feeds a unary element-wise operation can sink the 2083 // broadcast after the unary element-wise operation. 2084 TF_ASSIGN_OR_RETURN( 2085 bool sink_succeeded, 2086 TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); 2087 changed_ |= sink_succeeded; 2088 if (sink_succeeded) { 2089 return Status::OK(); 2090 } 2091 2092 // A scalar broadcast feeding an instruction which only permutes (reshape, 2093 // transpose, sort, reverse) or selects a subset of operand elements (slice, 2094 // dynamic slice) can be replaced with a broadcast directly to the output 2095 // shape of the instruction. 2096 if (ShapeUtil::IsScalar(operand->shape())) { 2097 for (HloInstruction* user : broadcast->users()) { 2098 // Skip if the broadcast user has no uses itself. 2099 if (user->user_count() == 0 && user != computation_->root_instruction()) { 2100 continue; 2101 } 2102 if (OutputIsPermutationOfOperandElements(user, broadcast) || 2103 OutputIsSubsetOfOperandElements(user, broadcast)) { 2104 VLOG(10) << "transform permuting/subset of a scalar broadcast into " 2105 << "a single broadcast"; 2106 HloInstruction* new_broadcast = computation_->AddInstruction( 2107 HloInstruction::CreateBroadcast(user->shape(), operand, {})); 2108 // Use HloInstruction::ReplaceAllUsesWith instead of 2109 // HloComputation::ReplaceWithNewInstruction because we are replacing an 2110 // instruction other than the visited instruction. 2111 changed_ = true; 2112 return user->ReplaceAllUsesWith(new_broadcast); 2113 } 2114 } 2115 return Status::OK(); 2116 } 2117 2118 // broadcast(iota) -> iota. 2119 if (operand->opcode() == HloOpcode::kIota) { 2120 return ReplaceWithNewInstruction( 2121 broadcast, 2122 HloInstruction::CreateIota( 2123 broadcast->shape(), 2124 dims[Cast<HloIotaInstruction>(operand)->iota_dimension()])); 2125 } 2126 2127 // Merge two consecutive broadcasts into a single one. 2128 if (operand->opcode() == HloOpcode::kBroadcast) { 2129 std::vector<int64> new_dimensions; 2130 for (auto dim : operand->dimensions()) { 2131 new_dimensions.push_back(dims[dim]); 2132 } 2133 return ReplaceWithNewInstruction( 2134 broadcast, 2135 HloInstruction::CreateBroadcast( 2136 broadcast->shape(), operand->mutable_operand(0), new_dimensions)); 2137 } 2138 return Status::OK(); 2139 } 2140 2141 // A conversion to the same element type as the operand is a nop and can be 2142 // removed. A conversion of a constant can be simplified by making a new 2143 // constant. 2144 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { 2145 PrimitiveType src_type = convert->operand(0)->shape().element_type(); 2146 PrimitiveType dest_type = convert->shape().element_type(); 2147 if (src_type == dest_type) { 2148 return ReplaceInstruction(convert, convert->mutable_operand(0)); 2149 } 2150 return Status::OK(); 2151 } 2152 2153 // Complex(Real(c), Imag(c)) -> c 2154 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { 2155 HloInstruction *c0, *c1; 2156 if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) && 2157 c0 == c1) { 2158 return ReplaceInstruction(complex, c0); 2159 } 2160 return Status::OK(); 2161 } 2162 2163 // Real(Complex(r, i)) -> r 2164 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { 2165 HloInstruction* op; 2166 if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) { 2167 return ReplaceInstruction(real, op); 2168 } 2169 return Status::OK(); 2170 } 2171 2172 // Imag(Complex(r, i)) -> i 2173 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { 2174 HloInstruction* op; 2175 if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) { 2176 return ReplaceInstruction(imag, op); 2177 } 2178 return Status::OK(); 2179 } 2180 2181 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { 2182 // iota -> zero if the iota dimension never produces an element other than 2183 // zero. 2184 auto* iota = Cast<HloIotaInstruction>(instruction); 2185 if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { 2186 auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( 2187 LiteralUtil::Zero(iota->shape().element_type()).Clone())); 2188 return ReplaceWithNewInstruction( 2189 iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); 2190 } 2191 return Status::OK(); 2192 } 2193 2194 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { 2195 if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { 2196 return ReplaceWithNewInstruction( 2197 pad, HloInstruction::CreateBroadcast(pad->shape(), 2198 pad->mutable_operand(1), {})); 2199 } 2200 2201 // Interior padding on one sized dimensions have no effect. As a result it 2202 // makes other simplifications possible if there is no interior padding. 2203 if (HasInteriorPadding(pad->padding_config())) { 2204 PaddingConfig padding_config = pad->padding_config(); 2205 bool cleared_interior_padding = false; 2206 for (int64 i = 0; i < pad->shape().rank(); ++i) { 2207 if (padding_config.dimensions(i).interior_padding() > 0 && 2208 pad->operand(0)->shape().dimensions(i) == 1) { 2209 cleared_interior_padding = true; 2210 padding_config.mutable_dimensions(i)->set_interior_padding(0); 2211 } 2212 } 2213 if (cleared_interior_padding) { 2214 return ReplaceWithNewInstruction( 2215 pad, 2216 HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0), 2217 pad->mutable_operand(1), padding_config)); 2218 } 2219 } 2220 2221 // Eliminate nop pads (padding all zero), and replace a pad with negative 2222 // padding with a pad with non-negative padding followed by a slice. 2223 bool all_zero = true; 2224 bool has_negative = false; 2225 for (auto& padding_dimension : pad->padding_config().dimensions()) { 2226 if (padding_dimension.edge_padding_low() < 0 || 2227 padding_dimension.edge_padding_high() < 0) { 2228 has_negative = true; 2229 } 2230 if (padding_dimension.edge_padding_low() != 0 || 2231 padding_dimension.edge_padding_high() != 0) { 2232 all_zero = false; 2233 } 2234 } 2235 2236 if (all_zero) { 2237 ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); 2238 return Status::OK(); 2239 } 2240 2241 if (has_negative) { 2242 // Pad has negative padding. Replace with a pad with the non-negative 2243 // padding followed by a slice which effectively performs the negative 2244 // padding. 2245 // TODO(b/34628603): Add support for negative padding in the backends, or 2246 // change kPad semantics to disallow negative padding and use slice 2247 // instead. 2248 2249 // First construct the padding config with non-negative entries and the 2250 // compute the shape of this new pad instruction. 2251 PaddingConfig nonzero_padding = pad->padding_config(); 2252 for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) { 2253 PaddingConfig::PaddingConfigDimension* padding_dimension = 2254 nonzero_padding.mutable_dimensions(i); 2255 // Set negative padding to zero. 2256 if (padding_dimension->edge_padding_low() < 0) { 2257 padding_dimension->set_edge_padding_low(0); 2258 } 2259 if (padding_dimension->edge_padding_high() < 0) { 2260 padding_dimension->set_edge_padding_high(0); 2261 } 2262 } 2263 2264 TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, 2265 MakePadHlo(pad->mutable_operand(0), 2266 pad->mutable_operand(1), nonzero_padding)); 2267 // Copy the layout from the original pad instructions. The new pad and the 2268 // slice instruction should all have the same layout. 2269 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( 2270 pad->shape(), nonzero_pad->mutable_shape())); 2271 2272 // Second, construct the slice instruction to perform the negative padding. 2273 std::vector<int64> start_indices; 2274 std::vector<int64> end_indices; 2275 std::vector<int64> strides; 2276 for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) { 2277 const PaddingConfig::PaddingConfigDimension& padding_dimension = 2278 pad->padding_config().dimensions(i); 2279 int64 start = 0; 2280 if (padding_dimension.edge_padding_low() < 0) { 2281 start = -1 * padding_dimension.edge_padding_low(); 2282 } 2283 int64 end = nonzero_pad->shape().dimensions(i); 2284 if (padding_dimension.edge_padding_high() < 0) { 2285 end += padding_dimension.edge_padding_high(); 2286 } 2287 start_indices.push_back(start); 2288 end_indices.push_back(end); 2289 strides.push_back(1); 2290 } 2291 2292 TF_ASSIGN_OR_RETURN( 2293 HloInstruction * slice, 2294 MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); 2295 2296 // Verify that the slice shape matches the pad shape. 2297 TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape())); 2298 2299 return ReplaceInstruction(pad, slice); 2300 } 2301 2302 return Status::OK(); 2303 } 2304 2305 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { 2306 VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); 2307 HloInstruction *lhs, *rhs; 2308 CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); 2309 if (IsAll(rhs, 0)) { 2310 auto one = HloInstruction::CreateConstant( 2311 LiteralUtil::One(power->shape().element_type()).Clone()); 2312 std::unique_ptr<HloInstruction> ones; 2313 if (ShapeUtil::IsScalar(power->shape())) { 2314 ones = std::move(one); 2315 } else { 2316 ones = HloInstruction::CreateBroadcast( 2317 power->shape(), computation_->AddInstruction(std::move(one)), {}); 2318 } 2319 return ReplaceWithNewInstruction(power, std::move(ones)); 2320 } 2321 2322 VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); 2323 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { 2324 return Status::OK(); 2325 } 2326 2327 // pow(exp(A),B) => exp(A*B) 2328 HloInstruction *a, *b; 2329 if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) { 2330 auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( 2331 power->shape(), HloOpcode::kMultiply, a, b)); 2332 return ReplaceWithNewInstruction( 2333 power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, 2334 a_times_b)); 2335 } 2336 VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); 2337 if (IsAll(rhs, 2)) { 2338 return ReplaceWithNewInstruction( 2339 power, HloInstruction::CreateBinary(power->shape(), 2340 HloOpcode::kMultiply, lhs, lhs)); 2341 } 2342 2343 VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); 2344 if (IsAll(rhs, -1)) { 2345 auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( 2346 LiteralUtil::One(rhs->shape().element_type()).Clone())); 2347 2348 // Explicitly broadcast scalar 1 to the output shape, to avoid implicit 2349 // broadcast in divide HLO as we are trying to eliminate implicit 2350 // broadcasting at HLO level. 2351 auto* broadcast_one = computation_->AddInstruction( 2352 HloInstruction::CreateBroadcast(power->shape(), one, {})); 2353 return ReplaceWithNewInstruction( 2354 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, 2355 broadcast_one, lhs)); 2356 } 2357 2358 VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: " 2359 << power->ToString(); 2360 2361 // Don't perform this optimization if either of the exponents is complex; this 2362 // identity is true only for real-valued exponents. In addition, we cowardly 2363 // refuse to do this transformation if the two expontents have different 2364 // element types. 2365 if (lhs->opcode() == HloOpcode::kPower && 2366 !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) && 2367 !ShapeUtil::ElementIsComplex(rhs->shape()) && 2368 ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) { 2369 auto exponent_product = 2370 computation_->AddInstruction(HloInstruction::CreateBinary( 2371 rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); 2372 return ReplaceWithNewInstruction( 2373 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower, 2374 lhs->mutable_operand(0), 2375 exponent_product)); 2376 } 2377 2378 return Status::OK(); 2379 } 2380 2381 StatusOr<bool> 2382 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( 2383 HloInstruction* broadcast) { 2384 TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); 2385 bool changed = false; 2386 if (ShapeUtil::IsScalar(broadcast->shape())) { 2387 return false; 2388 } 2389 HloInstruction* operand = broadcast->mutable_operand(0); 2390 for (HloInstruction* user : broadcast->users()) { 2391 if (user->user_count() == 0 && user != computation_->root_instruction()) { 2392 continue; 2393 } 2394 // Do not move reshapes or broadcasts past copies since the shape the copy 2395 // will operate on will change. 2396 if (user->opcode() == HloOpcode::kCopy) { 2397 continue; 2398 } 2399 // Do not change the shape of fusion nodes in case there a multiple shapes 2400 // inside the fusion node already. 2401 if (user->opcode() == HloOpcode::kFusion) { 2402 continue; 2403 } 2404 if (!user->IsElementwise()) { 2405 continue; 2406 } 2407 2408 // Find the unique non-scalar operand or continue if there isn't one. 2409 int64 scalar_broadcast_count = 0; 2410 int64 broadcast_use_count = 0; 2411 for (HloInstruction* user_operand : user->operands()) { 2412 if (user_operand->opcode() == HloOpcode::kBroadcast && 2413 ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { 2414 ++scalar_broadcast_count; 2415 } else if (broadcast == user_operand) { 2416 ++broadcast_use_count; 2417 } 2418 } 2419 if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { 2420 continue; 2421 } 2422 std::vector<HloInstruction*> new_operands; 2423 new_operands.reserve(user->operand_count()); 2424 2425 for (HloInstruction* user_operand : user->operands()) { 2426 if (user_operand->opcode() == HloOpcode::kBroadcast && 2427 ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { 2428 new_operands.push_back( 2429 computation_->AddInstruction(HloInstruction::CreateBroadcast( 2430 ShapeUtil::ChangeElementType( 2431 operand->shape(), user_operand->shape().element_type()), 2432 user_operand->mutable_operand(0), {}))); 2433 } else { 2434 CHECK_EQ(broadcast, user_operand); 2435 new_operands.push_back(operand); 2436 } 2437 } 2438 VLOG(4) << "Sinking broadcast after user:"; 2439 VLOG(4) << " old broadcast: " << broadcast->ToString(); 2440 VLOG(4) << " old user: " << user->ToString(); 2441 HloInstruction* new_user = 2442 computation_->AddInstruction(user->CloneWithNewOperands( 2443 ShapeUtil::ChangeElementType(operand->shape(), 2444 user->shape().element_type()), 2445 new_operands)); 2446 VLOG(4) << " new user: " << new_user->ToString(); 2447 HloInstruction* new_broadcast = 2448 computation_->AddInstruction(HloInstruction::CreateBroadcast( 2449 user->shape(), new_user, broadcast->dimensions())); 2450 VLOG(4) << " new broadcast: " << new_broadcast->ToString(); 2451 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); 2452 changed = true; 2453 } 2454 return changed; 2455 } 2456 2457 namespace { 2458 template <typename T> 2459 std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder, 2460 HloComputation* computation) { 2461 HloInstruction *a, *b, *c; 2462 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); 2463 2464 if (ShapeUtil::ElementIsIntegral(remainder->shape()) && 2465 !Match(b, m::ConstantEffectiveScalar(&c)) && 2466 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { 2467 return nullptr; 2468 } 2469 2470 if (ShapeUtil::ElementIsSigned(remainder->shape())) { 2471 int64 b_value = c->literal().GetFirstElement<T>(); 2472 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) { 2473 // Handle negative dividends by negating the result of the division. 2474 HloInstruction* zero_like_a = BroadcastZeros( 2475 computation, a->shape().element_type(), a->shape().dimensions()); 2476 2477 auto* dividend_is_negative = 2478 computation->AddInstruction(HloInstruction::CreateCompare( 2479 ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, 2480 ComparisonDirection::kLt)); 2481 2482 auto* negated_dividend = computation->AddInstruction( 2483 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); 2484 2485 auto* abs_dividend = 2486 computation->AddInstruction(HloInstruction::CreateTernary( 2487 a->shape(), HloOpcode::kSelect, dividend_is_negative, 2488 negated_dividend, a)); 2489 2490 auto* mask_amount = 2491 computation->AddInstruction(HloInstruction::CreateConstant( 2492 LiteralUtil::CreateR0<T>(b_value - 1))); 2493 if (!ShapeUtil::IsScalar(b->shape())) { 2494 mask_amount = computation->AddInstruction( 2495 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); 2496 } 2497 2498 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( 2499 remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount)); 2500 2501 auto* neqated_quotient = 2502 computation->AddInstruction(HloInstruction::CreateUnary( 2503 quotient->shape(), HloOpcode::kNegate, quotient)); 2504 2505 return HloInstruction::CreateTernary( 2506 remainder->shape(), HloOpcode::kSelect, dividend_is_negative, 2507 neqated_quotient, quotient); 2508 } 2509 } else { 2510 uint64 b_value = c->literal().GetFirstElement<T>(); 2511 if (IsPowerOfTwo(b_value)) { 2512 HloInstruction* mask_amount = 2513 computation->AddInstruction(HloInstruction::CreateConstant( 2514 LiteralUtil::CreateR0<T>(b_value - 1))); 2515 if (!ShapeUtil::IsScalar(b->shape())) { 2516 mask_amount = computation->AddInstruction( 2517 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); 2518 } 2519 return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd, 2520 a, mask_amount); 2521 } 2522 } 2523 return nullptr; 2524 } 2525 } // namespace 2526 2527 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { 2528 HloInstruction *a, *b; 2529 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); 2530 2531 // A % B => A & (B - 1) if B is a power of 2. 2532 switch (remainder->shape().element_type()) { 2533 case S8: 2534 if (std::unique_ptr<HloInstruction> shift = 2535 TryRemainderToAnd<int8>(remainder, computation_)) { 2536 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2537 } 2538 break; 2539 case S16: 2540 if (std::unique_ptr<HloInstruction> shift = 2541 TryRemainderToAnd<int16>(remainder, computation_)) { 2542 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2543 } 2544 break; 2545 case S32: 2546 if (std::unique_ptr<HloInstruction> shift = 2547 TryRemainderToAnd<int32>(remainder, computation_)) { 2548 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2549 } 2550 break; 2551 case S64: 2552 if (std::unique_ptr<HloInstruction> shift = 2553 TryRemainderToAnd<int64>(remainder, computation_)) { 2554 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2555 } 2556 break; 2557 case U8: 2558 if (std::unique_ptr<HloInstruction> shift = 2559 TryRemainderToAnd<uint8>(remainder, computation_)) { 2560 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2561 } 2562 break; 2563 case U16: 2564 if (std::unique_ptr<HloInstruction> shift = 2565 TryRemainderToAnd<uint16>(remainder, computation_)) { 2566 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2567 } 2568 break; 2569 case U32: 2570 if (std::unique_ptr<HloInstruction> shift = 2571 TryRemainderToAnd<uint32>(remainder, computation_)) { 2572 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2573 } 2574 break; 2575 case U64: 2576 if (std::unique_ptr<HloInstruction> shift = 2577 TryRemainderToAnd<uint64>(remainder, computation_)) { 2578 return ReplaceWithNewInstruction(remainder, std::move(shift)); 2579 } 2580 break; 2581 default: 2582 break; 2583 } 2584 2585 return Status::OK(); 2586 } 2587 2588 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { 2589 auto operand = reshape->mutable_operand(0); 2590 2591 // Reshape directly to empty constant if the shape contains zero-element 2592 // dimension. 2593 if (ShapeUtil::IsZeroElementArray(reshape->shape())) { 2594 // If the instruction doesn't have a layout, use a default layout for 2595 // the literal result. 2596 Shape reshaped_shape = reshape->shape(); 2597 if (!LayoutUtil::HasLayout(reshaped_shape)) { 2598 LayoutUtil::SetToDefaultLayout(&reshaped_shape); 2599 } 2600 auto empty_constant = HloInstruction::CreateConstant( 2601 Literal::CreateFromShape(reshaped_shape)); 2602 2603 return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); 2604 } 2605 2606 // Delete no-op reshapes, i.e. where shape = operand shape. 2607 if (SameShape(reshape, operand)) { 2608 VLOG(10) << "deleting no-op reshape"; 2609 return ReplaceInstruction(reshape, operand); 2610 } 2611 2612 // Merge reshapes. 2613 if (HloOpcode::kReshape == operand->opcode()) { 2614 return ReplaceWithNewInstruction( 2615 reshape, HloInstruction::CreateReshape(reshape->shape(), 2616 operand->mutable_operand(0))); 2617 } 2618 2619 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { 2620 *operand->mutable_shape() = reshape->shape(); 2621 return ReplaceInstruction(reshape, operand); 2622 } 2623 2624 if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { 2625 auto opt_dims = ReshapeLeavesDimensionsUnmodified( 2626 reshape, reshape->operand(0)->dimensions()); 2627 if (opt_dims.has_value()) { 2628 return ReplaceWithNewInstruction( 2629 reshape, 2630 HloInstruction::CreateBroadcast( 2631 reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), 2632 *opt_dims)); 2633 } 2634 } 2635 2636 // reshape(iota) -> iota. 2637 if (operand->opcode() == HloOpcode::kIota) { 2638 auto* iota = Cast<HloIotaInstruction>(operand); 2639 auto opt_dims = 2640 ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); 2641 if (opt_dims.has_value()) { 2642 CHECK_EQ(opt_dims->size(), 1); 2643 return ReplaceWithNewInstruction( 2644 reshape, 2645 HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); 2646 } 2647 } 2648 2649 // Make this a bitcast if possible. 2650 if (HloInstruction* bitcast_operand = 2651 BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) { 2652 ReplaceWithBitcast(reshape, bitcast_operand); 2653 } 2654 return Status::OK(); 2655 } 2656 2657 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { 2658 // When all the dimensions to reverse are trivial (i.e. the bound is 1), 2659 // there is nothing to be done. 2660 auto dim_is_one = [&](int64 i) -> bool { 2661 return reverse->shape().dimensions(i) == 1; 2662 }; 2663 if (absl::c_all_of(reverse->dimensions(), dim_is_one)) { 2664 return ReplaceInstruction(reverse, reverse->mutable_operand(0)); 2665 } 2666 return Status::OK(); 2667 } 2668 2669 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( 2670 HloInstruction* slice) { 2671 // Only try to do this for effective scalars. We could do the same for slicing 2672 // out larger pieces of padding (replacing with a broadcast of the padding 2673 // value), but this is probably not worth it. 2674 if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { 2675 return false; 2676 } 2677 2678 if (slice->operand(0)->opcode() == HloOpcode::kPad) { 2679 VLOG(10) << "Trying to simplify scalar slice of pad"; 2680 // Check there's no internal padding. Again, we could handle that too, since 2681 // everything is statically known, but it's not worth it. 2682 auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0)); 2683 auto padding_config = pad->padding_config(); 2684 int64 rank = padding_config.dimensions_size(); 2685 if (HasInteriorPadding(padding_config)) { 2686 VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; 2687 return false; 2688 } 2689 2690 // Check whether the scalar we're slicing out falls into the padding. 2691 bool in_padding = [&]() { 2692 for (int64 i = 0; i < rank; ++i) { 2693 int64 start = slice->slice_starts(i); 2694 int64 low = padding_config.dimensions(i).edge_padding_low(); 2695 int64 data = pad->operand(0)->shape().dimensions(i); 2696 if (start < low || start >= low + data) { 2697 return true; 2698 } 2699 } 2700 return false; 2701 }(); 2702 2703 if (in_padding) { 2704 VLOG(10) << "Folding scalar slice of pad into padding value"; 2705 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( 2706 slice, HloInstruction::CreateReshape(slice->shape(), 2707 pad->mutable_padding_value()))); 2708 return true; 2709 } else { 2710 // We already know the output of the slice is scalar. If the padded 2711 // value is scalar, and it's not in the padding, then it's exactly the 2712 // output value. 2713 bool replaced = 2714 ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); 2715 if (replaced) { 2716 VLOG(10) << "Folding scalar slice of pad into padded value"; 2717 } else { 2718 VLOG(10) << "Not folding scalar slice of pad into padded value as they " 2719 "have different shapes."; 2720 } 2721 return replaced; 2722 } 2723 } 2724 2725 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { 2726 VLOG(10) << "Trying to simplify scalar slice of concat"; 2727 // Only do this for R1, there's no chance of this being useful otherwise. 2728 if (slice->shape().rank() != 1) { 2729 VLOG(10) << "Not folding, slice is not rank 1"; 2730 return false; 2731 } 2732 HloConcatenateInstruction* concat = 2733 Cast<HloConcatenateInstruction>(slice->mutable_operand(0)); 2734 int64 operand_start = 0; 2735 int64 operand_num = 0; 2736 // Weird loop structure to avoid annoying off-by-one errors. 2737 while (true) { 2738 TF_RET_CHECK(operand_num < concat->operand_count()); 2739 const HloInstruction* operand = concat->operand(operand_num); 2740 int64 next_operand_start = operand_start + operand->shape().dimensions(0); 2741 if (next_operand_start > slice->slice_starts(0)) { 2742 break; 2743 } 2744 operand_start = next_operand_start; 2745 operand_num++; 2746 } 2747 2748 bool replaced = ReplaceInstructionIfSameShape( 2749 slice, concat->mutable_operand(operand_num)); 2750 if (replaced) { 2751 VLOG(10) << "Folding scalar slice of concat into concat operand"; 2752 } else { 2753 VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; 2754 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( 2755 slice, HloInstruction::CreateSlice( 2756 slice->shape(), concat->mutable_operand(operand_num), 2757 {slice->slice_starts(0) - operand_start}, 2758 {slice->slice_starts(0) - operand_start + 1}, 2759 slice->slice_strides()))); 2760 } 2761 return true; 2762 } 2763 2764 return false; 2765 } 2766 2767 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( 2768 HloInstruction* slice) { 2769 CHECK_EQ(slice->opcode(), HloOpcode::kSlice); 2770 if (!IsUnstridedSlice(slice)) { 2771 return false; 2772 } 2773 HloInstruction* reshape = slice->mutable_operand(0); 2774 if (reshape->opcode() != HloOpcode::kReshape) { 2775 return false; 2776 } 2777 HloInstruction* new_slice_operand = reshape->mutable_operand(0); 2778 int64 slice_rank = slice->shape().rank(); 2779 std::vector<int64> sliced_dims; 2780 for (int64 i = 0; i < slice_rank; ++i) { 2781 if (slice->slice_starts(i) != 0 || 2782 slice->slice_limits(i) != reshape->shape().dimensions(i)) { 2783 sliced_dims.push_back(i); 2784 } 2785 } 2786 2787 if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && 2788 slice->slice_starts(0) == 0) { 2789 const Shape& new_slice_shape = new_slice_operand->shape(); 2790 const int64 rank = new_slice_shape.rank(); 2791 std::vector<int64> new_slice_starts(rank, 0); 2792 std::vector<int64> new_slice_stides(rank, 1); 2793 std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(), 2794 new_slice_shape.dimensions().end()); 2795 int64 slice_elements = ShapeUtil::ElementsIn(slice->shape()); 2796 for (int64 i = rank - 1; i >= 0; --i) { 2797 if (slice_elements >= new_slice_limits[i]) { 2798 if (slice_elements % new_slice_limits[i] != 0) { 2799 return false; 2800 } 2801 slice_elements /= new_slice_limits[i]; 2802 } else { 2803 new_slice_limits[i] = slice_elements; 2804 slice_elements = 1; 2805 } 2806 } 2807 HloInstruction* new_slice = 2808 computation_->AddInstruction(HloInstruction::CreateSlice( 2809 ShapeUtil::MakeShape(new_slice_shape.element_type(), 2810 new_slice_limits), 2811 new_slice_operand, new_slice_starts, new_slice_limits, 2812 new_slice_stides)); 2813 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( 2814 slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); 2815 return true; 2816 } 2817 return false; 2818 } 2819 2820 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { 2821 // Delete no-op slices, i.e. where shape = operand shape. 2822 if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { 2823 return Status::OK(); 2824 } 2825 2826 if (slice->operand(0)->opcode() == HloOpcode::kSlice && 2827 IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) { 2828 HloInstruction* operand_slice = slice->mutable_operand(0); 2829 std::vector<int64> new_slice_starts = slice->slice_starts(); 2830 std::vector<int64> new_slice_limits = slice->slice_limits(); 2831 for (int64 i = 0; i < new_slice_starts.size(); ++i) { 2832 new_slice_starts[i] += operand_slice->slice_starts(i); 2833 new_slice_limits[i] += operand_slice->slice_starts(i); 2834 } 2835 return ReplaceWithNewInstruction( 2836 slice, HloInstruction::CreateSlice( 2837 slice->shape(), operand_slice->mutable_operand(0), 2838 new_slice_starts, new_slice_limits, slice->slice_strides())); 2839 } 2840 2841 auto only_broadcast_dims_sliced = [&] { 2842 if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) { 2843 return false; 2844 } 2845 for (int64 dim : slice->operand(0)->dimensions()) { 2846 if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 || 2847 slice->slice_limits(dim) != 2848 slice->operand(0)->shape().dimensions(dim)) { 2849 return false; 2850 } 2851 } 2852 return true; 2853 }; 2854 if (only_broadcast_dims_sliced()) { 2855 return ReplaceWithNewInstruction( 2856 slice, 2857 HloInstruction::CreateBroadcast( 2858 slice->shape(), slice->mutable_operand(0)->mutable_operand(0), 2859 slice->mutable_operand(0)->dimensions())); 2860 } 2861 2862 TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); 2863 if (replaced) { 2864 return Status::OK(); 2865 } 2866 2867 TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); 2868 if (replaced) { 2869 return Status::OK(); 2870 } 2871 return Status::OK(); 2872 } 2873 2874 Status AlgebraicSimplifierVisitor::HandleDynamicSlice( 2875 HloInstruction* dynamic_slice) { 2876 auto operand = dynamic_slice->mutable_operand(0); 2877 if (ShapeUtil::IsScalar(dynamic_slice->shape())) { 2878 return ReplaceInstruction(dynamic_slice, operand); 2879 } 2880 // DynamicSlice where operand has the same size as the output is simply equal 2881 // to operand. 2882 if (SameShape(operand, dynamic_slice)) { 2883 return ReplaceInstruction(dynamic_slice, operand); 2884 } 2885 return Status::OK(); 2886 } 2887 2888 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( 2889 HloInstruction* dynamic_update_slice) { 2890 auto update = dynamic_update_slice->mutable_operand(1); 2891 2892 // DynamicUpdateSlice where operand and update have the same size is simply 2893 // equal to update. 2894 if (SameShape(dynamic_update_slice, update)) { 2895 return ReplaceInstruction(dynamic_update_slice, update); 2896 } 2897 2898 // If any dimension of update is 0, elide the DynamicUpdateSlice. This 2899 // optimization becomes invalid should we later prefer to warn about out of 2900 // bound indices. 2901 if (ShapeUtil::IsZeroElementArray(update->shape())) { 2902 return ReplaceInstruction(dynamic_update_slice, 2903 dynamic_update_slice->mutable_operand(0)); 2904 } 2905 return Status::OK(); 2906 } 2907 2908 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { 2909 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo); 2910 bool multi_output_reduce = reduce->shape().IsTuple(); 2911 2912 // For tuple reduce, we require all reduce shapes to be the same, up to the 2913 // element types, so we can just the first operand and the first result as a 2914 // representative. 2915 auto arg = reduce->inputs()[0]; 2916 auto init_value = reduce->init_values()[0]; 2917 const Shape& reduce_result_shape = 2918 multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); 2919 2920 absl::Span<const int64> dimensions(reduce->dimensions()); 2921 HloComputation* function = reduce->to_apply(); 2922 if (ShapeUtil::IsZeroElementArray(arg->shape()) || 2923 ShapeUtil::IsZeroElementArray(reduce_result_shape)) { 2924 if (multi_output_reduce) { 2925 std::vector<HloInstruction*> broadcast_inits; 2926 int64 inputs = reduce->input_count(); 2927 for (int64 i = 0; i < inputs; ++i) { 2928 broadcast_inits.push_back(computation_->AddInstruction( 2929 HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), 2930 reduce->init_values()[i], {}))); 2931 } 2932 return ReplaceWithNewInstruction( 2933 reduce, HloInstruction::CreateTuple(broadcast_inits)); 2934 } else { 2935 return ReplaceWithNewInstruction( 2936 reduce, 2937 HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {})); 2938 } 2939 } 2940 2941 // If the reduction results in the same number of elements, then the only 2942 // possible side effect would be a reshape. Since the init_value is an 2943 // identity of the reduction function, we can therefore replace the reduce 2944 // with a simple reshape, ignoring the reduction function completely. 2945 if (ShapeUtil::ElementsIn(reduce_result_shape) == 2946 ShapeUtil::ElementsIn(arg->shape())) { 2947 if (multi_output_reduce) { 2948 std::vector<HloInstruction*> reshaped_args; 2949 int64 inputs = reduce->input_count(); 2950 for (int64 i = 0; i < inputs; ++i) { 2951 reshaped_args.push_back( 2952 computation_->AddInstruction(HloInstruction::CreateReshape( 2953 reduce->shape().tuple_shapes(i), reduce->inputs()[i]))); 2954 } 2955 return ReplaceWithNewInstruction( 2956 reduce, HloInstruction::CreateTuple(reshaped_args)); 2957 } else { 2958 return ReplaceWithNewInstruction( 2959 reduce, HloInstruction::CreateReshape(reduce_result_shape, arg)); 2960 } 2961 } 2962 2963 // TODO(b/112040122): Most of those optimizations below can be done for 2964 // multi-output reduces. 2965 if (multi_output_reduce) { 2966 return Status::OK(); 2967 } 2968 2969 // A Transpose feeding a reduce can simply permute the reduction dimensions 2970 // field if the output of the reduce is a vector or scalar. Higher ranked 2971 // result may require a transpose of the output. 2972 if (reduce_result_shape.rank() <= 1 && 2973 arg->opcode() == HloOpcode::kTranspose) { 2974 auto transpose_dimensions = arg->dimensions(); 2975 std::vector<int64> new_reduce_dimensions; 2976 for (auto dim : dimensions) { 2977 new_reduce_dimensions.push_back(transpose_dimensions[dim]); 2978 } 2979 return ReplaceWithNewInstruction( 2980 reduce, HloInstruction::CreateReduce( 2981 reduce_result_shape, arg->mutable_operand(0), init_value, 2982 new_reduce_dimensions, function)); 2983 } 2984 2985 // If a reduce feeds a reduce with the same computation and initial value, 2986 // they can be combined into a single reduce. 2987 if (arg->opcode() == HloOpcode::kReduce && 2988 init_value->Identical(*arg->operand(1)) && 2989 *function == *arg->to_apply()) { 2990 // Create a new reduce with the combined reduction dimensions of both 2991 // reduces. 2992 std::vector<int64> arg_dims = arg->dimensions(); 2993 absl::c_sort(arg_dims); 2994 std::vector<int64> reduce_dims = reduce->dimensions(); 2995 absl::c_sort(reduce_dims); 2996 // Transform reduce_dims to the same rank as the operand of the operand. 2997 for (int64 arg_dim : arg_dims) { 2998 for (int64& dim : reduce_dims) { 2999 if (dim >= arg_dim) { 3000 ++dim; 3001 } 3002 } 3003 } 3004 std::vector<int64> new_dimensions; 3005 new_dimensions.reserve(arg->dimensions().size() + 3006 reduce->dimensions().size()); 3007 std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), 3008 reduce_dims.end(), std::back_inserter(new_dimensions)); 3009 return ReplaceWithNewInstruction( 3010 reduce, HloInstruction::CreateReduce( 3011 reduce_result_shape, arg->mutable_operand(0), init_value, 3012 new_dimensions, function)); 3013 } 3014 3015 // A reshape that collapses multiple dimensions into a dimension being 3016 // reduced can just reduce all of those dimensions instead of doing a 3017 // collapsing reshape before a reduction. 3018 if (arg->opcode() == HloOpcode::kReshape) { 3019 std::vector<std::pair<int64, int64>> unmodified_dims = 3020 ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), 3021 arg->shape()); 3022 std::vector<bool> arg_dim_in_output(arg->shape().rank(), true); 3023 std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false); 3024 for (auto dim : dimensions) { 3025 arg_dim_in_output[dim] = false; 3026 } 3027 for (auto dim_pair : unmodified_dims) { 3028 arg_dim_unmodified[dim_pair.second] = true; 3029 } 3030 // The goal is to verify that all dimensions that are not removed in the 3031 // reduce are unmodified by the reshape. For example: 3032 // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2]) 3033 bool can_move_reshape_into_reduce = true; 3034 for (int64 i = 0; i < arg_dim_in_output.size(); ++i) { 3035 if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) { 3036 can_move_reshape_into_reduce = false; 3037 } 3038 } 3039 if (can_move_reshape_into_reduce) { 3040 changed_ = true; 3041 absl::flat_hash_set<int64> dimensions_not_to_reduce; 3042 for (auto dim_pair : unmodified_dims) { 3043 if (arg_dim_in_output[dim_pair.second]) { 3044 dimensions_not_to_reduce.insert(dim_pair.first); 3045 } 3046 } 3047 std::vector<int64> new_reduce_dimensions; 3048 for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { 3049 if (!dimensions_not_to_reduce.contains(i)) { 3050 new_reduce_dimensions.push_back(i); 3051 } 3052 } 3053 return ReplaceWithNewInstruction( 3054 reduce, HloInstruction::CreateReduce( 3055 reduce_result_shape, arg->mutable_operand(0), init_value, 3056 new_reduce_dimensions, function)); 3057 } 3058 } 3059 // Convert Reduce(concat({a,b,...})) to 3060 // map(reduce(a),map(reduce(b),...,)) 3061 // 3062 // This should make fusion easier or use less memory bandwidth in the unfused 3063 // case. 3064 if (arg->opcode() == HloOpcode::kConcatenate && 3065 absl::c_linear_search(reduce->dimensions(), 3066 arg->concatenate_dimension())) { 3067 HloInstruction* old_reduce = nullptr; 3068 for (HloInstruction* operand : arg->operands()) { 3069 HloInstruction* new_reduce = computation_->AddInstruction( 3070 HloInstruction::CreateReduce(reduce_result_shape, operand, init_value, 3071 reduce->dimensions(), function)); 3072 if (old_reduce != nullptr) { 3073 new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( 3074 reduce_result_shape, {old_reduce, new_reduce}, function)); 3075 } 3076 old_reduce = new_reduce; 3077 } 3078 return ReplaceInstruction(reduce, old_reduce); 3079 } 3080 return Status::OK(); 3081 } 3082 3083 Status AlgebraicSimplifierVisitor::HandleReduceWindow( 3084 HloInstruction* reduce_window) { 3085 if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) { 3086 return ReplaceWithNewInstruction( 3087 reduce_window, 3088 HloInstruction::CreateBroadcast(reduce_window->shape(), 3089 reduce_window->mutable_operand(1), {})); 3090 } 3091 auto operand = reduce_window->mutable_operand(0); 3092 const Window& window = reduce_window->window(); 3093 auto function = reduce_window->to_apply(); 3094 if (ShapeUtil::IsScalar(operand->shape())) { 3095 TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape())); 3096 return ReplaceWithNewInstruction( 3097 reduce_window, 3098 HloInstruction::CreateMap(reduce_window->shape(), 3099 {reduce_window->mutable_operand(1), operand}, 3100 function)); 3101 } 3102 3103 if (options_.enable_window_reduce_to_reduce_replacement()) { 3104 // A reduce window can be expressed as a reduce and a reshape if all 3105 // dimensions either have a window size of one or the entire dimension. If 3106 // there is no stride, dilation, or padding, this is as easy as checking the 3107 // size of the output shape and window dimension. 3108 // 3109 // The reshape is a bitcast since it adds one-sized dimensions. Often these 3110 // ones are immediately removed as well with another reshape. The 3111 // implementation of reduce tends to be slightly more efficient at reducing 3112 // entire dimensions compared to reduce window. 3113 auto effective_reduce_dims = [&] { 3114 if (window_util::HasStride(window) || window_util::HasDilation(window) || 3115 window_util::HasPadding(window)) { 3116 return absl::InlinedVector<int64, 8>{}; 3117 } 3118 absl::InlinedVector<int64, 8> reduce_dims; 3119 for (int64 i = 0; i < window.dimensions_size(); ++i) { 3120 if (window.dimensions(i).size() == 1) { 3121 continue; 3122 } else if (reduce_window->shape().dimensions(i) == 1) { 3123 reduce_dims.push_back(i); 3124 } else { 3125 return absl::InlinedVector<int64, 8>{}; 3126 } 3127 } 3128 return reduce_dims; 3129 }(); 3130 3131 // If a reduce window can be expressed as a reduce, do so and reshape the 3132 // output. 3133 if (!effective_reduce_dims.empty()) { 3134 Shape reduce_shape = ShapeUtil::FilterDimensions( 3135 [&](int64 dim) { 3136 return !absl::c_linear_search(effective_reduce_dims, dim); 3137 }, 3138 reduce_window->shape()); 3139 HloInstruction* reduce = 3140 computation_->AddInstruction(HloInstruction::CreateReduce( 3141 /*shape=*/reduce_shape, 3142 /*operand=*/operand, 3143 /*init_value=*/reduce_window->mutable_operand(1), 3144 /*dimensions_to_reduce=*/effective_reduce_dims, 3145 /*reduce_computation=*/function)); 3146 return ReplaceWithNewInstruction( 3147 reduce_window, 3148 HloInstruction::CreateReshape(reduce_window->shape(), reduce)); 3149 } 3150 } 3151 3152 // This optimization folds a pad op into reduce_window. 3153 HloInstruction* pad; 3154 const HloInstruction* convert = nullptr; 3155 if (operand->opcode() == HloOpcode::kPad) { 3156 pad = operand; 3157 } else if (operand->opcode() == HloOpcode::kConvert && 3158 operand->operand(0)->opcode() == HloOpcode::kPad) { 3159 convert = operand; 3160 pad = operand->mutable_operand(0); 3161 } else { 3162 VLOG(10) << "Not folding pad into reduce-window as there is no pad."; 3163 return Status::OK(); 3164 } 3165 3166 // Bail on dilation. 3167 if (window_util::HasDilation(window)) { 3168 VLOG(10) << "Not folding pad into reduce-window as there is dilation."; 3169 return Status::OK(); 3170 } 3171 3172 VLOG(10) << "Considering folding Pad: " << pad->ToString() 3173 << "\ninto reduce-window: " << reduce_window->ToString() 3174 << (convert != nullptr 3175 ? absl::StrCat("\nvia convert: ", convert->ToString()) 3176 : ""); 3177 3178 // Do not fold interior padding into ReduceWindow since the backends do not 3179 // support it. 3180 const PaddingConfig& pad_config = pad->padding_config(); 3181 if (HasInteriorPadding(pad_config)) { 3182 VLOG(10) << "Not folding pad into reduce-window due to interior padding."; 3183 return Status::OK(); 3184 } 3185 3186 // If reduce_window already has padding, the pad value of the pad op and the 3187 // init value of reduce_window must match to allow folding the pad. 3188 const HloInstruction* pad_value = pad->operand(1); 3189 const HloInstruction* reduce_init_value = reduce_window->operand(1); 3190 if (pad_value != reduce_init_value) { 3191 auto literals_are_equivalent = [&] { 3192 auto& pad_literal = pad_value->literal(); 3193 auto& reduce_init_literal = reduce_init_value->literal(); 3194 if (pad_literal == reduce_init_literal) { 3195 return true; 3196 } 3197 auto converted_pad_literal = 3198 pad_literal.ConvertToShape(reduce_init_value->shape()); 3199 if (!converted_pad_literal.ok()) { 3200 return false; 3201 } 3202 return converted_pad_literal.ValueOrDie() == reduce_init_literal; 3203 }; 3204 // The pad value is usually a constant, so we handle that case and do not 3205 // try to get more fancy about proving equivalence in cases beyond that. 3206 if (pad_value->opcode() != HloOpcode::kConstant || 3207 reduce_init_value->opcode() != HloOpcode::kConstant || 3208 !literals_are_equivalent()) { 3209 VLOG(10) << "Not folding pad into reduce-window due to different pad " 3210 "values."; 3211 return Status::OK(); 3212 } 3213 } 3214 3215 // If the pad puts a single non-identity value in each window that we're 3216 // reducing, then this is a broadcast. 3217 HloInstruction* pad_operand = pad->mutable_operand(0); 3218 auto is_effective_broadcast = [&] { 3219 if (window_util::HasStride(window)) { 3220 VLOG(10) << "Window has stride."; 3221 return false; 3222 } 3223 if (!window_util::HasSymmetricPadding(pad_config)) { 3224 VLOG(10) << "Window has uneven padding."; 3225 return false; 3226 } 3227 for (int64 i = 0; i < pad_config.dimensions_size(); ++i) { 3228 const auto& pad_dimension = pad_config.dimensions(i); 3229 if ((pad_dimension.edge_padding_low() != 0 || 3230 pad_dimension.edge_padding_high() != 0) && 3231 pad_operand->shape().dimensions(i) != 1) { 3232 VLOG(10) << "Found non-trivial dimension being padded: " << i; 3233 return false; 3234 } 3235 } 3236 VLOG(10) << "Found to be padding trivial dimensions only."; 3237 3238 for (int64 i = 0; i < window.dimensions_size(); ++i) { 3239 const auto& pad_dimension = pad_config.dimensions(i); 3240 const WindowDimension& window_dimension = window.dimensions(i); 3241 bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 || 3242 pad_dimension.edge_padding_high() != 0); 3243 if (dimension_has_padding && 3244 window_dimension.size() < pad_dimension.edge_padding_low() + 1) { 3245 VLOG(10) << "Found window did not cover single unpadded element in " 3246 "dimension: " 3247 << i; 3248 return false; 3249 } 3250 if (pad_operand->shape().dimensions(i) != 1 && 3251 window_dimension.size() != 1) { 3252 VLOG(10) << "Found window covers more than one element in non-trivial " 3253 "dimension: " 3254 << i; 3255 return false; 3256 } 3257 } 3258 VLOG(10) << "Found window covers a single unpadded element."; 3259 return true; 3260 }; 3261 3262 HloInstruction* new_reduce_window_operand; 3263 if (convert != nullptr) { 3264 new_reduce_window_operand = 3265 computation_->AddInstruction(HloInstruction::CreateConvert( 3266 ShapeUtil::ChangeElementType(pad_operand->shape(), 3267 convert->shape().element_type()), 3268 pad_operand)); 3269 } else { 3270 new_reduce_window_operand = pad_operand; 3271 } 3272 3273 if (is_effective_broadcast()) { 3274 VLOG(10) << "Replacing pad/reduce-window with broadcast."; 3275 auto fadd = [this](std::unique_ptr<HloInstruction> x) { 3276 return computation_->AddInstruction(std::move(x)); 3277 }; 3278 return ReplaceWithNewInstruction( 3279 reduce_window, HloInstruction::CreateBroadcastSequence( 3280 /*output_shape=*/reduce_window->shape(), 3281 /*operand=*/new_reduce_window_operand, fadd)); 3282 } 3283 3284 // Carry out the folding of the pad into reduce_window. 3285 VLOG(10) << "Folding pad into reduce-window."; 3286 Window new_window = window; 3287 const int64 rank = reduce_window->shape().rank(); 3288 TF_RET_CHECK(pad_config.dimensions_size() == rank); 3289 TF_RET_CHECK(window.dimensions_size() == rank); 3290 for (int64 i = 0; i < rank; ++i) { 3291 const auto& pad_dim = pad_config.dimensions(i); 3292 auto& window_dim = *new_window.mutable_dimensions(i); 3293 window_dim.set_padding_low(window_dim.padding_low() + 3294 pad_dim.edge_padding_low()); 3295 window_dim.set_padding_high(window_dim.padding_high() + 3296 pad_dim.edge_padding_high()); 3297 } 3298 3299 return ReplaceWithNewInstruction( 3300 reduce_window, HloInstruction::CreateReduceWindow( 3301 /*shape=*/reduce_window->shape(), 3302 /*operand=*/new_reduce_window_operand, 3303 /*init_value=*/reduce_window->mutable_operand(1), 3304 /*window=*/new_window, 3305 /*reduce_computation=*/function)); 3306 } 3307 3308 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { 3309 // select(x, y, y) -> y. 3310 if (select->operand(1) == select->operand(2)) { 3311 return ReplaceInstruction(select, select->mutable_operand(1)); 3312 } 3313 // select(true, x, y) -> x. 3314 if (IsAll(select->operand(0), true)) { 3315 return ReplaceInstruction(select, select->mutable_operand(1)); 3316 } 3317 // select(false, x, y) -> y. 3318 if (IsAll(select->operand(0), false)) { 3319 return ReplaceInstruction(select, select->mutable_operand(2)); 3320 } 3321 return Status::OK(); 3322 } 3323 3324 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { 3325 auto operand = sort->mutable_operand(0); 3326 int64 dimension_to_sort = sort->dimensions(0); 3327 if (ShapeUtil::IsZeroElementArray(operand->shape()) || 3328 operand->shape().dimensions(dimension_to_sort) <= 1) { 3329 if (sort->operand_count() == 1) { 3330 return ReplaceInstruction(sort, operand); 3331 } 3332 // If it is key/value sort, the output of sort is a tuple. 3333 return ReplaceWithNewInstruction( 3334 sort, HloInstruction::CreateTuple(sort->operands())); 3335 } 3336 return Status::OK(); 3337 } 3338 3339 namespace { 3340 bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, 3341 absl::Span<const int64> perm) { 3342 std::vector<int64> new_permutation; 3343 int64 degenerate_count = 0; 3344 for (int64 i = 0; i < perm.size(); ++i) { 3345 if (shape.dimensions(i) != 1) { 3346 new_permutation.push_back(perm[i]); 3347 } else { 3348 ++degenerate_count; 3349 } 3350 } 3351 return degenerate_count > 1 && absl::c_is_sorted(new_permutation); 3352 } 3353 } // namespace 3354 3355 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { 3356 auto operand = transpose->mutable_operand(0); 3357 if (std::is_sorted(transpose->dimensions().begin(), 3358 transpose->dimensions().end())) { 3359 VLOG(10) << "deleting no-op transpose"; 3360 return ReplaceInstruction(transpose, operand); 3361 } 3362 3363 if (HloOpcode::kTranspose == operand->opcode()) { 3364 return ReplaceWithNewInstruction( 3365 transpose, HloInstruction::CreateTranspose( 3366 transpose->shape(), operand->mutable_operand(0), 3367 ComposePermutations(operand->dimensions(), 3368 transpose->dimensions()))); 3369 } 3370 3371 // Replace transpose with a reshape if more than one degenerate method is 3372 // permuted. 3373 if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), 3374 transpose->dimensions())) { 3375 return ReplaceWithNewInstruction( 3376 transpose, HloInstruction::CreateReshape( 3377 transpose->shape(), transpose->mutable_operand(0))); 3378 } 3379 3380 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { 3381 *operand->mutable_shape() = transpose->shape(); 3382 return ReplaceInstruction(transpose, operand); 3383 } 3384 3385 if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { 3386 ReplaceWithBitcast(transpose); 3387 return Status::OK(); 3388 } 3389 3390 return Status::OK(); 3391 } 3392 3393 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad( 3394 HloInstruction* convolution) { 3395 auto* lhs = convolution->mutable_operand(0); 3396 auto* rhs = convolution->mutable_operand(1); 3397 const auto& window = convolution->window(); 3398 const ConvolutionDimensionNumbers& dnums = 3399 convolution->convolution_dimension_numbers(); 3400 3401 if (lhs->opcode() != HloOpcode::kPad) { 3402 return false; 3403 } 3404 3405 // Convolution's padding is always zero, so bail if the kPad is adding 3406 // something other than zero. 3407 if (!IsAll(lhs->operand(1), 0)) { 3408 return false; 3409 } 3410 3411 const auto& padding = lhs->padding_config(); 3412 3413 // Can't pad batch or feature dims. 3414 for (int64 dim : 3415 {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { 3416 const auto& p = padding.dimensions(dim); 3417 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || 3418 p.interior_padding() != 0) { 3419 return false; 3420 } 3421 } 3422 3423 // Compute the window which is the result of merging the kPad and the 3424 // convolution's existing window. 3425 Window new_window = window; 3426 for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { 3427 auto& w = *new_window.mutable_dimensions(dim); 3428 const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); 3429 // Edge padding composes with itself in the straightforward way, but 3430 // composing interior padding is nontrivial, and we cowardly refuse to 3431 // think about it. If we see interior padding in either the kPad or conv, 3432 // bail if there's any sort of padding in the other. 3433 if (p.interior_padding() != 0 && 3434 (w.padding_low() != 0 || w.padding_high() != 0 || 3435 w.base_dilation() != 1)) { 3436 return false; 3437 } 3438 if (w.base_dilation() != 1 && 3439 (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || 3440 p.interior_padding() != 0)) { 3441 return false; 3442 } 3443 3444 w.set_padding_low(w.padding_low() + p.edge_padding_low()); 3445 w.set_padding_high(w.padding_high() + p.edge_padding_high()); 3446 if (p.interior_padding() != 0) { 3447 CHECK_EQ(w.base_dilation(), 1); 3448 w.set_base_dilation(1 + p.interior_padding()); 3449 } 3450 } 3451 3452 auto new_conv = convolution->CloneWithNewOperands( 3453 convolution->shape(), {lhs->mutable_operand(0), rhs}); 3454 new_conv->set_window(new_window); 3455 TF_RETURN_IF_ERROR( 3456 ReplaceWithNewInstruction(convolution, std::move(new_conv))); 3457 return true; 3458 } 3459 3460 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad( 3461 HloInstruction* convolution) { 3462 auto* lhs = convolution->mutable_operand(0); 3463 auto* rhs = convolution->mutable_operand(1); 3464 const ConvolutionDimensionNumbers& dnums = 3465 convolution->convolution_dimension_numbers(); 3466 3467 if (rhs->opcode() != HloOpcode::kPad) { 3468 return false; 3469 } 3470 3471 // Convolution's padding is always zero, so bail if the kPad is adding 3472 // something other than zero. 3473 if (!IsAll(rhs->operand(1), 0)) { 3474 return false; 3475 } 3476 3477 const auto& padding = rhs->padding_config(); 3478 3479 // Can't pad or dilate feature dims. 3480 for (int64 dim : {dnums.kernel_input_feature_dimension(), 3481 dnums.kernel_output_feature_dimension()}) { 3482 const auto& p = padding.dimensions(dim); 3483 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || 3484 p.interior_padding() != 0) { 3485 return false; 3486 } 3487 } 3488 3489 // Compute the window which is the result of merging the kPad and the 3490 // convolution's existing window. 3491 Window new_window = convolution->window(); 3492 for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { 3493 auto& w = *new_window.mutable_dimensions(dim); 3494 const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); 3495 3496 // We can only do this transformation if p adds dilation to the filter -- 3497 // edge padding on the filter is not supported in conv. 3498 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { 3499 return false; 3500 } 3501 3502 // Nothing to do if the kPad for this dim is entirely a nop. 3503 if (p.interior_padding() == 0) { 3504 continue; 3505 } 3506 3507 // We cowardly refuse to think about how dilation composes with itself; 3508 // bail if both the kPad and conv have dilation on this dimension. 3509 if (w.window_dilation() > 1) { 3510 return false; 3511 } 3512 CHECK_EQ(w.window_dilation(), 1); 3513 w.set_window_dilation(1 + p.interior_padding()); 3514 w.set_size(rhs->operand(0)->shape().dimensions( 3515 dnums.kernel_spatial_dimensions(dim))); 3516 } 3517 3518 auto new_conv = convolution->CloneWithNewOperands( 3519 convolution->shape(), {lhs, rhs->mutable_operand(0)}); 3520 new_conv->set_window(new_window); 3521 TF_RETURN_IF_ERROR( 3522 ReplaceWithNewInstruction(convolution, std::move(new_conv))); 3523 return true; 3524 } 3525 3526 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot( 3527 HloInstruction* convolution) { 3528 auto* lhs = convolution->mutable_operand(0); 3529 auto* rhs = convolution->mutable_operand(1); 3530 const auto& window = convolution->window(); 3531 const ConvolutionDimensionNumbers& dnums = 3532 convolution->convolution_dimension_numbers(); 3533 3534 if (!options_.enable_conv_simplification()) { 3535 return false; 3536 } 3537 3538 // TODO(b/31337498): For now, we cowardly refuse to do this optimization in 3539 // layout-insensitive mode, for fear of adding nontrivial reshapes. 3540 if (!options_.is_layout_sensitive()) { 3541 return false; 3542 } 3543 3544 const Shape& input_shape = lhs->shape(); 3545 const Shape& filter_shape = rhs->shape(); 3546 const Shape& convolution_shape = convolution->shape(); 3547 TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); 3548 TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape)); 3549 TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape)); 3550 3551 // Require the spatial dimensions in the kernel to have a bound of one. 3552 for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { 3553 if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { 3554 return false; 3555 } 3556 } 3557 3558 // Stride ignores part of the output, which matrix multiplication does not do, 3559 // so require no stride. Padding and base (lhs) dilation both implicitly 3560 // extend the data, which matrix multiplication also does not do, so require 3561 // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect 3562 // for a 1x1 window, so window dilation is no problem. 3563 if (window_util::HasStride(window) || window_util::HasPadding(window) || 3564 window_util::HasBaseDilation(window)) { 3565 return false; 3566 } 3567 3568 // Also, the shapes must align for a rowmajor matmul: 3569 // - the input and output have the same layout. 3570 // - for input/output, the channel dimension must be the most minor. Other 3571 // spatial dims can be in any order. 3572 // - for filters, the input channel dimension must be more major than the 3573 // output channel dimension. The width+height don't matter because 3574 // they are 1. 3575 // 3576 // These constraints are harsh. If the channel dimension is the most major 3577 // and/or the layout of input/output feature dimensions are reversed, we can 3578 // still convert Conv into more efficient Matmul with operand transposition 3579 // (such as the transposition flags in cuBLAS SGEMM). 3580 if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || 3581 LayoutUtil::Minor(input_shape.layout(), 0) != 3582 dnums.input_feature_dimension() || 3583 LayoutUtil::Minor(convolution_shape.layout(), 0) != 3584 dnums.output_feature_dimension() || 3585 // The input feature dimension should come later in the minor-to-major 3586 // order. 3587 (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), 3588 dnums.kernel_input_feature_dimension()) < 3589 PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), 3590 dnums.kernel_output_feature_dimension()))) { 3591 return false; 3592 } 3593 3594 auto add_bitcast = [&](Shape shape, HloInstruction* operand) { 3595 std::vector<int64> dims(operand->shape().dimensions_size()); 3596 std::iota(dims.begin(), dims.end(), 0); 3597 return computation_->AddInstruction( 3598 HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand)); 3599 }; 3600 3601 // Replace it with a dot, with bitcasts around it to get the right shape. 3602 const int64 input_channels = 3603 input_shape.dimensions(dnums.input_feature_dimension()); 3604 const int64 output_channels = 3605 filter_shape.dimensions(dnums.kernel_output_feature_dimension()); 3606 3607 // Computes the product of the non-feature dimensions. 3608 int64 conv_width = 1; 3609 for (int i = 0; i < input_shape.dimensions_size(); ++i) { 3610 if (i != dnums.input_feature_dimension()) { 3611 conv_width *= input_shape.dimensions(i); 3612 } 3613 } 3614 3615 // We already checked feature_dimension is most minor, so data in input_shape 3616 // and row-major {conv_width,input_channels} are bitwise identical. 3617 const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( 3618 input_shape.element_type(), {conv_width, input_channels}); 3619 // We already checked input_feature_dimension is more major than 3620 // output_feature_dimension, so data in filter_shape and row-major 3621 // {input_channels,output_channels} are bitwise identical. 3622 const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( 3623 filter_shape.element_type(), {input_channels, output_channels}); 3624 const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( 3625 convolution_shape.element_type(), {conv_width, output_channels}); 3626 3627 auto new_lhs = add_bitcast(new_input_shape, lhs); 3628 auto new_rhs = add_bitcast(new_filter_shape, rhs); 3629 DotDimensionNumbers dot_dimension_numbers; 3630 dot_dimension_numbers.add_lhs_contracting_dimensions(1); 3631 dot_dimension_numbers.add_rhs_contracting_dimensions(0); 3632 auto dot = computation_->AddInstruction(HloInstruction::CreateDot( 3633 dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, 3634 convolution->precision_config())); 3635 3636 TF_RETURN_IF_ERROR( 3637 ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); 3638 return true; 3639 } 3640 3641 Status AlgebraicSimplifierVisitor::HandleConvolution( 3642 HloInstruction* convolution) { 3643 // Zero-sized input or filter. 3644 if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || 3645 ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { 3646 return ReplaceWithNewInstruction( 3647 convolution, 3648 HloInstruction::CreateBroadcast( 3649 convolution->shape(), 3650 computation_->AddInstruction(HloInstruction::CreateConstant( 3651 LiteralUtil::Zero(convolution->shape().element_type()))), 3652 {})); 3653 } 3654 3655 // Try to merge padding/dilation of the input with the convolution's window. 3656 TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); 3657 if (folded_input_pad) { 3658 return Status::OK(); 3659 } 3660 3661 // Try to merge dilation of the filter with the convolution's window. 3662 TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); 3663 if (folded_filter_pad) { 3664 return Status::OK(); 3665 } 3666 3667 // Try to replace the convolution with a kDot instruction. 3668 TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); 3669 if (replaced_with_dot) { 3670 return Status::OK(); 3671 } 3672 3673 return Status::OK(); 3674 } 3675 3676 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( 3677 HloInstruction* root, HloInstruction* min, HloInstruction* min_operand, 3678 HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) { 3679 // Ensure shapes of min and max operand are equal to match current shape 3680 // inference. 3681 if (!SameShape(min_operand, max_operand)) { 3682 return false; 3683 } 3684 3685 auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, 3686 max_operand, operand, min_operand); 3687 TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp))); 3688 return true; 3689 } 3690 3691 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { 3692 auto* map_computation = map->to_apply(); 3693 auto* map_root = map_computation->root_instruction(); 3694 if (map_root->opcode() == HloOpcode::kParameter) { 3695 ReplaceInstructionIfSameShape( 3696 map, map->mutable_operand(map_root->parameter_number())); 3697 return Status::OK(); 3698 } 3699 if (map_root->opcode() == HloOpcode::kConstant) { 3700 if (!ShapeUtil::IsScalar(map_root->shape())) { 3701 return Status::OK(); 3702 } 3703 auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); 3704 if (ShapeUtil::IsScalar(map->shape())) { 3705 return ReplaceWithNewInstruction(map, std::move(clone)); 3706 } 3707 return ReplaceWithNewInstruction( 3708 map, 3709 HloInstruction::CreateBroadcast( 3710 map->shape(), computation_->AddInstruction(std::move(clone)), {})); 3711 } 3712 // Inline the map if the map computation only contains an elementwise 3713 // operation that can accept arbitrary shapes. 3714 if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) { 3715 return Status::OK(); 3716 } 3717 std::vector<HloInstruction*> new_operands; 3718 for (auto* root_operand : map_root->operands()) { 3719 if (root_operand->opcode() != HloOpcode::kParameter) { 3720 return Status::OK(); 3721 } 3722 new_operands.push_back( 3723 map->mutable_operand(root_operand->parameter_number())); 3724 } 3725 auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); 3726 return ReplaceWithNewInstruction(map, std::move(clone)); 3727 } 3728 3729 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { 3730 XLA_VLOG_LINES(2, 3731 "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); 3732 bool changed = false; 3733 for (auto* comp : module->MakeNonfusionComputations()) { 3734 if (AlgebraicSimplifierVisitor::Run(comp, options_)) { 3735 changed = true; 3736 } 3737 } 3738 XLA_VLOG_LINES(2, 3739 "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); 3740 return changed; 3741 } 3742 3743 } // namespace xla 3744