Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
     18 
     19 #include <iosfwd>
     20 #include <map>
     21 #include <memory>
     22 #include <set>
     23 #include <string>
     24 #include <unordered_map>
     25 #include <utility>
     26 #include <vector>
     27 
     28 #include "absl/container/flat_hash_map.h"
     29 #include "absl/container/flat_hash_set.h"
     30 #include "tensorflow/compiler/xla/service/computation_layout.h"
     31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     33 #include "tensorflow/compiler/xla/service/hlo_module.h"
     34 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     35 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     36 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     37 #include "tensorflow/compiler/xla/shape_layout.h"
     38 #include "tensorflow/compiler/xla/shape_util.h"
     39 #include "tensorflow/compiler/xla/statusor.h"
     40 #include "tensorflow/compiler/xla/types.h"
     41 #include "tensorflow/compiler/xla/xla_data.pb.h"
     42 #include "tensorflow/core/lib/core/status.h"
     43 #include "tensorflow/core/platform/types.h"
     44 
     45 namespace xla {
     46 
     47 // Abstract base class for layout constraints. These constraint objects are
     48 // gathered together in LayoutConstraints object.
     49 class LayoutConstraint {
     50  public:
     51   LayoutConstraint(bool mandatory, bool dfs)
     52       : mandatory_(mandatory), dfs_(dfs) {}
     53   virtual ~LayoutConstraint() = default;
     54 
     55   virtual string ToString() const = 0;
     56 
     57   // True if this constraint cannot be overwritten by a different constraint.
     58   bool mandatory() const { return mandatory_; }
     59 
     60   // When true, propagate in DFS. When false, constraint will propagate in BFS.
     61   bool dfs() const { return dfs_; }
     62 
     63  private:
     64   bool mandatory_;
     65   bool dfs_;
     66 };
     67 
     68 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
     69 
     70 // Layout constraint on a single LogicalBuffer. This constrains the layout of an
     71 // array produced by a particular instruction.
     72 class BufferLayoutConstraint : public LayoutConstraint {
     73  public:
     74   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
     75                          bool mandatory, bool dfs);
     76 
     77   const LogicalBuffer& buffer() const { return *buffer_; }
     78   const Layout& layout() const { return layout_; }
     79 
     80   string ToString() const override;
     81 
     82  private:
     83   Layout layout_;
     84   const LogicalBuffer* buffer_;
     85 };
     86 
     87 // Constraint on the layout of the operand of an instruction. The constrained
     88 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the
     89 // use of a shaped value and is not a hard constraint on the instruction(s)
     90 // which define the value as copies may be inserted between the definition and
     91 // use.
     92 class OperandLayoutConstraint : public LayoutConstraint {
     93  public:
     94   OperandLayoutConstraint(const ShapeLayout& shape_layout,
     95                           const HloInstruction* instruction, int64 operand_no,
     96                           bool mandatory, bool dfs);
     97 
     98   const ShapeLayout& shape_layout() const { return shape_layout_; }
     99   const HloInstruction* instruction() const { return instruction_; }
    100   const int64 operand_no() const { return operand_no_; }
    101   const HloInstruction* operand() const {
    102     return instruction_->operand(operand_no_);
    103   }
    104 
    105   string ToString() const override;
    106 
    107  private:
    108   ShapeLayout shape_layout_;
    109   const HloInstruction* instruction_;
    110   int64 operand_no_;
    111 };
    112 
    113 // Constraint on the layout of the result of the entry computation.
    114 class ResultLayoutConstraint : public LayoutConstraint {
    115  public:
    116   explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
    117                                   bool dfs = false)
    118       : LayoutConstraint(/*mandatory=*/true, dfs),
    119         shape_layout_(shape_layout) {}
    120 
    121   const ShapeLayout& shape_layout() const { return shape_layout_; }
    122   string ToString() const override;
    123 
    124  private:
    125   const ShapeLayout shape_layout_;
    126 };
    127 
    128 // Class encapsulating the layout constraints of the values in a HLO
    129 // computation.
    130 class LayoutConstraints {
    131  public:
    132   LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis,
    133                     HloComputation* computation);
    134   ~LayoutConstraints() = default;
    135 
    136   const HloComputation* computation() const { return computation_; }
    137   HloComputation* computation() { return computation_; }
    138   const TuplePointsToAnalysis& points_to_analysis() const {
    139     return points_to_analysis_;
    140   }
    141 
    142   // Return a vector containing the constraints which have been added to the
    143   // LayoutConstraints object since the construction of the object or since the
    144   // last time ConsumeAddedConstraints() has been called. This is used to
    145   // identify newly added constraints when propagating layouts.
    146   std::vector<const LayoutConstraint*> ConsumeAddedConstraints() {
    147     std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_));
    148     added_constraints_.clear();
    149     return ret_vec;
    150   }
    151   void ClearAddedConstraints() { added_constraints_.clear(); }
    152 
    153   // Returns the layout of a LogicalBuffer, the layout of the operand of the
    154   // instruction, or the layout of the result of the computation, respectively,
    155   // if it has been constrained. Otherwise return nullptr.
    156   const Layout* BufferLayout(const LogicalBuffer& buffer) const;
    157   const BufferLayoutConstraint* GetBufferLayoutConstraint(
    158       const LogicalBuffer& buffer) const;
    159   const ShapeLayout* OperandLayout(const HloInstruction* instruction,
    160                                    int64 operand_no) const;
    161   const OperandLayoutConstraint* GetOperandLayoutConstraint(
    162       const HloInstruction* instruction, int64 operand_no) const;
    163   const ShapeLayout* ResultLayout() const;
    164 
    165   // Add a constraint on the layout of a LogicalBuffer, the layout of the
    166   // operand of the instruction, or the layout of the result of the computation,
    167   // respectively.
    168   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
    169                          bool mandatory = true, bool dfs = true);
    170   Status SetOperandLayout(const Shape& shape_with_layout,
    171                           const HloInstruction* instruction, int64 operand_no,
    172                           bool mandatory = true, bool dfs = true);
    173   Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true);
    174 
    175   // Convenience wrapper around SetOperandLayout for setting the layout of a
    176   // operand using a Layout object. The operand must be array-shaped.
    177   Status SetArrayOperandLayout(const Layout& layout,
    178                                const HloInstruction* instruction,
    179                                int64 operand_no, bool mandatory = true,
    180                                bool dfs = true);
    181 
    182   // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
    183   // created by the instruction to the layouts in the given shape. The
    184   // instruction must define every logical buffer in its output.
    185   Status SetInstructionLayout(const Shape& shape_with_layout,
    186                               const HloInstruction* instruction,
    187                               bool mandatory = true, bool dfs = true);
    188 
    189   // Returns true if any buffer in the given operand is forwarded to the output
    190   // of the given instruction. For example, the Tuple instruction forwards the
    191   // buffers of its operands and would return true for each of its operands.
    192   bool OperandBufferForwarded(const HloInstruction* instruction,
    193                               int64 operand_no) const;
    194 
    195   // Returns the set of logical buffers (by LogicalBuffer:Id) which do not
    196   // yet have a layout constraint
    197   const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const {
    198     return unconstrained_buffer_ids_;
    199   }
    200 
    201   string ToString() const;
    202 
    203  private:
    204   // Find a bufferset in the bufferset cache. This is useful since we can
    205   // currently create the flattened buffer set for the same instruction many
    206   // times, which is often slow.
    207   PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const;
    208 
    209   // The set of BufferLayoutConstraints applied to the computation.
    210   std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint>
    211       buffer_constraints_;
    212 
    213   // The set of OperandLayoutConstraints applied to the computation.
    214   using OperandConstraintKey = std::pair<const HloInstruction*, int64>;
    215   std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_;
    216 
    217   // The result constraint for the computation (can be null).
    218   std::unique_ptr<ResultLayoutConstraint> result_constraint_;
    219 
    220   // A vector which holds constraints as they are added. Can be cleared with
    221   // ClearAddedConstraints.
    222   std::vector<const LayoutConstraint*> added_constraints_;
    223 
    224   // Points-to analysis for the module. Used to propagate constraints through
    225   // the HLO graph.
    226   const TuplePointsToAnalysis& points_to_analysis_;
    227 
    228   // Array-shaped buffers which have not yet been constrained.
    229   std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
    230 
    231   mutable absl::flat_hash_map<const HloInstruction*,
    232                               std::unique_ptr<PointsToSet::BufferSet>>
    233       buffer_sets_cache_;
    234 
    235   HloComputation* computation_;
    236 };
    237 
    238 // Contains constraints on the layout of channels; sends and recvs.
    239 class ChannelLayoutConstraints {
    240  public:
    241   // Construct an empty constraint set.
    242   ChannelLayoutConstraints() {}
    243 
    244   // Returns true if channel_id has a layout constraint.
    245   bool IsChannelConstrained(int64 channel_id) const {
    246     return constraints_.contains(channel_id);
    247   }
    248 
    249   // Given `shape`, apply the layout for `channel_id`. `channel_id` must already
    250   // be constrained.
    251   Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const {
    252     auto it = constraints_.find(channel_id);
    253     CHECK(it != constraints_.end()) << "Channel " << channel_id;
    254     *shape.mutable_layout() = it->second;
    255     return shape;
    256   }
    257 
    258   // Returns the layout constraint for `channel_id`, which must already be
    259   // constrained.
    260   const Layout& LayoutForChannel(int64 channel_id) const {
    261     auto it = constraints_.find(channel_id);
    262     CHECK(it != constraints_.end()) << "Channel " << channel_id;
    263     return it->second;
    264   }
    265 
    266   // Adds a new layout constraint for `channel_id`. If a constraint for
    267   // `channel_id` has been added, this API returns nullptr, otherwise returns
    268   // the layout which has already been set for the channel.
    269   const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) {
    270     auto it = constraints_.emplace(std::make_pair(channel_id, layout));
    271     if (it.second) {
    272       return nullptr;
    273     }
    274     return LayoutUtil::Equal(layout, it.first->second) ? nullptr
    275                                                        : &it.first->second;
    276   }
    277 
    278  private:
    279   absl::flat_hash_map<int64, Layout> constraints_;
    280 };
    281 
    282 // HLO pass which assigns layouts to all instructions in the HLO module while
    283 // satisfying all necessary invariants and minimizing cost.
    284 class LayoutAssignment : public HloModulePass {
    285  public:
    286   // entry_computation_layout is modified to populate a layout for the result in
    287   // the case that no particular layout is requested.
    288   //
    289   // instruction_can_change_layout_func is a function object that determines
    290   // whether an instruction can change layouts. An instruction not being able to
    291   // change layout means that it requires operands with the same rank as the
    292   // output to have the same layout as the output.
    293   //
    294   // channel_constraints is both an input and output. Any sends or recvs that
    295   // are present in channel_constraints will be laid out as constrained. Any
    296   // unconstrained sends or recvs will be laid out as locally optimal and their
    297   // layout will be added as a constraint to channel_constraints.
    298   //
    299   // If channel_constraints is nullptr, no kSend or kRecvs must be contained
    300   // within any module passed to `Run`.
    301   explicit LayoutAssignment(
    302       ComputationLayout* entry_computation_layout,
    303       std::function<bool(const HloInstruction*)>
    304           instruction_can_change_layout_func = InstructionCanChangeLayout,
    305       ChannelLayoutConstraints* channel_constraints = nullptr);
    306   ~LayoutAssignment() override {}
    307   absl::string_view name() const override { return "layout-assignment"; }
    308 
    309   // Assign layouts to the given module. Returns whether the module was changed
    310   // (any layouts were changed).
    311   StatusOr<bool> Run(HloModule* module) override;
    312 
    313   // Determines whether an instruction can change layouts. An instruction not
    314   // being able to change layout means that it requires operands with the same
    315   // rank as the output to have the same layout as the output.
    316   static bool InstructionCanChangeLayout(const HloInstruction* instruction);
    317 
    318   // In case of an array shape returns true iff it is at most rank 1. In case of
    319   // a tuple shape returns true iff all leaf shapes are at most rank 1.
    320   static bool IsAtMostRank1(const Shape& shape);
    321 
    322  protected:
    323   // These methods, invoked by PropagateConstraints, propagate a layout
    324   // constraint to its neighbors (i.e. operands and users) in order to minimize
    325   // the cost of the instructions being constrainted on. New constraints are
    326   // added to the given constraint set.
    327   //
    328   // Backends can override these methods with backend-specific propagation
    329   // rules.
    330   virtual Status PropagateBufferConstraint(
    331       const BufferLayoutConstraint& layout_constraint,
    332       LayoutConstraints* constraints);
    333   virtual Status PropagateOperandConstraint(
    334       const OperandLayoutConstraint& layout_constraint,
    335       LayoutConstraints* constraints);
    336   virtual Status PropagateResultConstraint(
    337       const ResultLayoutConstraint& layout_constraint,
    338       LayoutConstraints* constraints);
    339 
    340   // Called after layouts of an instruction have been finalized to allow
    341   // subclasses to check for platform specific assumptions.
    342   virtual Status Verify(const HloInstruction* instruction) {
    343     return Status::OK();
    344   }
    345 
    346   // Propagates a buffer layout constraint into the operands that use it.
    347   Status PropagateBufferConstraintToUses(
    348       const BufferLayoutConstraint& layout_constraint,
    349       LayoutConstraints* constraints);
    350 
    351   // Propagates a layout constraint on the use of the result of the given
    352   // instruction to the definitions of the LogicalBuffers which make up the
    353   // result.
    354   Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout,
    355                                       const HloInstruction* instruction,
    356                                       LayoutConstraints* constraints);
    357 
    358   // Chooses a layout of operand `operand_no` of `instruction` that minimizes
    359   // the cost of `instruction`. `output_layout` is the layout of `instruction`.
    360   // Returns null if it can't decide the best layout.
    361   // Precondition: `instruction` and the operand are array-shaped.
    362   std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout(
    363       const Layout& output_layout, const HloInstruction* instruction,
    364       int64 operand_no);
    365   // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of
    366   // `user` that minimizes its cost on that operand.  Returns null if it can't
    367   // decide the best layout.
    368   // Precondition: `user` and the operand are array-shaped.
    369   virtual std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout(
    370       const Layout& operand_layout, const HloInstruction* user,
    371       int64 operand_no);
    372 
    373  private:
    374   // Initializes the layout assignment object for a new Run() call.
    375   Status Init();
    376 
    377   // Adds constraints which must be satisfied for correctness on all
    378   // backends. Called once prior to propagating constraints.
    379   Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
    380                                  ChannelLayoutConstraints* channel_constraints,
    381                                  HloComputation* computation,
    382                                  LayoutConstraints* constraints);
    383 
    384   // This method can be overridden to add backend-specific constraints to the
    385   // layout of the instructions of a computation. This method is called after
    386   // all mandatory constraints have been added via AddMandatoryConstraints
    387   // and before propagating constraints.
    388   virtual Status AddBackendConstraints(LayoutConstraints* constraints) {
    389     return Status::OK();
    390   }
    391 
    392   // Construct contraints and assign layouts to all instructions in the
    393   // computation satisfying the given ComputationLayout, if not nullptr.
    394   // Otherwise the ComputationLayout will be calculated by propagating the
    395   // computation instruction contraints.
    396   // Layouts constraints are added, then propagated until all LogicalBuffers in
    397   // the computation are constrained.
    398   Status RunOnComputation(ComputationLayout* computation_layout,
    399                           const TuplePointsToAnalysis& points_to_analysis,
    400                           HloComputation* computation,
    401                           ChannelLayoutConstraints* channel_constraints);
    402 
    403   // Assign layouts to the instructions of a computation which satisfy the given
    404   // layout constraints. Copies may be added to satisfy the constraints. The
    405   // given LayoutConstraints must have layout constraints every logical buffer
    406   // in the computation.
    407   Status AssignLayouts(const LayoutConstraints& constraints,
    408                        HloComputation* computation);
    409 
    410   // Propagates layout constraints from a set of initial constraints in order to
    411   // minimize the local cost of the computation. This propagation is *not*
    412   // required for correctness.
    413   Status PropagateConstraints(LayoutConstraints* constraints);
    414 
    415   Status PropagateBufferConstraintToOperands(
    416       const BufferLayoutConstraint& buffer_constraint,
    417       LayoutConstraints* constraints);
    418 
    419   // Check that all layouts in the module have been set and satisfy all
    420   // necessary conditions.
    421   Status CheckLayouts(HloModule* module);
    422 
    423   // Computes the ComputationLayout of the given computation based of the
    424   // layouts assigned to parameters and root instruction, and inserts it to the
    425   // computation_layouts_ map.
    426   Status CalculateComputationLayout(HloComputation* computation);
    427 
    428   // Clears all the layouts which can be cleared within a computation.
    429   Status ClearComputationLayouts(HloComputation* computation);
    430 
    431   // Clears the side effects of a previous pass, like added copy instructions.
    432   Status ClearPreviousPassSideEffects(HloModule* module);
    433 
    434   // Propagates the layouts computed by the layout assignment pass on the given
    435   // computation, to the computation layout passed in to this API.
    436   // This API propagates missing layout, and also checks that the caller
    437   // specified have been respected, by comparing those with the parameters and
    438   // root computation instruction.
    439   Status PropagateComputationLayouts(HloComputation* computation,
    440                                      ComputationLayout* computation_layout);
    441 
    442   // The pointer to the ComputationLayout passed as constructor parameter.
    443   ComputationLayout* entry_computation_layout_;
    444 
    445   // A copy of entry_computation_layout_ used to reset it to the initial values
    446   // during the multiple passes done by the layout assignment operation.
    447   ComputationLayout saved_entry_computation_layout_;
    448 
    449  protected:
    450   // Sets up the copy instruction according to the characteristic (sharding,
    451   // metadata, ...) of the reference instruction. The index argument is used
    452   // when the instruction is a tuple, and in such case the index represents
    453   // the location from where the copy instruction was created from.
    454   // If the index is empty, the whole sharding will be propagated, even in case
    455   // the intruction has a tuple sharding.
    456   static void SetupCopiedInstruction(const HloInstruction& instruction,
    457                                      HloInstruction* copy,
    458                                      const ShapeIndex& index);
    459 
    460   // Creates and returns a copy of the given instruction with a different
    461   // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
    462   // instruction producing the copy is returned.
    463   StatusOr<HloInstruction*> CreateCopyWithNewLayout(
    464       const Shape& shape_with_layout, HloInstruction* instruction);
    465 
    466   // Creates a copy of the given operand if the operand's layout does not match
    467   // the given layout. This copy replaces the use in the given instruction.
    468   // Tuple operands will be deep-copied.
    469   Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
    470                                     HloInstruction* instruction,
    471                                     int64 operand_no);
    472 
    473   // Registers a copy instruction added by the layout assignment pass.
    474   void RegisterAddedCopy(HloInstruction* copy) {
    475     CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
    476     added_copies_.insert(copy);
    477   }
    478 
    479   // Adds a copy for the operand of an instruction, unless such operand is
    480   // already a copy, and has a single user (which is forcibly the instruction
    481   // itself).
    482   Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
    483 
    484   // Apply the channel layout constraints by populating the channel_constraints
    485   // data structure passed in at constructor time. Eventually adds copies in
    486   // case two ends of a channel ended up with a different leyout.
    487   Status ConstrainChannelLayouts(HloComputation* computation,
    488                                  ChannelLayoutConstraints* channel_constraints);
    489 
    490   // Resets the input ChannelLayoutConstraints to the original copy received
    491   // from the constructor input.
    492   void ResetChannelConstraints() {
    493     if (channel_layout_constraints_ != nullptr) {
    494       *channel_layout_constraints_ = channel_constraints_;
    495     }
    496   }
    497 
    498   // Adds constraints related to host Send/Recv instructions.
    499   Status BuildHostChannelConstraints(HloComputation* computation);
    500 
    501   // Map containing the layouts of all computations assigned so
    502   // far. Computations are handled in a topological sort where computations are
    503   // handled before their caller instructions so the layouts of caller
    504   // instructions can be set to match the computation.
    505   std::map<HloComputation*, ComputationLayout> computation_layouts_;
    506 
    507   // Every copy added to the module by the layout assignment pass is registered
    508   // here.
    509   absl::flat_hash_set<HloInstruction*> added_copies_;
    510 
    511   // The pointer to the channel layout constraints passed in with the
    512   // constructor. If not nullptr, this is an input/output argument.
    513   ChannelLayoutConstraints* channel_layout_constraints_ = nullptr;
    514 
    515   // A copy of the input layout constraints used to reset the above pointer in
    516   // case we have to undo operations due to the multiple passes over the
    517   // computations/instructions.
    518   ChannelLayoutConstraints channel_constraints_;
    519 
    520   // Layout constraints for send/recv instructions which communicate with the
    521   // host.
    522   ChannelLayoutConstraints host_channel_constraints_;
    523 
    524   // The set of HLO instructions which lacked any layout constraint, thus
    525   // receiving propagated default layouts.
    526   absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
    527 
    528   std::function<bool(const HloInstruction*)>
    529       instruction_can_change_layout_func_;
    530 };
    531 
    532 }  // namespace xla
    533 
    534 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
    535