1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implementation note: 17 // 18 // The general idea behind this pass is that we're converting from this: 19 // %param.A = OldShape 20 // %param.B = OldShape 21 // %reshape.A = NewShape reshape(%param.A) 22 // %reshape.B = NewShape reshape(%param.B) 23 // %instruction = NewShape instruction(%reshape.A, %reshape.B) 24 // To this: 25 // %param.A = OldShape 26 // %param.B = OldShape 27 // %instruction = OldShape instruction(%param.A, %param.B) 28 // %reshape = NewShape reshape(%instruction) 29 // 30 // Where the instruction must be elementwise, and both reshapes and transposes 31 // are moved. 32 // 33 // Most elementwise instructions support implicit broadcast of scalar operands, 34 // but select is a special-case. The signature is Select(Pred, A, B), and the 35 // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or 36 // transposes to a scalar should be cheap, we simply never move them. 37 38 #include "tensorflow/compiler/xla/service/reshape_mover.h" 39 40 #include <algorithm> 41 #include "tensorflow/compiler/xla/literal_util.h" 42 #include "tensorflow/compiler/xla/shape_util.h" 43 #include "tensorflow/compiler/xla/status_macros.h" 44 #include "tensorflow/compiler/xla/util.h" 45 #include "tensorflow/core/lib/core/errors.h" 46 47 namespace xla { 48 49 namespace { 50 51 bool IsReshapeOrTranspose(const HloInstruction* instruction) { 52 return instruction->opcode() == HloOpcode::kReshape || 53 instruction->opcode() == HloOpcode::kTranspose; 54 } 55 56 // Returns true iff `instruction` can change its shape simply by adjusting 57 // metadata. 58 bool CanTriviallyChangeShape(const HloInstruction* instruction) { 59 // NOTE: Technically a sequence of reshape(reshape(constant)) is also 60 // trivially reshapable, so we might be tempted to simply recurse if 61 // IsReshapeOrTranspose(instruction)==true. 62 // 63 // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially 64 // reshapable if *all* instructions in the chain have user_count == 1. And 65 // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we 66 // rely on implicit scalar broadcast for scalars to be trivial. In addition, 67 // these cases make it harder to maintain correctness of the UpdateOperand 68 // logic below. 69 // 70 // So don't handle these chains, unless you update the tests and code to deal 71 // with these properly. One idea is to add a pass immediately beforehand that 72 // collapses trivial runs of reshapes / transposes. 73 74 // Scalars can operate with any shape. 75 if (ShapeUtil::IsScalar(instruction->shape())) { 76 return true; 77 } 78 79 // A constant can trivially reshape the literal it holds. 80 if (instruction->opcode() == HloOpcode::kConstant) { 81 return true; 82 } 83 84 // An Rng instruction can be any shape as long as it has one user. Two copies 85 // of the same Rng would be problematic if an Rng of a different shape would 86 // produce random numbers in a different order. 87 if (instruction->opcode() == HloOpcode::kRng && 88 instruction->user_count() == 1) { 89 return true; 90 } 91 return false; 92 } 93 94 // Finds the first non-scalar operand of an instruction that is a non-trivial 95 // reshape or transpose. Returns the operand if it is found or nullptr if not 96 // found. 97 HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( 98 const HloInstruction* hlo) { 99 for (HloInstruction* operand : hlo->operands()) { 100 if (!ShapeUtil::IsScalar(operand->shape()) && 101 IsReshapeOrTranspose(operand) && 102 !CanTriviallyChangeShape(operand->operand(0))) { 103 VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " 104 << hlo->ToString(HloPrintOptions().set_print_metadata(false)) 105 << ":\n\t" 106 << operand->ToString(HloPrintOptions().set_print_metadata(false)); 107 return operand; 108 } 109 } 110 return nullptr; 111 } 112 113 // Returns whether `a` and `b` are equivalent for the purposes of this pass. 114 bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { 115 if (a->opcode() != b->opcode() || 116 !ShapeUtil::SameDimensions(a->shape(), b->shape())) { 117 return false; 118 } 119 switch (a->opcode()) { 120 case HloOpcode::kTranspose: 121 return a->dimensions() == b->dimensions(); 122 case HloOpcode::kReshape: 123 return ShapeUtil::SameDimensions(a->operand(0)->shape(), 124 b->operand(0)->shape()); 125 default: 126 return false; 127 } 128 } 129 130 // Returns true if all operands of `instruction` can easily change shape. 131 // Operands can easily change shape if they are all reshapes/transposes to and 132 // from the same shape. Additionally, operands like constant, rng, and any 133 // scalar change shape with only an adjustment of metadata. 134 bool AllOperandsHaveEasyShapeChanges( 135 const HloInstruction* instruction, 136 const HloInstruction* first_reshape_operand) { 137 auto print_no_metadata = HloPrintOptions().set_print_metadata(false); 138 VLOG(3) << "** Checking whether all operands have easy shape changes: " 139 << instruction->ToString(print_no_metadata); 140 // Check whether all operands: 141 // 0. Have the same dimensions as the output -- if not, it may be 142 // implicitly broadcast, which can confound the movement's 143 // correctness. 144 // 145 // And one of the following: 146 // 1. Are reshapes or transposes that have the same input and 147 // output shapes as all other reshaped or transposed operands. 148 // or 149 // 2. Are one of kConstant, kRng, and scalars that can change shape 150 // trivially, 151 for (const HloInstruction* operand : instruction->operands()) { 152 if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { 153 VLOG(5) << "Operand shape differs from output shape; may be " 154 "implicitly broadcast, so preventing " 155 "movement\n\toperand: " 156 << operand->ToString(print_no_metadata) << "\n\tinstruction: " 157 << instruction->ToString(print_no_metadata); 158 return false; 159 } 160 161 if (AreEquivalentReshapes(first_reshape_operand, operand)) { 162 VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " 163 << first_reshape_operand->ToString(print_no_metadata) 164 << "\n\toperand: " << operand->ToString(print_no_metadata); 165 continue; 166 } 167 168 if (CanTriviallyChangeShape(operand)) { 169 VLOG(5) << "Operand can trivially change shape: " 170 << operand->ToString(print_no_metadata); 171 continue; 172 } 173 174 // TODO(someone): Look into supporting general ops for the operands as 175 // well. 176 VLOG(5) << "Operand is neither equalivant to the first Reshape operand" 177 "nor can trivially change shape: " 178 << operand->ToString(print_no_metadata); 179 return false; 180 } 181 182 VLOG(3) << "All operands have easy shape changes: " 183 << instruction->ToString(print_no_metadata); 184 return true; 185 } 186 187 // This function is called once we've decided to sink reshape/transpose operands 188 // across an instruction. It returns an updated `operand` with a shape that 189 // plays nicely with `new_operand_shape`; either it has the same shape (of the 190 // correct type), or it is a scalar that may be implicitly broadcast. 191 HloInstruction* UpdateOperand(HloComputation* computation, 192 const HloInstruction* first_reshape_operand, 193 const Shape& new_operand_shape, 194 HloInstruction* operand) { 195 const PrimitiveType element_type = operand->shape().element_type(); 196 const Shape new_shape = 197 ShapeUtil::ChangeElementType(new_operand_shape, element_type); 198 199 switch (operand->opcode()) { 200 case HloOpcode::kConstant: { 201 if (first_reshape_operand->opcode() == HloOpcode::kReshape) { 202 VLOG(5) << "Adding reshape to kConstant operand"; 203 return computation->AddInstruction( 204 HloInstruction::CreateReshape(new_shape, operand)); 205 } else { 206 CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); 207 VLOG(5) << "Adding transpose to kConstant operand"; 208 std::vector<int64> inverse_permutation = 209 InversePermutation(first_reshape_operand->dimensions()); 210 return computation->AddInstruction(HloInstruction::CreateTranspose( 211 new_shape, operand, inverse_permutation)); 212 } 213 } 214 case HloOpcode::kRng: { 215 CHECK_EQ(operand->user_count(), 1); 216 VLOG(5) << "Cloning kRng operand with new shape"; 217 return computation->AddInstruction( 218 operand->CloneWithNewOperands(new_shape, operand->operands())); 219 } 220 case HloOpcode::kReshape: 221 case HloOpcode::kTranspose: { 222 VLOG(5) << "Using existing operand of kReshape or kTranspose"; 223 return operand->mutable_operand(0); 224 } 225 default: 226 LOG(FATAL) << "Unexpected operand opcode during update: " << operand; 227 } 228 } 229 230 // Try to sink any reshape or transpose operands of `instruction` across it. We 231 // do so if `instruction` is elementwise and all operands are either equivalent 232 // reshapes/transposes or are trivially reshapable. 233 StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, 234 HloInstruction* instruction) { 235 // Only perform sinks for live elementwise instructions with operands. 236 const bool is_dead = instruction->user_count() == 0 && 237 instruction != computation->root_instruction(); 238 if (!instruction->IsElementwise() || instruction->operands().empty() || 239 is_dead) { 240 return false; 241 } 242 243 // Only perform sinks if there are any nontrivial reshape/transpose operands. 244 const HloInstruction* first_reshape_operand = 245 FirstNonScalarAndNonTrivialReshapeOperand(instruction); 246 if (!first_reshape_operand) { 247 return false; 248 } 249 250 // Only perform sinks if all operands can easily change shape. 251 if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { 252 return false; 253 } 254 255 auto print_no_metadata = HloPrintOptions().set_print_metadata(false); 256 // At this point we've decided to sink reshape/transpose operands. 257 const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); 258 VLOG(3) << "** Sinking reshape or transpose: " 259 << instruction->ToString(print_no_metadata) 260 << "\n\tfirst reshape operand: " 261 << first_reshape_operand->ToString(print_no_metadata) 262 << "\n\tnew operand shape: " 263 << ShapeUtil::HumanString(new_operand_shape); 264 265 auto operands = instruction->operands(); 266 for (size_t i = 0; i < operands.size(); ++i) { 267 // All scalar operands remain as-is, even if they're reshape or transpose, 268 // to simplify handling wrt special scalar broadcast rules for ops like 269 // Select. Scalar reshapes should be cheap anyways. 270 if (ShapeUtil::IsScalar(operands[i]->shape())) { 271 continue; 272 } 273 VLOG(3) << "Updating operand #" << i << ": " 274 << operands[i]->ToString(print_no_metadata); 275 operands[i] = UpdateOperand(computation, first_reshape_operand, 276 new_operand_shape, operands[i]); 277 } 278 if (HloOpcode::kFusion == instruction->opcode()) { 279 // Here we already know `instruction` is elementwise, and no operand is 280 // implicit broadcast as if it were the operands would not have easy shape 281 // changes, so all the fused instructions have the same dimensions. 282 for (const auto& fused_instruction : instruction->fused_instructions()) { 283 Shape* shape = fused_instruction->mutable_shape(); 284 *shape->mutable_dimensions() = new_operand_shape.dimensions(); 285 *shape->mutable_layout() = new_operand_shape.layout(); 286 } 287 } 288 HloInstruction* new_elementwise = 289 computation->AddInstruction(instruction->CloneWithNewOperands( 290 // `instruction` may change the element type, e.g., from 291 // operands[0] -> reshape -> convert (`instruction`) 292 // to 293 // operands[0] -> convert' -> reshape' 294 // 295 // In this case, convert' should have the same element type as 296 // `convert` and the same dimensions as operands[0]. 297 ShapeUtil::ChangeElementType(new_operand_shape, 298 instruction->shape().element_type()), 299 operands)); 300 301 std::unique_ptr<HloInstruction> new_reshape; 302 switch (first_reshape_operand->opcode()) { 303 case HloOpcode::kReshape: 304 VLOG(3) << "Creating new reshape for new elementwise op: " 305 << new_elementwise->ToString(print_no_metadata); 306 new_reshape = 307 HloInstruction::CreateReshape(instruction->shape(), new_elementwise); 308 break; 309 case HloOpcode::kTranspose: 310 new_reshape = 311 HloInstruction::CreateTranspose(instruction->shape(), new_elementwise, 312 first_reshape_operand->dimensions()); 313 break; 314 default: 315 LOG(FATAL) << "Bad opcode"; 316 } 317 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( 318 instruction, std::move(new_reshape))); 319 return true; 320 } 321 322 } // namespace 323 324 StatusOr<bool> ReshapeMover::Run(HloModule* module) { 325 bool changed = false; 326 VLOG(2) << "Pre ReshapeMover HLO:"; 327 XLA_VLOG_LINES(2, module->ToString()); 328 for (auto* comp : module->MakeNonfusionComputations()) { 329 for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { 330 TF_ASSIGN_OR_RETURN(bool did_change, 331 TrySinkReshapeOrTranspose(comp, instruction)); 332 changed |= did_change; 333 } 334 } 335 VLOG(2) << "Post ReshapeMover HLO:"; 336 XLA_VLOG_LINES(2, module->ToString()); 337 return changed; 338 } 339 340 } // namespace xla 341