1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO shardings describe how an HLO instruction is split across multiple 17 // computations. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 21 22 #include <string> 23 24 #include "tensorflow/compiler/xla/array.h" 25 #include "tensorflow/compiler/xla/literal_util.h" 26 #include "tensorflow/compiler/xla/protobuf_util.h" 27 #include "tensorflow/compiler/xla/shape_tree.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/lib/gtl/array_slice.h" 30 #include "tensorflow/core/lib/hash/hash.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 37 // HLO shardings describe how an HLO instruction is split across multiple 38 // computations. 39 class HloSharding { 40 public: 41 // Creates a trivial sharding that replicates a maximal tile across all 42 // devices. 43 static HloSharding Replicate() { return HloSharding(); } 44 45 // Creates a sharding that emulates device placement; a tile shape equal to 46 // the input shape (one tile) assigned to a single device. 47 static HloSharding AssignDevice(int64 device_id); 48 49 // Creates a new sharding which splits a shape into tiles each with shape 50 // `tile_shape`. Each tile is assigned to one device, which is specified by 51 // `tile_assignment`. Any tensor not a multiple of the tile size in any 52 // dimension is implicitly padded to the tile size. 53 // 54 // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: 55 // 2 1 padding 56 // <------><-> 57 // +----+----+ 58 // | 0 | 1 | 59 // +----+----+ 60 // 61 // Split into two tiles, one of which is implicitly padded by one. 62 static HloSharding Tile(const Shape& tile_shape, 63 const Array<int64>& tile_assignment) { 64 return HloSharding(tile_shape, tile_assignment); 65 } 66 67 // Creates a new sharding which splits a one-dimensional input shape into 68 // `num_tiles` tiles. 69 static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); 70 71 // Creates a new sharding for a tuple type. The given ShapeTree must have 72 // elements for every leaf shape contained in the tuple. 73 static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) { 74 std::vector<HloSharding> flattened_list; 75 flattened_list.reserve( 76 std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); 77 for (const auto& index_to_sharding : sub_shardings.leaves()) { 78 flattened_list.push_back(index_to_sharding.second); 79 } 80 return HloSharding(flattened_list); 81 } 82 83 // Creates a new sharding for a tuple type. The requested tuple shape must not 84 // be nested. For nested tuples, use the ShapeTree overload. 85 static HloSharding Tuple(const Shape& tuple_shape, 86 tensorflow::gtl::ArraySlice<HloSharding> shardings) { 87 CHECK(ShapeUtil::IsTuple(tuple_shape)); 88 CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); 89 std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end()); 90 CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); 91 return HloSharding(flattened_list); 92 } 93 94 // Create a new sharding from a protobuf OpSharding. 95 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 96 97 OpSharding ToProto() const; 98 string ToString() const; 99 100 // Validate that this sharding can be applied to a tensor with shape `shape`. 101 Status Validate(const Shape& shape, int64 num_devices) const; 102 103 // Returns true if the sharding has tuple type. 104 bool IsTuple() const { return tuple_; } 105 106 // Returns true if the sharding is trivial: replicate on all devices. 107 bool IsReplicated() const { 108 if (!IsTuple()) { 109 return replicated_; 110 } 111 return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), 112 [](const HloSharding& s) { return s.IsReplicated(); }); 113 } 114 115 // Returns true if the tile size is the same as the input size. 116 bool IsTileMaximal() const { 117 if (!IsTuple()) { 118 return maximal_; 119 } 120 return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), 121 [](const HloSharding& s) { return s.IsTileMaximal(); }); 122 } 123 124 // Returns true if the sharding defines an operation on the given device. 125 bool UsesDevice(int64 device) const; 126 127 // Returns the tile that should be executed on the given device. 128 // REQUIRES: !IsTuple() 129 std::vector<int64> TileIndexForDevice(int64 device) const; 130 131 // Returns the device that should execute the given tile. 132 // It is an error to call this if is_replicated() is true. 133 // REQUIRES: !IsTuple() 134 int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const; 135 136 // Given a device ID, returns the offset within the input space of the 137 // tile that should be executed on the given core. This returns the lower 138 // extent of the tile in the input space. 139 // REQUIRES: !IsTuple() 140 std::vector<int64> TileOffsetForDevice(int64 device) const; 141 142 // Given a device ID, returns the limit within the input space of the 143 // tile that should be executed on the given core. This returns the upper 144 // extent of the tile in the input space. 145 // REQUIRES: !IsTuple() 146 std::vector<int64> TileLimitForDevice(int64 device) const; 147 148 // Returns the single device this op operates on. 149 // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() 150 StatusOr<int64> UniqueDevice() const; 151 152 // Returns true if this op only uses a single device. 153 bool HasUniqueDevice() const; 154 155 // Returns the ShapeTree containing the shardings for each element of this 156 // tuple, if IsTuple, or a ShapeTree with a single element containing this 157 // sharding. Only the leaf elements are populated. This creates a new 158 // ShapeTree object so is not cheap. 159 ShapeTree<HloSharding> GetAsShapeTree(const Shape& shape) const { 160 if (IsTuple()) { 161 ShapeTree<HloSharding> result(shape, HloSharding::Replicate()); 162 CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), 163 tuple_elements_.size()); 164 auto it = tuple_elements_.begin(); 165 for (auto& index_to_sharding : result.leaves()) { 166 index_to_sharding.second = *it++; 167 } 168 return result; 169 } else { 170 return ShapeTree<HloSharding>(shape, *this); 171 } 172 } 173 174 bool operator==(const HloSharding& other) const { 175 return replicated_ == other.replicated_ && maximal_ == other.maximal_ && 176 protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && 177 tile_assignment_ == other.tile_assignment_ && 178 tuple_elements_ == other.tuple_elements_; 179 } 180 bool operator!=(const HloSharding& other) const { return !(*this == other); } 181 182 size_t Hash() const { 183 if (!tuple_) { 184 size_t h = 0; 185 for (const auto& element : tuple_elements_) { 186 h = tensorflow::Hash64Combine(h, element.Hash()); 187 } 188 return h; 189 } 190 if (replicated_) { 191 return 0; 192 } 193 size_t h = 0; 194 for (uint32 v : tile_assignment_) { 195 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v)); 196 } 197 for (uint32 v : tile_shape_.dimensions()) { 198 h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v)); 199 } 200 return h; 201 } 202 203 // Gets the tile shape. 204 // REQUIRES: !IsTileMaximal() && !IsTuple() 205 const Shape& tile_shape() const { return tile_shape_; } 206 // Gets the tile assignment tensor. 207 // REQUIRES: !IsReplicated() && !IsTuple() 208 const Array<int64>& tile_assignment() const { return tile_assignment_; } 209 210 private: 211 HloSharding() 212 : replicated_(true), 213 maximal_(true), 214 tuple_(false), 215 tile_shape_(), 216 tile_assignment_({0}) {} 217 explicit HloSharding(int64 device_id) 218 : replicated_(false), 219 maximal_(true), 220 tuple_(false), 221 tile_shape_(), 222 tile_assignment_({1}, device_id) {} 223 HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment) 224 : replicated_(false), 225 maximal_(false), 226 tuple_(false), 227 tile_shape_(tile_shape), 228 tile_assignment_(tile_assignment) {} 229 HloSharding(const std::vector<HloSharding>& tuple_shardings) 230 : replicated_(false), 231 maximal_(false), 232 tuple_(true), 233 tile_assignment_({0}), 234 tuple_elements_(tuple_shardings) {} 235 236 // Internal helper to validate a tuple sharding. 237 Status ValidateTuple(const Shape& shape, int64 num_devices) const; 238 // Internal helper to validate a non-tuple (leaf) sharding. 239 Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; 240 241 bool replicated_; 242 bool maximal_; 243 bool tuple_; 244 Shape tile_shape_; 245 Array<int64> tile_assignment_; 246 // Only non-empty when tuple_ is true, but because empty tuples are allowed 247 // may also be empty even then. This is a flattened list of all the leaf 248 // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). 249 std::vector<HloSharding> tuple_elements_; 250 }; 251 252 } // namespace xla 253 254 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ 255