Home | History | Annotate | Download | only in xla
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_
     17 #define TENSORFLOW_COMPILER_XLA_LAYOUT_H_
     18 
     19 #include <vector>
     20 
     21 #include "absl/types/span.h"
     22 
     23 #include "tensorflow/compiler/xla/types.h"
     24 #include "tensorflow/compiler/xla/util.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 #include "tensorflow/core/platform/types.h"
     27 
     28 namespace xla {
     29 
     30 // Describes a tile used in tiling-based layout. Refer to
     31 // g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
     32 // details.
     33 class Tile {
     34  public:
     35   Tile() = default;
     36   explicit Tile(absl::Span<const int64> dimensions)
     37       : dimensions_(dimensions.begin(), dimensions.end()) {}
     38 
     39   // De/Serialize a Tile to and from a TileProto.
     40   static Tile CreateFromProto(const TileProto& tile_proto) {
     41     return Tile(AsInt64Slice(tile_proto.dimensions()));
     42   }
     43   TileProto ToProto() const;
     44 
     45   bool operator==(const Tile& other) const {
     46     return dimensions() == other.dimensions();
     47   }
     48   bool operator!=(const Tile& other) const { return !(*this == other); }
     49 
     50   string ToString() const;
     51 
     52   // Returns the bound of the tile in the given dimension index.
     53   int64 dimension(int i) const { return dimensions_.at(i); }
     54 
     55   // Returns the dimensions of the tile.
     56   const std::vector<int64>& dimensions() const { return dimensions_; }
     57 
     58   Tile& add_dimensions(int64 value) {
     59     dimensions_.push_back(value);
     60     return *this;
     61   }
     62 
     63   Tile& clear_dimensions() {
     64     dimensions_.clear();
     65     return *this;
     66   }
     67 
     68   // This dimension size means the corresponding dimension in the shape is
     69   // combined with the next minor dimension before tiling is applied.
     70   static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min();
     71 
     72  private:
     73   // The bounds of the tile.
     74   std::vector<int64> dimensions_;
     75 };
     76 
     77 class Layout {
     78  public:
     79   Layout() = default;
     80 
     81   // Constructs a dense layout with the given minor-to-major order.
     82   explicit Layout(absl::Span<const int64> minor_to_major)
     83       : format_(DENSE),
     84         minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {}
     85 
     86   // Constructs a dense tiled layout with the given minor-to-major order and
     87   // tiles.
     88   Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
     89          int64 element_size_in_bits = 0)
     90       : format_(DENSE),
     91         minor_to_major_(minor_to_major.begin(), minor_to_major.end()),
     92         tiles_(tiles.begin(), tiles.end()),
     93         element_size_in_bits_(element_size_in_bits) {}
     94 
     95   // Construct a shape from a LayoutProto.
     96   static Layout CreateFromProto(const LayoutProto& proto);
     97 
     98   // Returns a LayoutProto representation of the Layout.
     99   LayoutProto ToProto() const;
    100 
    101   // Returns a human-readable string that represents this layout.
    102   string ToString() const;
    103 
    104   // Equal is a configurable functor to check the equality of two layouts.
    105   //
    106   // Examples:
    107   //
    108   // - Comparing two layouts ignoring their difference in tiles:
    109   //   Equal().IgnoreTiles()(layout1, layout2);
    110   //
    111   // - Comparing two layouts ignoring their difference in tiles and element
    112   //   size:
    113   //   Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2);
    114   class Equal {
    115    public:
    116     Equal() = default;
    117 
    118     bool operator()(const Layout& lhs, const Layout& rhs);
    119 
    120     Equal& IgnoreTiles() {
    121       ignore_tiles_ = true;
    122       return *this;
    123     }
    124 
    125     Equal& IgnoreElementSize() {
    126       ignore_element_size_ = true;
    127       return *this;
    128     }
    129 
    130    private:
    131     bool ignore_tiles_ = false;
    132     bool ignore_element_size_ = false;
    133   };
    134 
    135   bool operator==(const Layout& other) const;
    136   bool operator!=(const Layout& other) const { return !(*this == other); }
    137 
    138   // The following methods mirror the protobuf generated code interface for the
    139   // message LayoutProto. This enabled easy migration of this data structure
    140   // from a proto to a proper C++ class.
    141   //
    142   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
    143   // interface.
    144 
    145   // Methods for accessing the format.
    146   Format format() const { return format_; }
    147   Layout& set_format(Format value) {
    148     format_ = value;
    149     return *this;
    150   }
    151 
    152   // Methods for accessing the minor-to-major array.
    153   int minor_to_major_size() const { return minor_to_major_.size(); }
    154   int64 minor_to_major(int index) const { return minor_to_major_.at(index); }
    155   Layout& set_minor_to_major(int index, int64 value) {
    156     minor_to_major_.at(index) = value;
    157     return *this;
    158   }
    159   Layout& add_minor_to_major(int64 value) {
    160     minor_to_major_.push_back(value);
    161     return *this;
    162   }
    163   Layout& clear_minor_to_major() {
    164     minor_to_major_.clear();
    165     return *this;
    166   }
    167   const std::vector<int64>& minor_to_major() const { return minor_to_major_; }
    168   std::vector<int64>* mutable_minor_to_major() { return &minor_to_major_; }
    169 
    170   // Methods for accessing the tile field.
    171   int tiles_size() const { return tiles_.size(); }
    172   const Tile& tiles(int index) const { return tiles_.at(index); }
    173   Tile* mutable_tiles(int index) { return &tiles_.at(index); }
    174   Tile* add_tiles() {
    175     tiles_.push_back(Tile());
    176     return &tiles_.back();
    177   }
    178   Layout& clear_tiles() {
    179     tiles_.clear();
    180     return *this;
    181   }
    182   const std::vector<Tile>& tiles() const { return tiles_; }
    183   std::vector<Tile>* mutable_tiles() { return &tiles_; }
    184 
    185   // Methods for accessing the int64 fields.
    186   int64 max_sparse_elements() const { return max_sparse_elements_; }
    187   Layout& set_max_sparse_elements(int64 value) {
    188     max_sparse_elements_ = value;
    189     return *this;
    190   }
    191   int64 element_size_in_bits() const { return element_size_in_bits_; }
    192   Layout& set_element_size_in_bits(int64 value) {
    193     element_size_in_bits_ = value;
    194     return *this;
    195   }
    196 
    197   void Swap(Layout* other) {
    198     using std::swap;
    199     swap(*this, *other);
    200   }
    201 
    202   void Clear() {
    203     format_ = INVALID_FORMAT;
    204     minor_to_major_.clear();
    205     max_sparse_elements_ = 0;
    206     element_size_in_bits_ = 0;
    207   }
    208 
    209  private:
    210   // The format of this layout.
    211   Format format_ = INVALID_FORMAT;
    212 
    213   // Sequence of dimension numbers, from minor (fastest varying index) to major
    214   // (slowest varying index).
    215   std::vector<int64> minor_to_major_;
    216 
    217   // The maximum number of elements that can be stored for SPARSE formats.  This
    218   // can be used to determine the maximum size in bytes of arrays stored in
    219   // memory.  This field must be zero unless the format is SPARSE.
    220   int64 max_sparse_elements_ = 0;
    221 
    222   // The tiles used in tiling-based layout.
    223   std::vector<Tile> tiles_;
    224 
    225   // The number of bits used to store an individual array element.
    226   int64 element_size_in_bits_ = 0;
    227 };
    228 
    229 std::ostream& operator<<(std::ostream& out, const Tile& Tile);
    230 std::ostream& operator<<(std::ostream& out, const Layout& layout);
    231 
    232 }  // namespace xla
    233 
    234 #endif  // TENSORFLOW_COMPILER_XLA_LAYOUT_H_
    235