1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/layout_util.h" 17 18 #include <stddef.h> 19 #include <algorithm> 20 #include <functional> 21 #include <random> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "absl/strings/str_cat.h" 27 #include "absl/strings/str_join.h" 28 #include "tensorflow/compiler/xla/protobuf_util.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/status_macros.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/compiler/xla/util.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/hash/hash.h" 35 #include "tensorflow/core/lib/strings/numbers.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/protobuf.h" 38 39 namespace xla { 40 namespace { 41 42 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets 43 // minor_to_major to the value that represents the default layout. 44 void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) { 45 // The default XLA layout is major-to-minor (dim 0 is major). 46 // For more information on XLA layouts, see: 47 // https://www.tensorflow.org/performance/xla/shapes 48 const int64 size = minor_to_major->size(); 49 for (int64 i = 0; i < size; ++i) { 50 (*minor_to_major)[i] = size - 1 - i; 51 } 52 } 53 54 } // namespace 55 56 /* static */ Layout LayoutUtil::MakeLayout( 57 absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles, 58 int64 element_size_in_bits) { 59 Layout layout; 60 layout.set_format(DENSE); 61 for (int64 dimension_number : minor_to_major) { 62 layout.add_minor_to_major(dimension_number); 63 } 64 for (Tile tile : tiles) { 65 for (int64 dim : tile.dimensions()) { 66 if (dim < 0 && dim != Tile::kCombineDimension) { 67 LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if " 68 "it's negative. Value is " 69 << dim; 70 } 71 } 72 *layout.add_tiles() = tile; 73 } 74 layout.set_element_size_in_bits(element_size_in_bits); 75 return layout; 76 } 77 78 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) { 79 std::vector<int64> layout(rank); 80 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0)); 81 return MakeLayout(layout); 82 } 83 84 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( 85 absl::Span<const int64> major_to_minor) { 86 Layout layout; 87 layout.set_format(DENSE); 88 for (int i = major_to_minor.size() - 1; i >= 0; i--) { 89 layout.add_minor_to_major(major_to_minor[i]); 90 } 91 return layout; 92 } 93 94 /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { 95 Layout layout; 96 layout.set_format(SPARSE); 97 layout.set_max_sparse_elements(max_sparse_elements); 98 return layout; 99 } 100 101 namespace { 102 103 // Internal helper that creates a default layout for an array of the given rank. 104 Layout CreateDefaultLayoutForRank(int64 rank) { 105 Layout layout; 106 layout.set_format(DENSE); 107 std::vector<int64>* minor_to_major = layout.mutable_minor_to_major(); 108 minor_to_major->resize(rank, 0); 109 SetDefaultLayoutToContainer(minor_to_major); 110 return layout; 111 } 112 113 } // namespace 114 115 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { 116 if (shape.IsOpaque() || shape.IsToken()) { 117 // Opaque and token types have empty layouts. 118 return Layout(); 119 } 120 121 // A Layout proto corresponds to a single array, not a tuple. 122 CHECK(shape.IsArray()); 123 return CreateDefaultLayoutForRank(shape.dimensions_size()); 124 } 125 126 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { 127 return CreateDefaultLayoutForRank(rank); 128 } 129 130 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { 131 return CreateDefaultLayoutForRank(2); 132 } 133 134 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() { 135 return CreateDefaultLayoutForRank(3); 136 } 137 138 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() { 139 return CreateDefaultLayoutForRank(4); 140 } 141 142 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) { 143 if (shape->IsTuple()) { 144 // Tuple shape. 145 for (auto& element_shape : *shape->mutable_tuple_shapes()) { 146 SetToDefaultLayout(&element_shape); 147 } 148 shape->clear_layout(); 149 } else if (shape->IsArray()) { 150 shape->mutable_layout()->set_format(DENSE); 151 auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); 152 minor_to_major->resize(shape->dimensions_size(), 0); 153 SetDefaultLayoutToContainer(minor_to_major); 154 } else { 155 // Opaque, token types etc. have no layout. 156 shape->clear_layout(); 157 } 158 } 159 160 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) { 161 Shape copy(shape); 162 LayoutUtil::SetToDefaultLayout(©); 163 return copy; 164 } 165 166 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) { 167 for (auto& parameter_shape : *program_shape->mutable_parameters()) { 168 LayoutUtil::SetToDefaultLayout(¶meter_shape); 169 } 170 LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); 171 } 172 173 /* static */ Status LayoutUtil::ValidateLayoutInShape( 174 const Shape& shape, bool allow_missing_layouts) { 175 if (shape.IsTuple()) { 176 // Tuple shape. 177 if (shape.has_layout()) { 178 return InvalidArgument("tuple should not have a layout field"); 179 } 180 for (auto& element_shape : shape.tuple_shapes()) { 181 TF_RETURN_IF_ERROR( 182 ValidateLayoutInShape(element_shape, allow_missing_layouts)); 183 } 184 return Status::OK(); 185 } else if (shape.IsArray()) { 186 if (!shape.has_layout()) { 187 if (allow_missing_layouts) { 188 return Status::OK(); 189 } 190 return InvalidArgument("shape %s does not have a layout", 191 ShapeUtil::HumanString(shape)); 192 } 193 return ValidateLayoutForShape(shape.layout(), shape); 194 } else { 195 // Token, opaque, etc. shape. 196 if (shape.has_layout()) { 197 return InvalidArgument( 198 "shape of primitive type %s should not have a layout", 199 PrimitiveType_Name(shape.element_type())); 200 } 201 return Status::OK(); 202 } 203 } 204 205 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, 206 const Shape& shape) { 207 if (shape.IsTuple()) { 208 return InvalidArgument("a single Layout is not valid for tuple shapes"); 209 } 210 211 if (!shape.IsArray()) { 212 if (layout.minor_to_major_size() != 0) { 213 return InvalidArgument( 214 "shape of primitive type %s should not have a non-trivial layout", 215 PrimitiveType_Name(shape.element_type())); 216 } 217 return Status::OK(); 218 } 219 220 if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { 221 return InvalidArgument("Layout has an invalid format (%d)", 222 layout.format()); 223 } 224 225 if (layout.format() == DENSE) { 226 if (layout.minor_to_major_size() != shape.rank()) { 227 return InvalidArgument( 228 "layout minor_to_major field contains %d elements, " 229 "but shape is rank %d: {%s}; shape: %s", 230 layout.minor_to_major_size(), shape.rank(), 231 absl::StrJoin(layout.minor_to_major(), ", "), 232 shape.ShortDebugString()); 233 } 234 235 std::vector<bool> dimensions_in_layout(shape.rank(), false); 236 for (int64 i = 0; i < shape.rank(); ++i) { 237 int64 dim = layout.minor_to_major(i); 238 if (dim < 0 || dim >= shape.rank()) { 239 return InvalidArgument( 240 "layout minor_to_major field has out-of-bounds value: %s", 241 HumanString(layout)); 242 } 243 if (dimensions_in_layout[dim]) { 244 return InvalidArgument( 245 "layout minor_to_major field has duplicate values: {%s}", 246 HumanString(layout)); 247 } 248 dimensions_in_layout[dim] = true; 249 } 250 } else { 251 if (layout.tiles_size() != 0) { 252 return InvalidArgument("Only dense layouts can be tiled."); 253 } 254 } 255 256 return Status::OK(); 257 } 258 259 /* static */ void LayoutUtil::ClearLayout(Shape* shape) { 260 shape->clear_layout(); 261 for (auto& element_shape : *shape->mutable_tuple_shapes()) { 262 ClearLayout(&element_shape); 263 } 264 } 265 266 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) { 267 for (auto& parameter_shape : *program_shape->mutable_parameters()) { 268 LayoutUtil::ClearLayout(¶meter_shape); 269 } 270 LayoutUtil::ClearLayout(program_shape->mutable_result()); 271 } 272 273 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { 274 return shape.IsArray() && shape.has_layout() && IsDense(shape.layout()); 275 } 276 277 /* static */ bool LayoutUtil::IsDense(const Layout& layout) { 278 return layout.format() == DENSE; 279 } 280 281 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) { 282 CHECK(layout.format() == DENSE); 283 return std::is_sorted(layout.minor_to_major().begin(), 284 layout.minor_to_major().end()); 285 } 286 287 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) { 288 CHECK(layout.format() == DENSE); 289 return std::is_sorted(layout.minor_to_major().begin(), 290 layout.minor_to_major().end(), std::greater<int64>()); 291 } 292 293 /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { 294 return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout()); 295 } 296 297 /* static */ bool LayoutUtil::IsSparse(const Layout& layout) { 298 return layout.format() == SPARSE; 299 } 300 301 /* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) { 302 CHECK(IsSparse(layout)); 303 return layout.max_sparse_elements(); 304 } 305 306 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { 307 if (shape.IsTuple()) { 308 // Tuple shape: all subshapes must have a layout. 309 return absl::c_all_of(shape.tuple_shapes(), 310 [](const Shape& s) { return HasLayout(s); }); 311 } else if (!shape.IsArray()) { 312 // Opaque, token types etc. ignore layout. 313 return true; 314 } 315 return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; 316 } 317 318 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) { 319 for (auto& parameter_shape : program_shape.parameters()) { 320 if (!LayoutUtil::HasLayout(parameter_shape)) { 321 return false; 322 } 323 } 324 return LayoutUtil::HasLayout(program_shape.result()); 325 } 326 327 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { 328 return lhs == rhs; 329 } 330 331 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor( 332 const Shape& shape) { 333 CHECK(IsDenseArray(shape)); 334 return AsInt64Slice(shape.layout().minor_to_major()); 335 } 336 337 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor( 338 const Layout& layout) { 339 CHECK(layout.format() == DENSE); 340 return AsInt64Slice(layout.minor_to_major()); 341 } 342 343 /* static */ int64 LayoutUtil::Major(const Layout& layout, 344 int64 physical_dimension_number) { 345 CHECK_LE(0, physical_dimension_number); 346 CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); 347 return Minor(layout, 348 layout.minor_to_major_size() - 1 - physical_dimension_number); 349 } 350 351 /* static */ int64 LayoutUtil::Minor(const Layout& layout, 352 int64 physical_dimension_number) { 353 CHECK_EQ(layout.format(), DENSE); 354 CHECK_LE(0, physical_dimension_number); 355 CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); 356 return layout.minor_to_major(physical_dimension_number); 357 } 358 359 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical( 360 const Layout& layout) { 361 std::vector<int64> logical_to_physical(layout.minor_to_major_size()); 362 for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) { 363 const int64 logical = Major(layout, physical); 364 logical_to_physical[logical] = physical; 365 } 366 return logical_to_physical; 367 } 368 369 /* static */ string LayoutUtil::HumanString(const Layout& layout) { 370 return layout.ToString(); 371 } 372 373 namespace { 374 375 // Internal helper for recursively copying layouts. 376 Status CopyLayoutInternal(const Shape& src, Shape* dst) { 377 if (src.IsTuple() != dst->IsTuple()) { 378 return InvalidArgument( 379 "cannot copy layout from shape: shape structure differs"); 380 } 381 if (src.IsTuple()) { 382 if (ShapeUtil::TupleElementCount(src) != 383 ShapeUtil::TupleElementCount(*dst)) { 384 return InvalidArgument( 385 "cannot copy layout from shape: tuple element count differs"); 386 } 387 for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) { 388 TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i), 389 dst->mutable_tuple_shapes(i))); 390 } 391 } else { 392 if (src.has_layout()) { 393 if (src.rank() != dst->rank()) { 394 return InvalidArgument("cannot copy layout from shape: ranks differs"); 395 } 396 TF_RETURN_IF_ERROR( 397 LayoutUtil::ValidateLayoutForShape(src.layout(), *dst)); 398 *dst->mutable_layout() = src.layout(); 399 } else { 400 dst->clear_layout(); 401 } 402 } 403 return Status::OK(); 404 } 405 406 } // namespace 407 408 /* static */ 409 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { 410 return CopyLayoutInternal(src, dst); 411 } 412 413 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, 414 const Shape& rhs) { 415 if (lhs.IsTuple()) { 416 if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) != 417 ShapeUtil::TupleElementCount(rhs)) { 418 return false; 419 } 420 for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { 421 if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) { 422 return false; 423 } 424 } 425 return true; 426 } else if (lhs.IsArray()) { 427 return lhs.rank() == rhs.rank() && 428 LayoutUtil::Equal(lhs.layout(), rhs.layout()); 429 } else { 430 // Layouts of non-array and non-tuple shapes is ignored. 431 return true; 432 } 433 } 434 435 /* static */ bool LayoutUtil::AreDimensionsConsecutive( 436 const Layout& layout, absl::Span<const int64> dims) { 437 CHECK(IsDense(layout)); 438 std::vector<int64> positions_in_layout; 439 for (int64 dim : dims) { 440 positions_in_layout.push_back( 441 PositionInContainer(layout.minor_to_major(), dim)); 442 } 443 absl::c_sort(positions_in_layout); 444 for (size_t i = 1; i < positions_in_layout.size(); ++i) { 445 if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) { 446 return false; 447 } 448 } 449 return true; 450 } 451 452 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { 453 using tensorflow::hash; 454 using tensorflow::Hash64Combine; 455 456 size_t hash_value = hash<Format>()(layout.format()); 457 458 for (int64 minor_to_major : layout.minor_to_major()) { 459 hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major)); 460 } 461 hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); 462 463 for (Tile tile : layout.tiles()) { 464 for (int64 tile_dim : tile.dimensions()) { 465 hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim)); 466 } 467 } 468 hash_value = Hash64Combine(hash_value, layout.element_size_in_bits()); 469 470 return hash_value; 471 } 472 473 } // namespace xla 474