Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // HLO shardings describe how an HLO instruction is split across multiple
     17 // computations.
     18 
     19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
     20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
     21 
     22 #include <string>
     23 
     24 #include "tensorflow/compiler/xla/array.h"
     25 #include "tensorflow/compiler/xla/literal_util.h"
     26 #include "tensorflow/compiler/xla/protobuf_util.h"
     27 #include "tensorflow/compiler/xla/shape_tree.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/lib/gtl/array_slice.h"
     30 #include "tensorflow/core/lib/hash/hash.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 namespace xla {
     36 
     37 // HLO shardings describe how an HLO instruction is split across multiple
     38 // computations.
     39 class HloSharding {
     40  public:
     41   // Creates a trivial sharding that replicates a maximal tile across all
     42   // devices.
     43   static HloSharding Replicate() { return HloSharding(); }
     44 
     45   // Creates a sharding that emulates device placement; a tile shape equal to
     46   // the input shape (one tile) assigned to a single device.
     47   static HloSharding AssignDevice(int64 device_id);
     48 
     49   // Creates a new sharding which splits a shape into tiles each with shape
     50   // `tile_shape`. Each tile is assigned to one device, which is specified by
     51   // `tile_assignment`. Any tensor not a multiple of the tile size in any
     52   // dimension is implicitly padded to the tile size.
     53   //
     54   // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like:
     55   //      2     1 padding
     56   //   <------><->
     57   //   +----+----+
     58   //   | 0  |  1 |
     59   //   +----+----+
     60   //
     61   // Split into two tiles, one of which is implicitly padded by one.
     62   static HloSharding Tile(const Shape& tile_shape,
     63                           const Array<int64>& tile_assignment) {
     64     return HloSharding(tile_shape, tile_assignment);
     65   }
     66 
     67   // Creates a new sharding which splits a one-dimensional input shape into
     68   // `num_tiles` tiles.
     69   static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
     70 
     71   // Creates a new sharding for a tuple type. The given ShapeTree must have
     72   // elements for every leaf shape contained in the tuple.
     73   static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
     74     std::vector<HloSharding> flattened_list;
     75     flattened_list.reserve(
     76         std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
     77     for (const auto& index_to_sharding : sub_shardings.leaves()) {
     78       flattened_list.push_back(index_to_sharding.second);
     79     }
     80     return HloSharding(flattened_list);
     81   }
     82 
     83   // Creates a new sharding for a tuple type. The requested tuple shape must not
     84   // be nested. For nested tuples, use the ShapeTree overload.
     85   static HloSharding Tuple(const Shape& tuple_shape,
     86                            tensorflow::gtl::ArraySlice<HloSharding> shardings) {
     87     CHECK(ShapeUtil::IsTuple(tuple_shape));
     88     CHECK(!ShapeUtil::IsNestedTuple(tuple_shape));
     89     std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
     90     CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape));
     91     return HloSharding(flattened_list);
     92   }
     93 
     94   // Create a new sharding from a protobuf OpSharding.
     95   static StatusOr<HloSharding> FromProto(const OpSharding& proto);
     96 
     97   OpSharding ToProto() const;
     98   string ToString() const;
     99 
    100   // Validate that this sharding can be applied to a tensor with shape `shape`.
    101   Status Validate(const Shape& shape, int64 num_devices) const;
    102 
    103   // Returns true if the sharding has tuple type.
    104   bool IsTuple() const { return tuple_; }
    105 
    106   // Returns true if the sharding is trivial: replicate on all devices.
    107   bool IsReplicated() const {
    108     if (!IsTuple()) {
    109       return replicated_;
    110     }
    111     return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
    112                        [](const HloSharding& s) { return s.IsReplicated(); });
    113   }
    114 
    115   // Returns true if the tile size is the same as the input size.
    116   bool IsTileMaximal() const {
    117     if (!IsTuple()) {
    118       return maximal_;
    119     }
    120     return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
    121                        [](const HloSharding& s) { return s.IsTileMaximal(); });
    122   }
    123 
    124   // Returns true if the sharding defines an operation on the given device.
    125   bool UsesDevice(int64 device) const;
    126 
    127   // Returns the tile that should be executed on the given device.
    128   // REQUIRES: !IsTuple()
    129   std::vector<int64> TileIndexForDevice(int64 device) const;
    130 
    131   // Returns the device that should execute the given tile.
    132   // It is an error to call this if is_replicated() is true.
    133   // REQUIRES: !IsTuple()
    134   int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
    135 
    136   // Given a device ID, returns the offset within the input space of the
    137   // tile that should be executed on the given core. This returns the lower
    138   // extent of the tile in the input space.
    139   // REQUIRES: !IsTuple()
    140   std::vector<int64> TileOffsetForDevice(int64 device) const;
    141 
    142   // Given a device ID, returns the limit within the input space of the
    143   // tile that should be executed on the given core. This returns the upper
    144   // extent of the tile in the input space.
    145   // REQUIRES: !IsTuple()
    146   std::vector<int64> TileLimitForDevice(int64 device) const;
    147 
    148   // Returns the single device this op operates on.
    149   // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
    150   StatusOr<int64> UniqueDevice() const;
    151 
    152   // Returns true if this op only uses a single device.
    153   bool HasUniqueDevice() const;
    154 
    155   // Returns the ShapeTree containing the shardings for each element of this
    156   // tuple, if IsTuple, or a ShapeTree with a single element containing this
    157   // sharding. Only the leaf elements are populated. This creates a new
    158   // ShapeTree object so is not cheap.
    159   ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const {
    160     if (IsTuple()) {
    161       ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
    162       CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
    163                tuple_elements_.size());
    164       auto it = tuple_elements_.begin();
    165       for (auto& index_to_sharding : result.leaves()) {
    166         index_to_sharding.second = *it++;
    167       }
    168       return result;
    169     } else {
    170       return ShapeTree<HloSharding>(shape, *this);
    171     }
    172   }
    173 
    174   bool operator==(const HloSharding& other) const {
    175     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
    176            protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
    177            tile_assignment_ == other.tile_assignment_ &&
    178            tuple_elements_ == other.tuple_elements_;
    179   }
    180   bool operator!=(const HloSharding& other) const { return !(*this == other); }
    181 
    182   size_t Hash() const {
    183     if (!tuple_) {
    184       size_t h = 0;
    185       for (const auto& element : tuple_elements_) {
    186         h = tensorflow::Hash64Combine(h, element.Hash());
    187       }
    188       return h;
    189     }
    190     if (replicated_) {
    191       return 0;
    192     }
    193     size_t h = 0;
    194     for (uint32 v : tile_assignment_) {
    195       h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
    196     }
    197     for (uint32 v : tile_shape_.dimensions()) {
    198       h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
    199     }
    200     return h;
    201   }
    202 
    203   // Gets the tile shape.
    204   // REQUIRES: !IsTileMaximal() && !IsTuple()
    205   const Shape& tile_shape() const { return tile_shape_; }
    206   // Gets the tile assignment tensor.
    207   // REQUIRES: !IsReplicated() && !IsTuple()
    208   const Array<int64>& tile_assignment() const { return tile_assignment_; }
    209 
    210  private:
    211   HloSharding()
    212       : replicated_(true),
    213         maximal_(true),
    214         tuple_(false),
    215         tile_shape_(),
    216         tile_assignment_({0}) {}
    217   explicit HloSharding(int64 device_id)
    218       : replicated_(false),
    219         maximal_(true),
    220         tuple_(false),
    221         tile_shape_(),
    222         tile_assignment_({1}, device_id) {}
    223   HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
    224       : replicated_(false),
    225         maximal_(false),
    226         tuple_(false),
    227         tile_shape_(tile_shape),
    228         tile_assignment_(tile_assignment) {}
    229   HloSharding(const std::vector<HloSharding>& tuple_shardings)
    230       : replicated_(false),
    231         maximal_(false),
    232         tuple_(true),
    233         tile_assignment_({0}),
    234         tuple_elements_(tuple_shardings) {}
    235 
    236   // Internal helper to validate a tuple sharding.
    237   Status ValidateTuple(const Shape& shape, int64 num_devices) const;
    238   // Internal helper to validate a non-tuple (leaf) sharding.
    239   Status ValidateNonTuple(const Shape& shape, int64 num_devices) const;
    240 
    241   bool replicated_;
    242   bool maximal_;
    243   bool tuple_;
    244   Shape tile_shape_;
    245   Array<int64> tile_assignment_;
    246   // Only non-empty when tuple_ is true, but because empty tuples are allowed
    247   // may also be empty even then. This is a flattened list of all the leaf
    248   // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
    249   std::vector<HloSharding> tuple_elements_;
    250 };
    251 
    252 }  // namespace xla
    253 
    254 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
    255