Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
     18 
     19 #include <iosfwd>
     20 #include <map>
     21 #include <memory>
     22 #include <set>
     23 #include <string>
     24 #include <unordered_map>
     25 #include <utility>
     26 #include <vector>
     27 
     28 #include "tensorflow/compiler/xla/service/computation_layout.h"
     29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_module.h"
     32 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     33 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     34 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     35 #include "tensorflow/compiler/xla/shape_layout.h"
     36 #include "tensorflow/compiler/xla/shape_util.h"
     37 #include "tensorflow/compiler/xla/statusor.h"
     38 #include "tensorflow/compiler/xla/types.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/platform/types.h"
     42 
     43 namespace xla {
     44 
     45 // Abstract base class for layout constraints. These constraint objects are
     46 // gathered together in LayoutConstraints object.
     47 class LayoutConstraint {
     48  public:
     49   LayoutConstraint(bool mandatory, bool dfs)
     50       : mandatory_(mandatory), dfs_(dfs) {}
     51   virtual ~LayoutConstraint() = default;
     52 
     53   virtual string ToString() const = 0;
     54 
     55   // True if this constraint cannot be overwritten by a different constraint.
     56   bool mandatory() const { return mandatory_; }
     57 
     58   // When true, propagate in DFS. When false, constraint will propagate in BFS.
     59   bool dfs() const { return dfs_; }
     60 
     61  private:
     62   bool mandatory_;
     63   bool dfs_;
     64 };
     65 
     66 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
     67 
     68 // Layout constraint on a single LogicalBuffer. This constrains the layout of an
     69 // array produced by a particular instruction.
     70 class BufferLayoutConstraint : public LayoutConstraint {
     71  public:
     72   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
     73                          bool mandatory, bool dfs);
     74 
     75   const LogicalBuffer& buffer() const { return *buffer_; }
     76   const Layout& layout() const { return layout_; }
     77 
     78   string ToString() const override;
     79 
     80  private:
     81   Layout layout_;
     82   const LogicalBuffer* buffer_;
     83 };
     84 
     85 // Constraint on the layout of the operand of an instruction. The constrained
     86 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the
     87 // use of a shaped value and is not a hard constraint on the instruction(s)
     88 // which define the value as copies may be inserted between the definition and
     89 // use.
     90 class OperandLayoutConstraint : public LayoutConstraint {
     91  public:
     92   OperandLayoutConstraint(const ShapeLayout& shape_layout,
     93                           const HloInstruction* instruction, int64 operand_no,
     94                           bool mandatory, bool dfs);
     95 
     96   const ShapeLayout& shape_layout() const { return shape_layout_; }
     97   const HloInstruction* instruction() const { return instruction_; }
     98   const int64 operand_no() const { return operand_no_; }
     99   const HloInstruction* operand() const {
    100     return instruction_->operand(operand_no_);
    101   }
    102 
    103   string ToString() const override;
    104 
    105  private:
    106   ShapeLayout shape_layout_;
    107   const HloInstruction* instruction_;
    108   int64 operand_no_;
    109 };
    110 
    111 // Constraint on the layout of the result of the entry computation.
    112 class ResultLayoutConstraint : public LayoutConstraint {
    113  public:
    114   explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
    115                                   bool dfs = false)
    116       : LayoutConstraint(/*mandatory=*/true, dfs),
    117         shape_layout_(shape_layout) {}
    118 
    119   const ShapeLayout& shape_layout() const { return shape_layout_; }
    120   string ToString() const override;
    121 
    122  private:
    123   const ShapeLayout shape_layout_;
    124 };
    125 
    126 // Class encapsulating the layout constraints of the values in a HLO
    127 // computation.
    128 class LayoutConstraints {
    129  public:
    130   LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis,
    131                     HloComputation* computation);
    132   ~LayoutConstraints() = default;
    133 
    134   const HloComputation* computation() const { return computation_; }
    135   HloComputation* computation() { return computation_; }
    136   const TuplePointsToAnalysis& points_to_analysis() const {
    137     return points_to_analysis_;
    138   }
    139 
    140   // Return a vector containing the constraints which have been added to the
    141   // LayoutConstraints object since the construction of the object or since the
    142   // last time ConsumeAddedConstraints() has been called. This is used to
    143   // identify newly added constraints when propagating layouts.
    144   std::vector<const LayoutConstraint*> ConsumeAddedConstraints() {
    145     std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_));
    146     added_constraints_.clear();
    147     return ret_vec;
    148   }
    149   void ClearAddedConstraints() { added_constraints_.clear(); }
    150 
    151   // Returns the layout of a LogicalBuffer, the layout of the operand of the
    152   // instruction, or the layout of the result of the computation, respectively,
    153   // if it has been constrained. Otherwise return nullptr.
    154   const Layout* BufferLayout(const LogicalBuffer& buffer) const;
    155   const BufferLayoutConstraint* GetBufferLayoutConstraint(
    156       const LogicalBuffer& buffer) const;
    157   const ShapeLayout* OperandLayout(const HloInstruction* instruction,
    158                                    int64 operand_no) const;
    159   const OperandLayoutConstraint* GetOperandLayoutConstraint(
    160       const HloInstruction* instruction, int64 operand_no) const;
    161   const ShapeLayout* ResultLayout() const;
    162 
    163   // Add a constraint on the layout of a LogicalBuffer, the layout of the
    164   // operand of the instruction, or the layout of the result of the computation,
    165   // respectively.
    166   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
    167                          bool mandatory = true, bool dfs = true);
    168   Status SetOperandLayout(const Shape& shape_with_layout,
    169                           const HloInstruction* instruction, int64 operand_no,
    170                           bool mandatory = true, bool dfs = true);
    171   Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true);
    172 
    173   // Convenience wrapper around SetOperandLayout for setting the layout of a
    174   // operand using a Layout object. The operand must be array-shaped.
    175   Status SetArrayOperandLayout(const Layout& layout,
    176                                const HloInstruction* instruction,
    177                                int64 operand_no, bool mandatory = true,
    178                                bool dfs = true);
    179 
    180   // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
    181   // created by the instruction to the layouts in the given shape. The
    182   // instruction must define every logical buffer in its output.
    183   Status SetInstructionLayout(const Shape& shape_with_layout,
    184                               const HloInstruction* instruction,
    185                               bool mandatory = true, bool dfs = true);
    186 
    187   // Returns true if any buffer in the given operand is forwarded to the output
    188   // of the given instruction. For example, the Tuple instruction forwards the
    189   // buffers of its operands and would return true for each of its operands.
    190   bool OperandBufferForwarded(const HloInstruction* instruction,
    191                               int64 operand_no) const;
    192 
    193   // Returns the set of logical buffers (by LogicalBuffer:Id) which do not
    194   // yet have a layout constraint
    195   const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const {
    196     return unconstrained_buffer_ids_;
    197   }
    198 
    199   string ToString() const;
    200 
    201  private:
    202   // The set of BufferLayoutConstraints applied to the computation.
    203   std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint>
    204       buffer_constraints_;
    205 
    206   // The set of OperandLayoutConstraints applied to the computation.
    207   using OperandConstraintKey = std::pair<const HloInstruction*, int64>;
    208   std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_;
    209 
    210   // The result constraint for the computation (can be null).
    211   std::unique_ptr<ResultLayoutConstraint> result_constraint_;
    212 
    213   // A vector which holds constraints as they are added. Can be cleared with
    214   // ClearAddedConstraints.
    215   std::vector<const LayoutConstraint*> added_constraints_;
    216 
    217   // Points-to analysis for the module. Used to propagate constraints through
    218   // the HLO graph.
    219   const TuplePointsToAnalysis& points_to_analysis_;
    220 
    221   // Array-shaped buffers which have not yet been constrained.
    222   std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
    223 
    224   HloComputation* computation_;
    225 };
    226 
    227 // Contains constraints on the layout of channels; sends and recvs.
    228 class ChannelLayoutConstraints {
    229  public:
    230   // Construct an empty constraint set.
    231   ChannelLayoutConstraints() {}
    232 
    233   // Returns true if channel_id has a layout constraint.
    234   bool IsChannelConstrained(int64 channel_id) const {
    235     return constraints_.count(channel_id) > 0;
    236   }
    237 
    238   // Given `shape`, apply the layout for `channel_id`. `channel_id` must already
    239   // be constrained.
    240   Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const {
    241     CHECK(IsChannelConstrained(channel_id));
    242     *shape.mutable_layout() = constraints_.at(channel_id);
    243     return shape;
    244   }
    245 
    246   // Returns the layout constraint for `channel_id`, which must already be
    247   // constrained.
    248   Layout LayoutForChannel(int64 channel_id) const {
    249     CHECK(IsChannelConstrained(channel_id));
    250     return constraints_.at(channel_id);
    251   }
    252 
    253   // Adds a new layout constraint for `channel_id`. If a constraint for
    254   // `channel_id` already exists, this operation requires that the new layout is
    255   // the same as the previously constrained layout.
    256   void ConstrainChannel(int64 channel_id, const Layout& layout) {
    257     CHECK(!IsChannelConstrained(channel_id) ||
    258           LayoutUtil::Equal(layout, constraints_[channel_id]));
    259     constraints_[channel_id] = layout;
    260   }
    261 
    262  private:
    263   std::unordered_map<int64, Layout> constraints_;
    264 };
    265 
    266 // HLO pass which assigns layouts to all instructions in the HLO module while
    267 // satisfying all necessary invariants and minimizing cost.
    268 class LayoutAssignment : public HloPassInterface {
    269  public:
    270   // entry_computation_layout is modified to populate a layout for the result in
    271   // the case that no particular layout is requested.
    272   //
    273   // channel_constraints is both an input and output. Any sends or recvs that
    274   // are present in channel_constraints will be layed out as constrained. Any
    275   // unconstrained sends or recvs will be layed out as locally optimal and their
    276   // layout will be added as a constraint to channel_constraints.
    277   //
    278   // If channel_constraints is nullptr, no kSend or kRecvs must be contained
    279   // within any module passed to `Run`.
    280   explicit LayoutAssignment(
    281       ComputationLayout* entry_computation_layout,
    282       ChannelLayoutConstraints* channel_constraints = nullptr);
    283   ~LayoutAssignment() override {}
    284   tensorflow::StringPiece name() const override { return "layout-assignment"; }
    285 
    286   // Assign layouts to the given module. Returns whether the module was changed
    287   // (any layouts were changed).
    288   StatusOr<bool> Run(HloModule* module) override;
    289 
    290  protected:
    291   // These methods, invoked by PropagateConstraints, propagate a layout
    292   // constraint to its neighbors (i.e. operands and users) in order to minimize
    293   // the cost of the instructions being constrainted on. New constraints are
    294   // added to the given constraint set.
    295   //
    296   // Backends can override these methods with backend-specific propagation
    297   // rules.
    298   virtual Status PropagateBufferConstraint(
    299       const BufferLayoutConstraint& layout_constraint,
    300       LayoutConstraints* constraints);
    301   virtual Status PropagateOperandConstraint(
    302       const OperandLayoutConstraint& layout_constraint,
    303       LayoutConstraints* constraints);
    304   virtual Status PropagateResultConstraint(
    305       const ResultLayoutConstraint& layout_constraint,
    306       LayoutConstraints* constraints);
    307 
    308   // By default LayoutAssignment ensures that inputs and outputs of CustomCalls
    309   // have the "major-first" layout (i.e.  {n, n-1, ..., 0}).
    310   //
    311   // If this function returns true, LayoutAssignment does not set a layout for
    312   // the given CustomCall.  It's up to the backend to set one in
    313   // AddBackendConstraints, if necessary.
    314   //
    315   // Precondition: instruction->opcode() == HloOpcode::kCustomCall.
    316   virtual bool CustomCallRequiresMajorFirstLayout(
    317       const HloInstruction* /*instruction*/) {
    318     return true;
    319   }
    320 
    321   // Called after layouts of an instruction have been finalized to allow
    322   // subclasses to check for platform specific assumptions.
    323   virtual Status Verify(const HloInstruction* instruction) {
    324     return Status::OK();
    325   }
    326 
    327   // Propagates a buffer layout constraint into the operands that use it.
    328   Status PropagateBufferConstraintToUses(
    329       const BufferLayoutConstraint& layout_constraint,
    330       LayoutConstraints* constraints);
    331 
    332   // Propagates a layout constraint on the use of the result of the given
    333   // instruction to the definitions of the LogicalBuffers which make up the
    334   // result.
    335   Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout,
    336                                       const HloInstruction* instruction,
    337                                       LayoutConstraints* constraints);
    338 
    339   // Chooses a layout of operand `operand_no` of `instruction` that minimizes
    340   // the cost of `instruction`. `output_layout` is the layout of `instruction`.
    341   // Returns null if it can't decide the best layout.
    342   // Precondition: `instruction` and the operand are array-shaped.
    343   std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout(
    344       const Layout& output_layout, const HloInstruction* instruction,
    345       int64 operand_no);
    346   // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of
    347   // `user` that minimizes its cost on that operand.  Returns null if it can't
    348   // decide the best layout.
    349   // Precondition: `user` and the operand are array-shaped.
    350   std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout(
    351       const Layout& operand_layout, const HloInstruction* user,
    352       int64 operand_no);
    353 
    354  private:
    355   // Adds constraints which must be satisfied for correctness on all
    356   // backends. Called once prior to propagating constraints.
    357   Status AddMandatoryConstraints(
    358       const ComputationLayout& computation_layout,
    359       const ChannelLayoutConstraints* channel_constraints,
    360       HloComputation* computation, LayoutConstraints* constraints);
    361 
    362   // This method can be overridden to add backend-specific constraints to the
    363   // layout of the instructions of a computation. This method is called after
    364   // all mandatory constraints have been added via AddMandatoryConstraints
    365   // and before propagating constraints.
    366   virtual Status AddBackendConstraints(LayoutConstraints* constraints) {
    367     return Status::OK();
    368   }
    369 
    370   // Construct contraints and assign layouts to all instructions in the
    371   // computation satisfying the given ComputationLayout. Layouts constraints are
    372   // added, then propagated until all LogicalBuffers in the computation are
    373   // constrained.
    374   Status RunOnComputation(const ComputationLayout& computation_layout,
    375                           const TuplePointsToAnalysis& points_to_analysis,
    376                           HloComputation* computation,
    377                           ChannelLayoutConstraints* channel_constraints);
    378 
    379   // Assign layouts to the instructions of a computation which satisfy the given
    380   // layout constraints. Copies may be added to satisfy the constraints. The
    381   // given LayoutConstraints must have layout constraints every logical buffer
    382   // in the computation.
    383   Status AssignLayouts(const LayoutConstraints& constraints,
    384                        HloComputation* computation);
    385 
    386   // Propagates layout constraints from a set of initial constraints in order to
    387   // minimize the local cost of the computation. This propagation is *not*
    388   // required for correctness.
    389   Status PropagateConstraints(LayoutConstraints* constraints);
    390 
    391   // Check that all layouts in the module have been set and satisfy all
    392   // necessary conditions.
    393   Status CheckLayouts(HloModule* module);
    394 
    395   ComputationLayout* entry_computation_layout_;
    396   ChannelLayoutConstraints* channel_layout_constraints_;
    397 
    398  protected:
    399   // Map containing the layouts of all computations assigned so
    400   // far. Computations are handled in a topological sort where computations are
    401   // handled before their caller instructions so the layouts of caller
    402   // instructions can be set to match the computation.
    403   std::map<HloComputation*, ComputationLayout> computation_layouts_;
    404 };
    405 
    406 }  // namespace xla
    407 
    408 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
    409