1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 17 18 #include "tensorflow/core/lib/core/errors.h" 19 #include "tensorflow/core/lib/strings/str_util.h" 20 21 namespace xla { 22 23 using ::tensorflow::strings::StrCat; 24 25 HloSharding HloSharding::AssignDevice(int64 device_id) { 26 return HloSharding(device_id); 27 } 28 29 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { 30 CHECK_EQ(1, ShapeUtil::Rank(input_shape)); 31 CHECK_GT(num_tiles, 1); 32 std::vector<int64> dimensions(1, num_tiles); 33 Shape tile_shape = input_shape; 34 auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; 35 tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles); 36 Array<int64> assignment(dimensions); 37 std::iota(assignment.begin(), assignment.end(), 0); 38 return HloSharding(tile_shape, assignment); 39 } 40 41 string HloSharding::ToString() const { 42 if (IsTuple()) { 43 std::vector<string> parts; 44 parts.reserve(tuple_elements_.size()); 45 for (const HloSharding& element : tuple_elements_) { 46 parts.push_back(element.ToString()); 47 } 48 return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); 49 } 50 51 string result = StrCat("{", (replicated_ ? " replicated" : ""), 52 (maximal_ ? " maximal" : "")); 53 54 if (replicated_) { 55 return "{replicated}"; 56 } else if (maximal_) { 57 return StrCat( 58 "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}"); 59 } else { 60 return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", 61 "devices=", VectorString(tile_assignment_), "}"); 62 } 63 } 64 65 bool HloSharding::UsesDevice(int64 device) const { 66 if (IsTuple()) { 67 return std::any_of( 68 tuple_elements_.begin(), tuple_elements_.end(), 69 [&](const HloSharding& s) { return s.UsesDevice(device); }); 70 } 71 const auto& devices = tile_assignment_; 72 return replicated_ || 73 std::find(devices.begin(), devices.end(), device) != devices.end(); 74 } 75 76 std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const { 77 CHECK(!ShapeUtil::IsTuple(tile_shape_)); 78 CHECK(!maximal_); 79 CHECK(!IsTuple()); 80 std::vector<int64> ret_index; 81 tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) { 82 if (d == device) { 83 ret_index = {index.begin(), index.end()}; 84 } 85 }); 86 CHECK(!ret_index.empty()); 87 return ret_index; 88 } 89 90 int64 HloSharding::DeviceForTileIndex( 91 tensorflow::gtl::ArraySlice<int64> index) const { 92 CHECK(!replicated_); 93 CHECK(!IsTuple()); 94 if (maximal_) { 95 return *tile_assignment_.begin(); 96 } 97 CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); 98 return tile_assignment_(index); 99 } 100 101 std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const { 102 CHECK(!IsTuple()); 103 104 std::vector<int64> index = TileIndexForDevice(device); 105 if (maximal_) { 106 // Index will always be all zeroes if we're maximal, and tile_shape_ is not 107 // valid. 108 return index; 109 } 110 for (int64 i = 0; i < index.size(); ++i) { 111 index[i] *= tile_shape_.dimensions(i); 112 } 113 return index; 114 } 115 116 std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const { 117 CHECK(!IsTuple()); 118 CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. 119 120 std::vector<int64> index = TileIndexForDevice(device); 121 for (int64 i = 0; i < index.size(); ++i) { 122 index[i] = (index[i] + 1) * tile_shape_.dimensions(i); 123 } 124 return index; 125 } 126 127 StatusOr<int64> HloSharding::UniqueDevice() const { 128 if (IsTuple()) { 129 if (tuple_elements_.empty()) { 130 return tensorflow::errors::InvalidArgument( 131 "UniqueDevice() called on empty tuple"); 132 } 133 std::vector<StatusOr<int64>> results; 134 std::transform(tuple_elements_.begin(), tuple_elements_.end(), 135 std::back_inserter(results), 136 [](const HloSharding& s) { return s.UniqueDevice(); }); 137 if (std::all_of(results.begin(), results.end(), 138 [&](const StatusOr<int64>& s) { 139 return s.ok() && results[0].ok() && 140 s.ValueOrDie() == results[0].ValueOrDie(); 141 })) { 142 return results[0]; 143 } else { 144 return tensorflow::errors::InvalidArgument( 145 "Tuple did not contain a unique device"); 146 } 147 } 148 if (!replicated_ && maximal_ && !IsTuple()) { 149 return static_cast<int64>(*tile_assignment_.begin()); 150 } 151 return tensorflow::errors::InvalidArgument( 152 "UniqueDevice() called on sharding that executes on multiple devices"); 153 } 154 155 bool HloSharding::HasUniqueDevice() const { 156 if (IsTuple()) { 157 return UniqueDevice().status().ok(); 158 } else { 159 return !IsReplicated() && IsTileMaximal(); 160 } 161 } 162 163 Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { 164 if (!ShapeUtil::IsTuple(shape)) { 165 return tensorflow::errors::InvalidArgument( 166 StrCat("Sharding is tuple-shaped but validation shape is not.")); 167 } 168 // The easiest way to get the number of elements in a nested tuple is just to 169 // create a shape tree. We could call GetAsShapeTree, but that will try and 170 // apply our tuple_shardings_ to the shape tree, and that might cause a crash 171 // at this point as we haven't validated them. 172 ShapeTree<bool> bool_shape_tree(shape, false); 173 int64 num_leaves = 174 std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); 175 if (num_leaves != tuple_elements_.size()) { 176 return tensorflow::errors::InvalidArgument( 177 StrCat("Validation tuple shape has ", num_leaves, 178 " leaf elements, but this sharding contains ", 179 tuple_elements_.size(), " elements.")); 180 } 181 182 // Now we've validated the number of tuple elements, it's safe to request a 183 // shape tree. 184 ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape); 185 for (const auto& index_to_sharding : shape_tree.leaves()) { 186 Status status = index_to_sharding.second.ValidateNonTuple( 187 ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); 188 if (!status.ok()) { 189 tensorflow::errors::AppendToMessage( 190 &status, StrCat("Note: While validating sharding tuple element ", 191 index_to_sharding.first.ToString(), " which is ", 192 index_to_sharding.second.ToString())); 193 return status; 194 } 195 } 196 return Status::OK(); 197 } 198 199 Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { 200 Status status = IsTuple() ? ValidateTuple(shape, num_devices) 201 : ValidateNonTuple(shape, num_devices); 202 if (!status.ok()) { 203 tensorflow::errors::AppendToMessage( 204 &status, StrCat("Note: While validating sharding ", ToString(), 205 " against shape ", ShapeUtil::HumanString(shape))); 206 } 207 return status; 208 } 209 210 Status HloSharding::ValidateNonTuple(const Shape& shape, 211 int64 num_devices) const { 212 if (ShapeUtil::IsTuple(shape)) { 213 return tensorflow::errors::InvalidArgument( 214 StrCat("Validation shape is a tuple but sharding is not.")); 215 } 216 if (replicated_) { 217 return Status::OK(); 218 } 219 220 // All tile assignments must be less than the number of available cores and 221 // unique. 222 Status status = Status::OK(); 223 std::set<int64> seen_cores; 224 tile_assignment_.Each( 225 [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 core) { 226 // Don't overwrite a bad status, so we report the first error. 227 if (status.ok()) { 228 if (core >= num_devices) { 229 status = tensorflow::errors::InvalidArgument(StrCat( 230 "core ", core, " > ", num_devices, " in tile assignment")); 231 } else if (seen_cores.count(core) != 0) { 232 status = tensorflow::errors::InvalidArgument( 233 StrCat("core ", core, " is not unique in tile assignment")); 234 } 235 } 236 seen_cores.insert(core); 237 }); 238 if (!status.ok()) { 239 return status; 240 } 241 242 if (IsTileMaximal()) { 243 return Status::OK(); 244 } 245 246 // The tile rank must be the same as the input rank. 247 if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { 248 return tensorflow::errors::InvalidArgument( 249 "Tile rank is different to the input rank. sharding=", ToString(), 250 ", input_shape=", ShapeUtil::HumanString(shape)); 251 } 252 253 // The tile shape must not be the same as the input shape without maximal_ 254 // also set. If this is the case, we're not actually sharded and the correct 255 // constructor should have been used. 256 if (ShapeUtil::Equal(shape, tile_shape_)) { 257 return tensorflow::errors::InvalidArgument( 258 "Tile shape is the same as the input shape. If a replicated sharding " 259 "was intended, use HloSharding::Replicated(). If a device placement " 260 "was intended, use HloSharding::AssignDevice()"); 261 } 262 263 // The tile shape must not be greater than the input shape in any dimension. 264 for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) { 265 auto tile_dim = tile_shape_.dimensions(i); 266 auto shape_dim = shape.dimensions(i); 267 if (tile_dim > shape_dim) { 268 return tensorflow::errors::InvalidArgument( 269 StrCat("Tile is larger than input shape (dimension ", i, ", ", 270 tile_dim, " > ", shape_dim)); 271 } 272 } 273 274 // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim] 275 // tile[dim]) for every dimension contained within tile. 276 for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { 277 int64 expected_dim = 278 CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); 279 if (tile_assignment_.dimensions()[i] != expected_dim) { 280 return tensorflow::errors::InvalidArgument( 281 StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, 282 " expected ", expected_dim, " but got ", 283 tile_assignment_.dimensions()[i])); 284 } 285 } 286 287 return Status::OK(); 288 } 289 290 /*static*/ StatusOr<HloSharding> HloSharding::FromProto( 291 const OpSharding& proto) { 292 if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) { 293 std::vector<HloSharding> tuple_shardings; 294 tuple_shardings.reserve(proto.tuple_shardings().size()); 295 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { 296 TF_ASSIGN_OR_RETURN(HloSharding sharding, 297 HloSharding::FromProto(tuple_sharding_proto)); 298 tuple_shardings.push_back(sharding); 299 } 300 return HloSharding(tuple_shardings); 301 } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { 302 return Replicate(); 303 } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || 304 proto.tile_assignment_devices().size() == 1) { 305 return HloSharding(proto.tile_assignment_devices(0)); 306 } 307 // Some versions of gcc cannot infer the TileAssignment constructor from a 308 // braced initializer-list, so create one manually. 309 std::vector<int64> devices(proto.tile_assignment_devices().begin(), 310 proto.tile_assignment_devices().end()); 311 Array<int64> tile_assignment( 312 std::vector<int64>(proto.tile_assignment_dimensions().begin(), 313 proto.tile_assignment_dimensions().end())); 314 std::copy(proto.tile_assignment_devices().begin(), 315 proto.tile_assignment_devices().end(), tile_assignment.begin()); 316 return HloSharding(proto.tile_shape(), tile_assignment); 317 } 318 319 OpSharding HloSharding::ToProto() const { 320 OpSharding result; 321 322 if (IsTuple()) { 323 for (const HloSharding& element : tuple_elements_) { 324 *result.add_tuple_shardings() = element.ToProto(); 325 } 326 result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); 327 return result; 328 } 329 330 *result.mutable_tile_shape() = tile_shape_; 331 for (int64 dim : tile_assignment_.dimensions()) { 332 result.add_tile_assignment_dimensions(dim); 333 } 334 for (auto device : tile_assignment_) { 335 result.add_tile_assignment_devices(device); 336 } 337 if (IsReplicated()) { 338 result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); 339 } else if (IsTileMaximal()) { 340 result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); 341 } else { 342 result.set_type(OpSharding::Type::OpSharding_Type_OTHER); 343 } 344 return result; 345 } 346 347 } // namespace xla 348