1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 18 19 #include <iosfwd> 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <utility> 26 #include <vector> 27 28 #include "tensorflow/compiler/xla/service/computation_layout.h" 29 #include "tensorflow/compiler/xla/service/hlo_computation.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 33 #include "tensorflow/compiler/xla/service/logical_buffer.h" 34 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 35 #include "tensorflow/compiler/xla/shape_layout.h" 36 #include "tensorflow/compiler/xla/shape_util.h" 37 #include "tensorflow/compiler/xla/statusor.h" 38 #include "tensorflow/compiler/xla/types.h" 39 #include "tensorflow/compiler/xla/xla_data.pb.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/platform/types.h" 42 43 namespace xla { 44 45 // Abstract base class for layout constraints. These constraint objects are 46 // gathered together in LayoutConstraints object. 47 class LayoutConstraint { 48 public: 49 LayoutConstraint(bool mandatory, bool dfs) 50 : mandatory_(mandatory), dfs_(dfs) {} 51 virtual ~LayoutConstraint() = default; 52 53 virtual string ToString() const = 0; 54 55 // True if this constraint cannot be overwritten by a different constraint. 56 bool mandatory() const { return mandatory_; } 57 58 // When true, propagate in DFS. When false, constraint will propagate in BFS. 59 bool dfs() const { return dfs_; } 60 61 private: 62 bool mandatory_; 63 bool dfs_; 64 }; 65 66 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); 67 68 // Layout constraint on a single LogicalBuffer. This constrains the layout of an 69 // array produced by a particular instruction. 70 class BufferLayoutConstraint : public LayoutConstraint { 71 public: 72 BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, 73 bool mandatory, bool dfs); 74 75 const LogicalBuffer& buffer() const { return *buffer_; } 76 const Layout& layout() const { return layout_; } 77 78 string ToString() const override; 79 80 private: 81 Layout layout_; 82 const LogicalBuffer* buffer_; 83 }; 84 85 // Constraint on the layout of the operand of an instruction. The constrained 86 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the 87 // use of a shaped value and is not a hard constraint on the instruction(s) 88 // which define the value as copies may be inserted between the definition and 89 // use. 90 class OperandLayoutConstraint : public LayoutConstraint { 91 public: 92 OperandLayoutConstraint(const ShapeLayout& shape_layout, 93 const HloInstruction* instruction, int64 operand_no, 94 bool mandatory, bool dfs); 95 96 const ShapeLayout& shape_layout() const { return shape_layout_; } 97 const HloInstruction* instruction() const { return instruction_; } 98 const int64 operand_no() const { return operand_no_; } 99 const HloInstruction* operand() const { 100 return instruction_->operand(operand_no_); 101 } 102 103 string ToString() const override; 104 105 private: 106 ShapeLayout shape_layout_; 107 const HloInstruction* instruction_; 108 int64 operand_no_; 109 }; 110 111 // Constraint on the layout of the result of the entry computation. 112 class ResultLayoutConstraint : public LayoutConstraint { 113 public: 114 explicit ResultLayoutConstraint(const ShapeLayout& shape_layout, 115 bool dfs = false) 116 : LayoutConstraint(/*mandatory=*/true, dfs), 117 shape_layout_(shape_layout) {} 118 119 const ShapeLayout& shape_layout() const { return shape_layout_; } 120 string ToString() const override; 121 122 private: 123 const ShapeLayout shape_layout_; 124 }; 125 126 // Class encapsulating the layout constraints of the values in a HLO 127 // computation. 128 class LayoutConstraints { 129 public: 130 LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis, 131 HloComputation* computation); 132 ~LayoutConstraints() = default; 133 134 const HloComputation* computation() const { return computation_; } 135 HloComputation* computation() { return computation_; } 136 const TuplePointsToAnalysis& points_to_analysis() const { 137 return points_to_analysis_; 138 } 139 140 // Return a vector containing the constraints which have been added to the 141 // LayoutConstraints object since the construction of the object or since the 142 // last time ConsumeAddedConstraints() has been called. This is used to 143 // identify newly added constraints when propagating layouts. 144 std::vector<const LayoutConstraint*> ConsumeAddedConstraints() { 145 std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_)); 146 added_constraints_.clear(); 147 return ret_vec; 148 } 149 void ClearAddedConstraints() { added_constraints_.clear(); } 150 151 // Returns the layout of a LogicalBuffer, the layout of the operand of the 152 // instruction, or the layout of the result of the computation, respectively, 153 // if it has been constrained. Otherwise return nullptr. 154 const Layout* BufferLayout(const LogicalBuffer& buffer) const; 155 const BufferLayoutConstraint* GetBufferLayoutConstraint( 156 const LogicalBuffer& buffer) const; 157 const ShapeLayout* OperandLayout(const HloInstruction* instruction, 158 int64 operand_no) const; 159 const OperandLayoutConstraint* GetOperandLayoutConstraint( 160 const HloInstruction* instruction, int64 operand_no) const; 161 const ShapeLayout* ResultLayout() const; 162 163 // Add a constraint on the layout of a LogicalBuffer, the layout of the 164 // operand of the instruction, or the layout of the result of the computation, 165 // respectively. 166 Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, 167 bool mandatory = true, bool dfs = true); 168 Status SetOperandLayout(const Shape& shape_with_layout, 169 const HloInstruction* instruction, int64 operand_no, 170 bool mandatory = true, bool dfs = true); 171 Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true); 172 173 // Convenience wrapper around SetOperandLayout for setting the layout of a 174 // operand using a Layout object. The operand must be array-shaped. 175 Status SetArrayOperandLayout(const Layout& layout, 176 const HloInstruction* instruction, 177 int64 operand_no, bool mandatory = true, 178 bool dfs = true); 179 180 // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers 181 // created by the instruction to the layouts in the given shape. The 182 // instruction must define every logical buffer in its output. 183 Status SetInstructionLayout(const Shape& shape_with_layout, 184 const HloInstruction* instruction, 185 bool mandatory = true, bool dfs = true); 186 187 // Returns true if any buffer in the given operand is forwarded to the output 188 // of the given instruction. For example, the Tuple instruction forwards the 189 // buffers of its operands and would return true for each of its operands. 190 bool OperandBufferForwarded(const HloInstruction* instruction, 191 int64 operand_no) const; 192 193 // Returns the set of logical buffers (by LogicalBuffer:Id) which do not 194 // yet have a layout constraint 195 const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const { 196 return unconstrained_buffer_ids_; 197 } 198 199 string ToString() const; 200 201 private: 202 // The set of BufferLayoutConstraints applied to the computation. 203 std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint> 204 buffer_constraints_; 205 206 // The set of OperandLayoutConstraints applied to the computation. 207 using OperandConstraintKey = std::pair<const HloInstruction*, int64>; 208 std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_; 209 210 // The result constraint for the computation (can be null). 211 std::unique_ptr<ResultLayoutConstraint> result_constraint_; 212 213 // A vector which holds constraints as they are added. Can be cleared with 214 // ClearAddedConstraints. 215 std::vector<const LayoutConstraint*> added_constraints_; 216 217 // Points-to analysis for the module. Used to propagate constraints through 218 // the HLO graph. 219 const TuplePointsToAnalysis& points_to_analysis_; 220 221 // Array-shaped buffers which have not yet been constrained. 222 std::set<LogicalBuffer::Id> unconstrained_buffer_ids_; 223 224 HloComputation* computation_; 225 }; 226 227 // Contains constraints on the layout of channels; sends and recvs. 228 class ChannelLayoutConstraints { 229 public: 230 // Construct an empty constraint set. 231 ChannelLayoutConstraints() {} 232 233 // Returns true if channel_id has a layout constraint. 234 bool IsChannelConstrained(int64 channel_id) const { 235 return constraints_.count(channel_id) > 0; 236 } 237 238 // Given `shape`, apply the layout for `channel_id`. `channel_id` must already 239 // be constrained. 240 Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { 241 CHECK(IsChannelConstrained(channel_id)); 242 *shape.mutable_layout() = constraints_.at(channel_id); 243 return shape; 244 } 245 246 // Returns the layout constraint for `channel_id`, which must already be 247 // constrained. 248 Layout LayoutForChannel(int64 channel_id) const { 249 CHECK(IsChannelConstrained(channel_id)); 250 return constraints_.at(channel_id); 251 } 252 253 // Adds a new layout constraint for `channel_id`. If a constraint for 254 // `channel_id` already exists, this operation requires that the new layout is 255 // the same as the previously constrained layout. 256 void ConstrainChannel(int64 channel_id, const Layout& layout) { 257 CHECK(!IsChannelConstrained(channel_id) || 258 LayoutUtil::Equal(layout, constraints_[channel_id])); 259 constraints_[channel_id] = layout; 260 } 261 262 private: 263 std::unordered_map<int64, Layout> constraints_; 264 }; 265 266 // HLO pass which assigns layouts to all instructions in the HLO module while 267 // satisfying all necessary invariants and minimizing cost. 268 class LayoutAssignment : public HloPassInterface { 269 public: 270 // entry_computation_layout is modified to populate a layout for the result in 271 // the case that no particular layout is requested. 272 // 273 // channel_constraints is both an input and output. Any sends or recvs that 274 // are present in channel_constraints will be layed out as constrained. Any 275 // unconstrained sends or recvs will be layed out as locally optimal and their 276 // layout will be added as a constraint to channel_constraints. 277 // 278 // If channel_constraints is nullptr, no kSend or kRecvs must be contained 279 // within any module passed to `Run`. 280 explicit LayoutAssignment( 281 ComputationLayout* entry_computation_layout, 282 ChannelLayoutConstraints* channel_constraints = nullptr); 283 ~LayoutAssignment() override {} 284 tensorflow::StringPiece name() const override { return "layout-assignment"; } 285 286 // Assign layouts to the given module. Returns whether the module was changed 287 // (any layouts were changed). 288 StatusOr<bool> Run(HloModule* module) override; 289 290 protected: 291 // These methods, invoked by PropagateConstraints, propagate a layout 292 // constraint to its neighbors (i.e. operands and users) in order to minimize 293 // the cost of the instructions being constrainted on. New constraints are 294 // added to the given constraint set. 295 // 296 // Backends can override these methods with backend-specific propagation 297 // rules. 298 virtual Status PropagateBufferConstraint( 299 const BufferLayoutConstraint& layout_constraint, 300 LayoutConstraints* constraints); 301 virtual Status PropagateOperandConstraint( 302 const OperandLayoutConstraint& layout_constraint, 303 LayoutConstraints* constraints); 304 virtual Status PropagateResultConstraint( 305 const ResultLayoutConstraint& layout_constraint, 306 LayoutConstraints* constraints); 307 308 // By default LayoutAssignment ensures that inputs and outputs of CustomCalls 309 // have the "major-first" layout (i.e. {n, n-1, ..., 0}). 310 // 311 // If this function returns true, LayoutAssignment does not set a layout for 312 // the given CustomCall. It's up to the backend to set one in 313 // AddBackendConstraints, if necessary. 314 // 315 // Precondition: instruction->opcode() == HloOpcode::kCustomCall. 316 virtual bool CustomCallRequiresMajorFirstLayout( 317 const HloInstruction* /*instruction*/) { 318 return true; 319 } 320 321 // Called after layouts of an instruction have been finalized to allow 322 // subclasses to check for platform specific assumptions. 323 virtual Status Verify(const HloInstruction* instruction) { 324 return Status::OK(); 325 } 326 327 // Propagates a buffer layout constraint into the operands that use it. 328 Status PropagateBufferConstraintToUses( 329 const BufferLayoutConstraint& layout_constraint, 330 LayoutConstraints* constraints); 331 332 // Propagates a layout constraint on the use of the result of the given 333 // instruction to the definitions of the LogicalBuffers which make up the 334 // result. 335 Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, 336 const HloInstruction* instruction, 337 LayoutConstraints* constraints); 338 339 // Chooses a layout of operand `operand_no` of `instruction` that minimizes 340 // the cost of `instruction`. `output_layout` is the layout of `instruction`. 341 // Returns null if it can't decide the best layout. 342 // Precondition: `instruction` and the operand are array-shaped. 343 std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout( 344 const Layout& output_layout, const HloInstruction* instruction, 345 int64 operand_no); 346 // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of 347 // `user` that minimizes its cost on that operand. Returns null if it can't 348 // decide the best layout. 349 // Precondition: `user` and the operand are array-shaped. 350 std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout( 351 const Layout& operand_layout, const HloInstruction* user, 352 int64 operand_no); 353 354 private: 355 // Adds constraints which must be satisfied for correctness on all 356 // backends. Called once prior to propagating constraints. 357 Status AddMandatoryConstraints( 358 const ComputationLayout& computation_layout, 359 const ChannelLayoutConstraints* channel_constraints, 360 HloComputation* computation, LayoutConstraints* constraints); 361 362 // This method can be overridden to add backend-specific constraints to the 363 // layout of the instructions of a computation. This method is called after 364 // all mandatory constraints have been added via AddMandatoryConstraints 365 // and before propagating constraints. 366 virtual Status AddBackendConstraints(LayoutConstraints* constraints) { 367 return Status::OK(); 368 } 369 370 // Construct contraints and assign layouts to all instructions in the 371 // computation satisfying the given ComputationLayout. Layouts constraints are 372 // added, then propagated until all LogicalBuffers in the computation are 373 // constrained. 374 Status RunOnComputation(const ComputationLayout& computation_layout, 375 const TuplePointsToAnalysis& points_to_analysis, 376 HloComputation* computation, 377 ChannelLayoutConstraints* channel_constraints); 378 379 // Assign layouts to the instructions of a computation which satisfy the given 380 // layout constraints. Copies may be added to satisfy the constraints. The 381 // given LayoutConstraints must have layout constraints every logical buffer 382 // in the computation. 383 Status AssignLayouts(const LayoutConstraints& constraints, 384 HloComputation* computation); 385 386 // Propagates layout constraints from a set of initial constraints in order to 387 // minimize the local cost of the computation. This propagation is *not* 388 // required for correctness. 389 Status PropagateConstraints(LayoutConstraints* constraints); 390 391 // Check that all layouts in the module have been set and satisfy all 392 // necessary conditions. 393 Status CheckLayouts(HloModule* module); 394 395 ComputationLayout* entry_computation_layout_; 396 ChannelLayoutConstraints* channel_layout_constraints_; 397 398 protected: 399 // Map containing the layouts of all computations assigned so 400 // far. Computations are handled in a topological sort where computations are 401 // handled before their caller instructions so the layouts of caller 402 // instructions can be set to match the computation. 403 std::map<HloComputation*, ComputationLayout> computation_layouts_; 404 }; 405 406 } // namespace xla 407 408 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 409