1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/shape_util.h" 17 18 #include <algorithm> 19 #include <functional> 20 #include <numeric> 21 #include <unordered_map> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/compiler/xla/index_util.h" 26 #include "tensorflow/compiler/xla/layout_util.h" 27 #include "tensorflow/compiler/xla/primitive_util.h" 28 #include "tensorflow/compiler/xla/status_macros.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/util.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/core/stringpiece.h" 33 #include "tensorflow/core/lib/gtl/iterator_range.h" 34 #include "tensorflow/core/lib/gtl/optional.h" 35 #include "tensorflow/core/lib/strings/numbers.h" 36 #include "tensorflow/core/lib/strings/str_util.h" 37 #include "tensorflow/core/lib/strings/strcat.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/protobuf.h" 40 #include "tensorflow/core/platform/regexp.h" 41 42 namespace xla { 43 44 string ShapeIndex::ToString() const { 45 return tensorflow::strings::StrCat( 46 "{", tensorflow::str_util::Join(indices_, ","), "}"); 47 } 48 49 string ShapeIndexView::ToString() const { 50 return tensorflow::strings::StrCat( 51 "{", 52 tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), 53 ","), 54 "}"); 55 } 56 57 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { 58 out << shape_index.ToString(); 59 return out; 60 } 61 62 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { 63 out << shape_index.ToString(); 64 return out; 65 } 66 67 namespace { 68 69 // Recursive helper for comparing the equality of two shapes. Returns true if 70 // the shapes are the same. If compare_layouts is true, then layouts must also 71 // match. 72 bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { 73 if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { 74 return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && 75 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), 76 [=](const Shape& l, const Shape& r) { 77 return CompareShapes(l, r, compare_layouts); 78 }); 79 } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { 80 return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); 81 } 82 83 if (compare_layouts) { 84 if (lhs.layout().format() != rhs.layout().format()) { 85 return false; 86 } 87 if (LayoutUtil::IsDenseArray(lhs)) { 88 if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), 89 LayoutUtil::MinorToMajor(rhs))) { 90 VLOG(3) << "CompareShapes: lhs layout != rhs layout"; 91 return false; 92 } 93 if (!ContainersEqual(lhs.layout().padded_dimensions(), 94 rhs.layout().padded_dimensions())) { 95 VLOG(3) 96 << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; 97 return false; 98 } 99 if (lhs.layout().padding_value() != rhs.layout().padding_value()) { 100 VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; 101 return false; 102 } 103 } 104 } 105 106 if (!ShapeUtil::SameDimensions(lhs, rhs)) { 107 VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; 108 return false; 109 } 110 if (!ShapeUtil::SameElementType(lhs, rhs)) { 111 VLOG(3) << "CompareShapes: lhs element type != rhs element type"; 112 return false; 113 } 114 return true; 115 } 116 117 // Constructs and returns the new shape with the given minor_to_major order in 118 // its Layout. 119 StatusOr<Shape> MakeShapeWithLayoutInternal( 120 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, 121 tensorflow::gtl::ArraySlice<int64> minor_to_major) { 122 if (dimensions.size() != minor_to_major.size()) { 123 return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", 124 dimensions.size(), minor_to_major.size()); 125 } 126 if (element_type == OPAQUE || element_type == TUPLE) { 127 return InvalidArgument("Unsupported element type: %s", 128 PrimitiveType_Name(element_type).c_str()); 129 } 130 Shape shape = ShapeUtil::MakeShape(element_type, dimensions); 131 auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); 132 min2maj->Clear(); 133 for (int64 value : minor_to_major) { 134 min2maj->Add(value); 135 } 136 if (!shape.has_layout()) { 137 return InvalidArgument("Shape has no layout."); 138 } 139 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); 140 return shape; 141 } 142 143 } // namespace 144 145 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { 146 bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); 147 if (!equal && VLOG_IS_ON(3)) { 148 VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() 149 << ", rhs = " << rhs.ShortDebugString(); 150 } 151 152 return equal; 153 } 154 155 /* static */ int64 ShapeUtil::Rank(const Shape& shape) { 156 CHECK(!ShapeUtil::IsTuple(shape)) 157 << "Tuples do not have a rank, shape: " << shape; 158 return shape.dimensions_size(); 159 } 160 161 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { 162 int64 accum = 0; 163 for (int64 dimension : shape.dimensions()) { 164 // We do not count zero dimensions. 165 if (dimension != 1) { 166 accum += 1; 167 } 168 } 169 return accum; 170 } 171 172 /* static */ ProgramShape ShapeUtil::MakeProgramShape( 173 std::initializer_list<Shape> parameters, Shape result) { 174 ProgramShape program_shape; 175 for (const auto& shape : parameters) { 176 *program_shape.add_parameters() = shape; 177 } 178 *program_shape.mutable_result() = std::move(result); 179 return program_shape; 180 } 181 182 /* static */ Shape ShapeUtil::MakeShape( 183 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) { 184 DCHECK_NE(TUPLE, element_type); 185 DCHECK_NE(OPAQUE, element_type); 186 Shape result; 187 PopulateShape(element_type, dimensions, &result); 188 return result; 189 } 190 191 /* static */ Shape ShapeUtil::MakeShapeWithLayout( 192 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, 193 tensorflow::gtl::ArraySlice<int64> minor_to_major) { 194 return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) 195 .ValueOrDie(); 196 } 197 198 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( 199 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) { 200 std::vector<int64> layout(dimensions.size()); 201 std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0)); 202 return MakeShapeWithLayout(element_type, dimensions, layout); 203 } 204 205 /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( 206 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, 207 int64 max_sparse_elements) { 208 DCHECK_NE(TUPLE, element_type); 209 DCHECK_NE(OPAQUE, element_type); 210 Shape shape = ShapeUtil::MakeShape(element_type, dimensions); 211 *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); 212 TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); 213 return shape; 214 } 215 216 /* static */ Shape 217 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( 218 const Shape& shape) { 219 std::vector<int64> dims(shape.dimensions_size()); 220 for (int i = 0; i < shape.dimensions_size(); ++i) { 221 dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); 222 } 223 return MakeShapeWithDescendingLayout(shape.element_type(), dims); 224 } 225 226 /* static */ void ShapeUtil::PopulateShape( 227 PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, 228 Shape* shape) { 229 shape->Clear(); 230 shape->set_element_type(element_type); 231 for (int64 dimension : dimensions) { 232 shape->add_dimensions(dimension); 233 } 234 LayoutUtil::SetToDefaultLayout(shape); 235 TF_DCHECK_OK(ValidateShape(*shape)); 236 } 237 238 /* static */ Shape ShapeUtil::MakeTupleShape( 239 tensorflow::gtl::ArraySlice<Shape> shapes) { 240 Shape result; 241 result.set_element_type(TUPLE); 242 for (const auto& shape : shapes) { 243 AppendShapeToTuple(shape, &result); 244 } 245 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); 246 return result; 247 } 248 249 /* static */ Shape ShapeUtil::MakeOpaqueShape() { 250 Shape result; 251 result.set_element_type(OPAQUE); 252 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); 253 return result; 254 } 255 256 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, 257 Shape* tuple_shape) { 258 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); 259 *tuple_shape->add_tuple_shapes() = shape; 260 } 261 262 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { 263 CHECK(LayoutUtil::IsDenseArray(*shape)); 264 shape->mutable_layout()->add_minor_to_major(Rank(*shape)); 265 shape->add_dimensions(bound); 266 TF_DCHECK_OK(ValidateShape(*shape)); 267 } 268 269 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) { 270 return primitive_util::IsIntegralType(shape.element_type()); 271 } 272 273 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape, 274 int32 bits) { 275 return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits); 276 } 277 278 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { 279 if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { 280 return false; 281 } 282 return primitive_util::BitWidth(shape.element_type()) == bits; 283 } 284 285 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) { 286 switch (shape.element_type()) { 287 case S8: 288 case S16: 289 case S32: 290 case S64: 291 case F16: 292 case BF16: 293 case F32: 294 case F64: 295 return true; 296 297 case PRED: 298 case U8: 299 case U16: 300 case U32: 301 case U64: 302 case C64: 303 case TUPLE: 304 case OPAQUE: 305 return false; 306 307 default: 308 LOG(FATAL) << "Unhandled element type " << shape.element_type(); 309 } 310 } 311 312 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) { 313 return primitive_util::IsComplexType(shape.element_type()); 314 } 315 316 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) { 317 return primitive_util::IsFloatingPointType(shape.element_type()); 318 } 319 320 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { 321 return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), 322 shape.tuple_shapes().end(), IsTuple); 323 } 324 325 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { 326 return IsTuple(shape) && TupleElementCount(shape) == 0; 327 } 328 329 /* static */ bool ShapeUtil::IsNil(const Shape& shape) { 330 return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); 331 } 332 333 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { 334 CHECK(IsTuple(shape)) << HumanString(shape); 335 return shape.tuple_shapes_size(); 336 } 337 338 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape, 339 int64 index) { 340 CHECK(IsTuple(shape)); 341 CHECK_GT(TupleElementCount(shape), index); 342 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index))); 343 return shape.tuple_shapes(index); 344 } 345 346 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, 347 int64 limit) { 348 TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); 349 CHECK(IsTuple(tuple)); 350 CHECK_LE(start, TupleElementCount(tuple)); 351 CHECK_LE(limit, TupleElementCount(tuple)); 352 353 std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start, 354 tuple.tuple_shapes().begin() + limit); 355 return MakeTupleShape(new_elements); 356 } 357 358 // Returns the shape of a real or imaginary component. 359 /* static */ Shape ShapeUtil::ComplexComponentShape( 360 const Shape& complex_shape) { 361 CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape); 362 return ChangeElementType(complex_shape, primitive_util::ComplexComponentType( 363 complex_shape.element_type())); 364 } 365 366 /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, 367 PrimitiveType element_type, 368 std::initializer_list<int64> dimensions) { 369 return Equal(shape, MakeShape(element_type, dimensions)); 370 } 371 372 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { 373 CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); 374 CHECK_EQ(shape.dimensions_size(), Rank(shape)); 375 return std::accumulate<decltype(shape.dimensions().begin()), int64>( 376 shape.dimensions().begin(), shape.dimensions().end(), 1LL, 377 std::multiplies<int64>()); 378 } 379 380 /* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) { 381 return ElementsIn(shape) == 0; 382 } 383 384 /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { 385 return shape.element_type() == F32 && Rank(shape) == 0; 386 } 387 388 /* static */ string ShapeUtil::HumanString(const Shape& shape) { 389 if (IsTuple(shape)) { 390 string text = "("; 391 const char* prefix = ""; 392 for (const Shape& elem_shape : shape.tuple_shapes()) { 393 tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); 394 prefix = ", "; 395 } 396 text += ")"; 397 return text; 398 } else { 399 return tensorflow::strings::StrCat( 400 tensorflow::str_util::Lowercase( 401 PrimitiveType_Name(shape.element_type())), 402 "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); 403 } 404 } 405 406 namespace { 407 408 // Class to memoize the computation of 409 // tensorflow::str_util::Lowercase(PrimitiveType_Name(p)) 410 // for all PrimitiveType values "p" 411 class PrimitiveTypeNameGenerator { 412 public: 413 PrimitiveTypeNameGenerator() { 414 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { 415 if (PrimitiveType_IsValid(i)) { 416 lowercase_name_[i] = tensorflow::str_util::Lowercase( 417 PrimitiveType_Name(static_cast<PrimitiveType>(i))); 418 } 419 } 420 } 421 const string& LowercaseName(PrimitiveType t) { 422 return lowercase_name_[static_cast<int>(t)]; 423 } 424 425 private: 426 string lowercase_name_[PrimitiveType_ARRAYSIZE]; 427 }; 428 429 const string& LowercasePrimitiveTypeName(PrimitiveType s) { 430 static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); 431 return gen->LowercaseName(s); 432 } 433 434 StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) { 435 static std::unordered_map<string, PrimitiveType>* name_to_type = [] { 436 static auto* map = new std::unordered_map<string, PrimitiveType>; 437 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { 438 if (PrimitiveType_IsValid(i)) { 439 auto value = static_cast<PrimitiveType>(i); 440 (*map)[LowercasePrimitiveTypeName(value)] = value; 441 } 442 } 443 return map; 444 }(); 445 auto found = name_to_type->find(name); 446 if (found == name_to_type->end()) { 447 return InvalidArgument("Invalid element type string: \"%s\".", 448 name.c_str()); 449 } 450 return found->second; 451 } 452 453 } // namespace 454 455 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { 456 if (IsTuple(shape)) { 457 string text = "("; 458 const char* prefix = ""; 459 for (const Shape& elem_shape : shape.tuple_shapes()) { 460 tensorflow::strings::StrAppend(&text, prefix, 461 HumanStringWithLayout(elem_shape)); 462 prefix = ", "; 463 } 464 text += ")"; 465 return text; 466 } else { 467 string result = tensorflow::strings::StrCat( 468 LowercasePrimitiveTypeName(shape.element_type()), "["); 469 for (int i = 0; i < shape.dimensions().size(); i++) { 470 tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", 471 shape.dimensions(i)); 472 } 473 result += "]"; 474 if (!IsScalar(shape) && !IsOpaque(shape)) { 475 if (LayoutUtil::HasLayout(shape)) { 476 tensorflow::strings::StrAppend(&result, 477 LayoutUtil::HumanString(shape.layout())); 478 } 479 } 480 return result; 481 } 482 } 483 484 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { 485 std::vector<string> parameters; 486 for (auto& shape : program_shape.parameters()) { 487 const int i = parameters.size(); 488 parameters.push_back( 489 tensorflow::strings::StrCat(i < program_shape.parameter_names_size() 490 ? program_shape.parameter_names(i) 491 : "(unknown)", 492 ": ", HumanString(shape))); 493 } 494 return tensorflow::strings::StrCat( 495 "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", 496 HumanString(program_shape.result())); 497 } 498 499 namespace { 500 // Parses shapes with simple recursive descent structure -- consumes from the 501 // front of s and passes that view recursively as required. 502 StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) { 503 tensorflow::str_util::RemoveLeadingWhitespace(s); 504 505 if (s->Consume("(")) { // Tuple. 506 std::vector<Shape> shapes; 507 bool must_end = false; 508 while (true) { 509 if (s->Consume(")")) { 510 break; 511 } else if (must_end) { 512 return InvalidArgument("Expected end of tuple; got: \"%s\"", 513 s->ToString().c_str()); 514 } 515 shapes.emplace_back(); 516 TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); 517 tensorflow::str_util::RemoveLeadingWhitespace(s); 518 must_end = !s->Consume(","); 519 } 520 return ShapeUtil::MakeTupleShape(shapes); 521 } 522 523 string element_type_string; 524 string dimensions_string; 525 string format_string; 526 string layout_string; 527 // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so 528 // we convert in to the RE2-consumable type and then consume the corresponding 529 // amount from our StringPiece type. 530 tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); 531 if (RE2::Consume( 532 &s_consumable, 533 "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", 534 &element_type_string, &dimensions_string, &format_string, 535 &layout_string)) { 536 size_t consumed = s->size() - s_consumable.size(); 537 s->remove_prefix(consumed); 538 auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> { 539 int64 element; 540 if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { 541 return InvalidArgument( 542 "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", 543 input.c_str(), s->ToString().c_str()); 544 } 545 return element; 546 }; 547 548 auto comma_list_to_int64s = 549 [&s, 550 string_to_int64](const string& input) -> StatusOr<std::vector<int64>> { 551 std::vector<int64> results; 552 for (const string& piece : tensorflow::str_util::Split(input, ',')) { 553 TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); 554 results.push_back(element); 555 } 556 return results; 557 }; 558 559 // Extract the dimensions. 560 TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions, 561 comma_list_to_int64s(dimensions_string)); 562 563 // Extract the primitive element type. 564 TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, 565 StringToPrimitiveType(element_type_string)); 566 if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || 567 primitive_type == OPAQUE) { 568 return InvalidArgument("Invalid element type string: \"%s\".", 569 element_type_string.c_str()); 570 } 571 572 Shape result; 573 if (format_string.empty() && layout_string.empty()) { 574 // Create a shape without a layout set. 575 result = ShapeUtil::MakeShape(primitive_type, dimensions); 576 } else if (format_string == "sparse") { 577 TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); 578 result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, 579 max_elements); 580 } else if (format_string.empty() || format_string == "dense") { 581 // Extract the layout minor-to-major and set it. 582 TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj, 583 comma_list_to_int64s(layout_string)); 584 TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( 585 primitive_type, dimensions, min2maj)); 586 } else { 587 // This should not be reached. 588 LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" 589 << format_string << "\", layout: \"" << layout_string << "\""; 590 } 591 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); 592 return std::move(result); 593 } 594 595 return InvalidArgument("Invalid shape string to parse: \"%s\"", 596 s->ToString().c_str()); 597 } 598 } // namespace 599 600 /* static */ StatusOr<Shape> ShapeUtil::ParseShapeString( 601 tensorflow::StringPiece s) { 602 TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); 603 if (!s.empty()) { 604 return InvalidArgument("Invalid shape string to parse: \"%s\"", 605 s.ToString().c_str()); 606 } 607 return shape; 608 } 609 610 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, 611 const Shape& rhs) { 612 return ContainersEqual(lhs.dimensions(), rhs.dimensions()); 613 } 614 615 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { 616 if (lhs.element_type() == TUPLE) { 617 return rhs.element_type() == TUPLE && 618 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); 619 } 620 return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); 621 } 622 623 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, 624 const Shape& rhs) { 625 if (lhs.element_type() == TUPLE) { 626 return rhs.element_type() == TUPLE && 627 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), 628 CompatibleIgnoringElementType); 629 } 630 return SameDimensions(lhs, rhs); 631 } 632 633 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, 634 const Shape& rhs) { 635 if (lhs.element_type() == TUPLE) { 636 return rhs.element_type() == TUPLE && 637 ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), 638 CompatibleIgnoringFpPrecision); 639 } 640 if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { 641 return CompatibleIgnoringElementType(lhs, rhs); 642 } 643 return false; 644 } 645 646 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, 647 int64 dimension_number) { 648 return shape.dimensions(GetDimensionNumber(shape, dimension_number)); 649 } 650 651 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, 652 int64 dimension_number) { 653 if (dimension_number < 0) { 654 dimension_number += Rank(shape); 655 } 656 CHECK_GE(dimension_number, 0); 657 return dimension_number; 658 } 659 660 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType( 661 PrimitiveType primitive_type) { 662 switch (primitive_type) { 663 case PRED: 664 return sizeof(int8); 665 case TUPLE: 666 LOG(FATAL) << "tuples have no definitive size"; 667 case OPAQUE: 668 LOG(FATAL) << "opaque have no definitive size"; 669 case S8: 670 return sizeof(int8); 671 case S16: 672 return sizeof(int16); 673 case S32: 674 return sizeof(int32); 675 case S64: 676 return sizeof(int64); 677 case U8: 678 return sizeof(uint8); 679 case U16: 680 return sizeof(uint16); 681 case U32: 682 return sizeof(uint32); 683 case U64: 684 return sizeof(uint64); 685 case BF16: 686 return sizeof(float) / 2; 687 case F16: 688 return sizeof(float) / 2; 689 case F32: 690 return sizeof(float); 691 case F64: 692 return sizeof(double); 693 case C64: 694 return sizeof(complex64); 695 default: 696 LOG(FATAL) << "Unhandled primitive type " << primitive_type; 697 } 698 } 699 700 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, 701 int64 pointer_size) { 702 TF_DCHECK_OK(ValidateShape(shape)); 703 DCHECK_NE(OPAQUE, shape.element_type()); 704 if (shape.element_type() == TUPLE) { 705 return ByteSizeOfTupleIndexTable(shape, pointer_size); 706 } 707 int64 byte_size = ByteSizeOfElements(shape); 708 if (LayoutUtil::IsSparseArray(shape)) { 709 byte_size += ByteSizeOfSparseIndices(shape); 710 } 711 return byte_size; 712 } 713 714 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, 715 int64 pointer_size) { 716 TF_DCHECK_OK(ValidateShape(shape)); 717 DCHECK_EQ(TUPLE, shape.element_type()); 718 CHECK_GT(pointer_size, 0); 719 return pointer_size * shape.tuple_shapes_size(); 720 } 721 722 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { 723 TF_DCHECK_OK(ValidateShape(shape)); 724 DCHECK(ShapeUtil::IsArray(shape)); 725 int64 allocated_element_count; 726 727 if (LayoutUtil::IsSparseArray(shape)) { 728 allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); 729 } else { 730 CHECK(LayoutUtil::IsDenseArray(shape)); 731 tensorflow::gtl::ArraySlice<int64> padded_dimensions = 732 LayoutUtil::PaddedDimensions(shape); 733 if (!padded_dimensions.empty()) { 734 CHECK_EQ(Rank(shape), padded_dimensions.size()); 735 allocated_element_count = 1; 736 for (int64 dimension_size : padded_dimensions) { 737 allocated_element_count *= dimension_size; 738 } 739 } else { 740 allocated_element_count = ElementsIn(shape); 741 } 742 } 743 return allocated_element_count * 744 ByteSizeOfPrimitiveType(shape.element_type()); 745 } 746 747 /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { 748 TF_DCHECK_OK(ValidateShape(shape)); 749 DCHECK(LayoutUtil::IsSparseArray(shape)); 750 return LayoutUtil::MaxSparseElements(shape.layout()) * 751 ShapeUtil::Rank(shape) * sizeof(int64); 752 } 753 754 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( 755 const Shape& shape) { 756 if (shape.element_type() == TUPLE) { 757 if (shape.dimensions_size() != 0) { 758 return InvalidArgument("tuples must not have dimensions specified"); 759 } 760 for (auto& element_shape : shape.tuple_shapes()) { 761 TF_RETURN_IF_ERROR( 762 ValidateShapeWithOptionalLayoutInternal(element_shape)); 763 } 764 return Status::OK(); 765 } 766 767 // Non-tuple shape. 768 if (shape.tuple_shapes_size() > 0) { 769 return InvalidArgument("non-tuple shape has tuple_shapes field"); 770 } 771 if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { 772 return InvalidArgument("shape has invalid element type: %s", 773 shape.ShortDebugString().c_str()); 774 } 775 if (Rank(shape) != shape.dimensions_size()) { 776 return InvalidArgument( 777 "shape's rank is mismatched with dimension count; rank=%lld " 778 "dimensions_size=%d", 779 Rank(shape), shape.dimensions_size()); 780 } 781 for (int64 i = 0; i < Rank(shape); ++i) { 782 int64 dimension = shape.dimensions(i); 783 if (dimension < 0) { 784 return InvalidArgument( 785 "shape's dimensions must not be < 0; dimension at index %lld was " 786 "%lld", 787 i, dimension); 788 } 789 } 790 791 return Status::OK(); 792 } 793 794 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout( 795 const Shape& shape) { 796 if (LayoutUtil::HasLayout(shape)) { 797 // Since a layout is present, upgrade to the full set of invariant checks. 798 return ValidateShape(shape); 799 } 800 return ValidateShapeWithOptionalLayoutInternal(shape); 801 } 802 803 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) { 804 TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape)); 805 806 return LayoutUtil::ValidateLayoutInShape(shape); 807 } 808 809 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, 810 PrimitiveType type) { 811 Shape new_shape = original; 812 new_shape.set_element_type(type); 813 return new_shape; 814 } 815 816 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape, 817 ShapeIndexView index) { 818 const Shape* return_shape = &shape; 819 for (auto i : index) { 820 CHECK(IsTuple(*return_shape)) 821 << "Invalid index " << index << " for shape " << shape; 822 return_shape = &return_shape->tuple_shapes(i); 823 } 824 return *return_shape; 825 } 826 827 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape, 828 ShapeIndexView index) { 829 Shape* return_shape = shape; 830 for (auto i : index) { 831 CHECK(IsTuple(*return_shape)); 832 return_shape = return_shape->mutable_tuple_shapes(i); 833 } 834 return return_shape; 835 } 836 837 /* static */ 838 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { 839 return !IsTuple(GetSubshape(shape, index)); 840 } 841 842 /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { 843 std::vector<int64> dimension_sizes; 844 std::vector<int64> degenerate_dimensions; 845 for (int64 i = 0; i < shape.dimensions_size(); ++i) { 846 if (shape.dimensions(i) == 1) { 847 degenerate_dimensions.push_back(i); 848 } else { 849 dimension_sizes.push_back(shape.dimensions(i)); 850 } 851 } 852 853 // Construct minor_to_major of stripped shape. The order of the non-degenerate 854 // dimensions should be preserved from the original shape. First, create 855 // vector of the non-degenerate dimensions from the original minor_to_major 856 // array. 857 std::vector<int64> minor_to_major; 858 for (int64 i : shape.layout().minor_to_major()) { 859 if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), 860 i) == degenerate_dimensions.end()) { 861 minor_to_major.push_back(i); 862 } 863 } 864 865 // The dimensions in minor_to_major need to be renumbered to account for the 866 // degenerate dimensions which have removed. Decrement each dimension number 867 // once for each degenerate dimension which has a smaller number. 868 for (int i = 0; i < minor_to_major.size(); ++i) { 869 int adjustment = 0; 870 for (int64 dim : degenerate_dimensions) { 871 if (minor_to_major[i] > dim) { 872 adjustment++; 873 } 874 } 875 minor_to_major[i] -= adjustment; 876 } 877 878 { 879 std::vector<int64> dims(minor_to_major.size()); 880 std::iota(dims.begin(), dims.end(), 0); 881 DCHECK(minor_to_major.size() == dims.size() && 882 std::is_permutation(minor_to_major.begin(), minor_to_major.end(), 883 dims.begin())); 884 } 885 Shape stripped_shape = 886 shape.has_layout() ? MakeShapeWithLayout(shape.element_type(), 887 dimension_sizes, minor_to_major) 888 : MakeShape(shape.element_type(), dimension_sizes); 889 890 VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); 891 VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); 892 return stripped_shape; 893 } 894 895 namespace { 896 897 // Helper for ForEachSubshape which visits the subshapes of the given shape in 898 // DFS pre-order starting with the index. 899 Status ForEachSubshapeHelper(const Shape& shape, 900 const ShapeUtil::StatusVisitorFunction& func, 901 ShapeIndex* index) { 902 TF_RETURN_IF_ERROR(func(shape, *index)); 903 if (ShapeUtil::IsTuple(shape)) { 904 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { 905 index->push_back(i); 906 TF_RETURN_IF_ERROR(ForEachSubshapeHelper( 907 ShapeUtil::GetTupleElementShape(shape, i), func, index)); 908 index->pop_back(); 909 } 910 } 911 return Status::OK(); 912 } 913 914 // Helper for ForEachMutableSubshape which visits the subshapes of the given 915 // shape in DFS pre-order starting with the index. 916 Status ForEachMutableSubshapeHelper( 917 Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func, 918 ShapeIndex* index) { 919 TF_RETURN_IF_ERROR(func(shape, *index)); 920 if (ShapeUtil::IsTuple(*shape)) { 921 for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) { 922 index->push_back(i); 923 TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper( 924 shape->mutable_tuple_shapes(i), func, index)); 925 index->pop_back(); 926 } 927 } 928 return Status::OK(); 929 } 930 931 } // namespace 932 933 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape, 934 const VisitorFunction& func) { 935 ShapeIndex index; 936 ForEachSubshapeHelper( 937 shape, 938 [&func](const Shape& subshape, const ShapeIndex& index) { 939 func(subshape, index); 940 return Status::OK(); 941 }, 942 &index) 943 .IgnoreError(); 944 } 945 946 /* static */ void ShapeUtil::ForEachMutableSubshape( 947 Shape* shape, const MutatingVisitorFunction& func) { 948 ShapeIndex index; 949 ForEachMutableSubshapeHelper( 950 shape, 951 [&func](Shape* subshape, const ShapeIndex& index) { 952 func(subshape, index); 953 return Status::OK(); 954 }, 955 &index) 956 .IgnoreError(); 957 } 958 959 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus( 960 const Shape& shape, const StatusVisitorFunction& func) { 961 ShapeIndex index; 962 return ForEachSubshapeHelper(shape, func, &index); 963 } 964 965 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus( 966 Shape* shape, const MutatingStatusVisitorFunction& func) { 967 ShapeIndex index; 968 return ForEachMutableSubshapeHelper(shape, func, &index); 969 } 970 971 /* static */ Shape ShapeUtil::PermuteDimensions( 972 tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape) { 973 Shape new_shape = shape; 974 new_shape.clear_dimensions(); 975 for (auto dim : Permute(permutation, shape.dimensions())) { 976 new_shape.add_dimensions(dim); 977 } 978 if (shape.has_layout()) { 979 CHECK(LayoutUtil::IsDenseArray(shape)); 980 Layout* new_layout = new_shape.mutable_layout(); 981 new_layout->set_format(DENSE); 982 new_layout->clear_minor_to_major(); 983 for (auto index : Permute(permutation, shape.layout().minor_to_major())) { 984 new_layout->add_minor_to_major(index); 985 } 986 if (shape.layout().padded_dimensions_size() > 0) { 987 new_layout->clear_padded_dimensions(); 988 for (auto dim : 989 Permute(permutation, shape.layout().padded_dimensions())) { 990 new_layout->add_padded_dimensions(dim); 991 } 992 } 993 } 994 return new_shape; 995 } 996 997 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>> 998 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, 999 const Shape& shape_post) { 1000 auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>()); 1001 1002 std::vector<int64> deleted_indices; 1003 std::vector<int64> inserted_indices; 1004 // Returns false if any input/output index between prior_unmodified_dim_pair 1005 // and unmodified_dim_pair have size >1. Otherwise, returns true and appends 1006 // the degerenate input/output dimensions in the gap to 1007 // deleted_indices/inserted_indices respectively. 1008 auto check_modified_dims = 1009 [&shape_pre, &shape_post, &deleted_indices, &inserted_indices]( 1010 std::pair<int64, int64> prior_unmodified_dim_pair, 1011 std::pair<int64, int64> unmodified_dim_pair) { 1012 for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; 1013 modified_input_dim < unmodified_dim_pair.first; 1014 ++modified_input_dim) { 1015 if (shape_pre.dimensions(modified_input_dim) > 1) { 1016 return false; 1017 } 1018 deleted_indices.push_back(modified_input_dim); 1019 } 1020 for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; 1021 modified_output_dim < unmodified_dim_pair.second; 1022 ++modified_output_dim) { 1023 if (shape_post.dimensions(modified_output_dim) > 1) { 1024 return false; 1025 } 1026 inserted_indices.push_back(modified_output_dim); 1027 } 1028 return true; 1029 }; 1030 1031 std::vector<std::pair<int64, int64>> unmodified_dims = 1032 DimensionsUnmodifiedByReshape(shape_pre, shape_post); 1033 // Returns nil if the reshape modifies any non-degenerate input/output 1034 // dimension. DimensionsUnmodifiedByReshape gives us all unmodified 1035 // dimensions, so we only need to check whether dimensions in the gaps (thus 1036 // modified) have size >1. 1037 for (size_t i = 0; i <= unmodified_dims.size(); ++i) { 1038 // Check (modified) dimensions between unmodified_dims[i-1] and 1039 // unmodified_dims[i]. 1040 auto prior_unmodified_dim_pair = 1041 i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL); 1042 auto unmodified_dim_pair = 1043 i < unmodified_dims.size() 1044 ? unmodified_dims[i] 1045 : std::make_pair(Rank(shape_pre), Rank(shape_post)); 1046 if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { 1047 return nil; 1048 } 1049 } 1050 1051 return std::make_tuple(true, deleted_indices, inserted_indices); 1052 } 1053 1054 /* static */ std::vector<std::pair<int64, int64>> 1055 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, 1056 const Shape& output_shape) { 1057 // Unmodified dimensions are merely common factors of rank 1. 1058 auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), 1059 AsInt64Slice(output_shape.dimensions())); 1060 for (size_t i = 0; i < common_factors.size() - 1;) { 1061 if (1 != common_factors[i + 1].first - common_factors[i].first || 1062 1 != common_factors[i + 1].second - common_factors[i].second) { 1063 common_factors.erase(common_factors.begin() + i); 1064 } else { 1065 ++i; 1066 } 1067 } 1068 // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it. 1069 common_factors.pop_back(); 1070 return common_factors; 1071 } 1072 1073 /* static */ bool ShapeUtil::TransposeIsBitcast( 1074 const Shape& input_shape, const Shape& output_shape, 1075 tensorflow::gtl::ArraySlice<int64> dimension_mapping) { 1076 // Can't insert bitcasts without layout information. 1077 if (!LayoutUtil::HasLayout(input_shape) && 1078 !LayoutUtil::HasLayout(output_shape)) { 1079 return false; 1080 } 1081 1082 // Padding is not handled. 1083 if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) { 1084 return false; 1085 } 1086 1087 // Check the reshape permutes the positions of each dimension in the 1088 // minor-to-major order. positions[i]=k means dimension `i` is k-th minor. 1089 // input_positions = apply(dimension_mapping, output_positions) 1090 // 1091 // Because the positions of each dimension are the inverse permutation of the 1092 // minor-to-major order, the above check is equivalent to 1093 // inverse(input_dimensions) = 1094 // apply(dimension_mapping, inverse(output_dimensions)) 1095 // # `I` indicates identity permutation. 1096 // apply(input_dimensions, I) = 1097 // apply(dimension_mapping, apply(output_dimensions, I)) 1098 // apply(input_dimensions, I) = 1099 // apply((dimension_mapping * output_dimensions), I) 1100 // input_dimensions = dimension_mapping * output_dimensions 1101 return ContainersEqual( 1102 ComposePermutations(dimension_mapping, 1103 AsInt64Slice(output_shape.layout().minor_to_major())), 1104 input_shape.layout().minor_to_major()); 1105 } 1106 1107 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, 1108 const Shape& output_shape) { 1109 // Can't convert reshapes into bitcasts without layout information. 1110 if (!LayoutUtil::HasLayout(input_shape) || 1111 !LayoutUtil::HasLayout(output_shape)) { 1112 return false; 1113 } 1114 1115 // Padding is not handled. 1116 if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) { 1117 return false; 1118 } 1119 1120 CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape)); 1121 if (ElementsIn(input_shape) == 0) { 1122 return true; 1123 } 1124 1125 // TL;DR: The rest of the method checks that the reshape does not change the 1126 // physical location of any unit input or output index. Unit indices have 1127 // exactly one dimension that equals 1 and other dimensions 0. This condition 1128 // is necessary for the reshape to be a bitcast, because a bitcast-equivalent 1129 // reshape shouldn't change the physical location of any element. It is also a 1130 // sufficient condition as is proved below (note: many details are omitted for 1131 // space). 1132 // 1133 // Definitions: 1134 // 1135 // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means 1136 // the size of i-th least significant dimension of IS or OS (this is opposite 1137 // to how we define the index of Shape::dimensions()). 1138 // 1139 // * Given an input or output index I, denote by p(I) I's physical linear 1140 // index (or physical index for short) and l(I) I's logical linear index (or 1141 // logical index for short). 1142 // 1143 // * Given a logical index k, denote by II(k) the input index whose linear 1144 // index is k, and OI(k) the corresponding output index. 1145 // 1146 // * Denote by IT[i] the increment of physical index if i-th dimension of the 1147 // input index is increased by 1. Similarly, OT[i] means the increment if i-th 1148 // dimension of the output index is increased by 1. Note that IT[i] or OT[i] 1149 // is a function of IS or OS and the layout, and not dependent on the specific 1150 // input or output index. 1151 // 1152 // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove 1153 // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by 1154 // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is 1155 // to prove, with every increment on k, the above formula still holds. 1156 // 1157 // First, suppose reshaping from IS to OS is non-factorizable (we discuss 1158 // refactorizable reshapes later). A reshape from IS to OS is factorizable, if 1159 // there exists (i,j) such that 1160 // 1161 // 0<=i<=|IS| 1162 // 0<=j<=|OS| 1163 // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end) 1164 // product(IS[i], IS[i+1], ..., IS[|IS|-1]) 1165 // = product(OS[j], OS[j+1], ..., OS[|OS|-1]) 1166 // 1167 // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0)) 1168 // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are 1169 // unit indices which are already tested. This also means IT[0]=OT[0] 1170 // because p(II(1))=IT[0] and p(OI(1))=OT[0]. 1171 // 1172 // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each 1173 // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0]) 1174 // to the output physical. 1175 // 1176 // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality, 1177 // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0]. 1178 // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From 1179 // logical index k-1 to logical index k, dimension 1 of the input index 1180 // is increased by 1 and dimension 0 is reset to 0 thus decreased by 1181 // IS[0]-1. Therefore, the physical input index is increased by 1182 // 1183 // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0] 1184 // 1185 // Because IS[0]<OS[0], the only change to the output index is that its 1186 // dimension 0 is increased by one. Therefore, 1187 // 1188 // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0] 1189 // 1190 // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that 1191 // p(II(k))=p(OI(k)). Therefore, 1192 // IT[1] - (IS[0]-1) * IT[0] = IT[0] 1193 // IT[1] = IS[0] * IT[0] 1194 // In other words, input dimension 1 is immediately more major than input 1195 // dimension 0. We can now conceptually collapse these two dimensions because 1196 // an increment in the logical index affecting only these two dimensions maps 1197 // to IT[0] in the physical index. 1198 // 1199 // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and 1200 // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise 1201 // identical. 1202 // 1203 // A factorizable reshape can be factorized into a list of non-factorizable 1204 // sub-reshapes, each of which can be handled similarly to the proof above. 1205 // For example, 1206 // 1207 // [7x9x2x15] -> [63x6x5] 1208 // 1209 // can be factorized into 1210 // 1211 // [7x9] -> [63] and [2x15] -> [6x5]. 1212 // 1213 // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the 1214 // same logical linear index. According to the factorization, we know 1215 // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for 1216 // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a 1217 // similar proof, with the increment of the logical index set to 1218 // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove 1219 // p(x3,x2,0,0)=p(y2,0,0) too. Therefore, 1220 // 1221 // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0) 1222 // = p(y2,0,0) + p(0,0,y1,y0) 1223 // = p(y2,y1,y0) 1224 // 1225 // check_input_unit_indices checks one way of the condition: each input unit 1226 // index is mapped to an output index with the same physical location. This 1227 // lambda will be called again with input_shape and output_shape reversed to 1228 // check the other way. 1229 auto check_input_unit_indices = [](const Shape& input_shape, 1230 const Shape& output_shape) { 1231 // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions" 1232 // as input_shape/output_shape and the dimension-0-major layout. These two 1233 // shapes are used for conversion between logical linear indices and 1234 // multi-dimensional indices. 1235 Shape input_shape_dim0_major = MakeShapeWithDescendingLayout( 1236 input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); 1237 Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( 1238 output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); 1239 1240 for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { 1241 if (input_shape.dimensions(input_dim) <= 1) { 1242 continue; 1243 } 1244 1245 std::vector<int64> input_unit_index(Rank(input_shape), 0); 1246 input_unit_index[input_dim] = 1; 1247 int64 logical_linear_index = 1248 IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, 1249 input_unit_index); 1250 // output_index has the same logical linear index as input_unit_index. 1251 std::vector<int64> output_index = 1252 IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major, 1253 logical_linear_index); 1254 // Check input_unit_index and output_index have the same physical linear 1255 // index. 1256 if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape, 1257 input_unit_index) != 1258 IndexUtil::MultidimensionalIndexToLinearIndex(output_shape, 1259 output_index)) { 1260 return false; 1261 } 1262 } 1263 return true; 1264 }; 1265 return check_input_unit_indices(input_shape, output_shape) && 1266 check_input_unit_indices(output_shape, input_shape); 1267 } 1268 1269 /* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts( 1270 const Shape& input_shape, const Shape& output_shape) { 1271 int64 input_rank = Rank(input_shape); 1272 int64 output_rank = Rank(output_shape); 1273 1274 // First, calculate an alignment of the dimensions. A consecutive sequence of 1275 // input dimensions and output dimensions belong to the same alignment part if 1276 // the products of their dimension bounds are the same. In the easiest case, 1277 // an alignment part consists of one input dimension and one output dimension 1278 // which both have the same dimension bound. An alignment part specifies which 1279 // dimensions need to be kept together in a physical layout if we want a 1280 // reshape to be a bitcast. The order of the alignment parts is defined by the 1281 // physical layout of the input shape, so when we construct the layout for the 1282 // output shape we just process the alignment parts in this order, and then 1283 // layout the dimensions belonging to each part in descending (major to minor) 1284 // order. 1285 1286 // Stores the input and output dimension numbers where each alignment part 1287 // starts. 1288 std::vector<std::pair<int64, int64>> alignment; 1289 alignment.push_back({0, 0}); 1290 1291 // Stores a mapping from the input dimension to the alignment part it belongs 1292 // to. 1293 std::vector<int64> dimension_to_alignment_index(input_rank); 1294 int64 input_dimension_product = 1, output_dimension_product = 1; 1295 for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) { 1296 // Check if we have reached the end of an alignment part. 1297 if (input_dimension_product == output_dimension_product && 1298 input_dimension_product > 1) { 1299 alignment.push_back({i, j}); 1300 input_dimension_product = output_dimension_product = 1; 1301 } 1302 if (input_dimension_product < output_dimension_product || 1303 j == output_rank) { 1304 if (i == input_rank) { 1305 return tensorflow::gtl::nullopt; 1306 } 1307 dimension_to_alignment_index[i] = alignment.size() - 1; 1308 input_dimension_product *= input_shape.dimensions(i); 1309 ++i; 1310 } else { 1311 output_dimension_product *= output_shape.dimensions(j); 1312 ++j; 1313 } 1314 } 1315 if (input_dimension_product != output_dimension_product) { 1316 return tensorflow::gtl::nullopt; 1317 } 1318 // We also need to store an end element so that we know where the last 1319 // alignment part ends. 1320 alignment.push_back({input_rank, output_rank}); 1321 1322 // Now check if the physical layout can potentially be aligned to the output 1323 // shape by changing the physical layout of the output shape. We need to check 1324 // that all dimension numbers that belong to the same alignment part appear 1325 // consecutively, and are in descending order. However we can ignore any 1326 // trivial dimension bounds of 1, because they can be placed anywhere. 1327 auto input_dimension_numbers = input_shape.layout().minor_to_major(); 1328 std::vector<int64> output_layout; 1329 output_layout.reserve(output_rank); 1330 for (int64 i = 0; i < input_rank;) { 1331 int64 current_dimension_number = input_dimension_numbers[i]; 1332 1333 // Skip trivial dimensions with a bound of 1. 1334 if (input_shape.dimensions(current_dimension_number) == 1) { 1335 ++i; 1336 continue; 1337 } 1338 1339 // Calculate the number of non-trivial dimension bounds in the input shape 1340 // belonging to the current alignment part. 1341 const int64 current_alignment_index = 1342 dimension_to_alignment_index[current_dimension_number]; 1343 // Because of the special end element that we added, we can be sure that 1344 // 'current_alignment_index' is < alignment.size() - 1. 1345 CHECK_LT(current_alignment_index, alignment.size() - 1); 1346 int64 num_non_trivial_dimensions_in_alignment_part = 0; 1347 for (int64 j = alignment[current_alignment_index].first; 1348 j < alignment[current_alignment_index + 1].first; ++j) { 1349 if (input_shape.dimensions(j) != 1) { 1350 ++num_non_trivial_dimensions_in_alignment_part; 1351 } 1352 } 1353 1354 // Check that the following 'num_non_trivial_dimensions_in_alignment_part' 1355 // dimension numbers (ignoring dimension numbers with dimension bound 1) are 1356 // in descending order and belong to the current alignment part. 1357 for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; 1358 ++i, ++j) { 1359 if (i == input_rank) { 1360 return tensorflow::gtl::nullopt; 1361 } 1362 // Skip trivial dimensions with a bound of 1. 1363 if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { 1364 --j; 1365 continue; 1366 } 1367 // If the current dimension number belongs to a different alignment part, 1368 // or the dimension numbers are not in descending order, we can return 1369 // early. 1370 if (dimension_to_alignment_index[input_dimension_numbers[i]] != 1371 current_alignment_index || 1372 input_dimension_numbers[i] > current_dimension_number) { 1373 return tensorflow::gtl::nullopt; 1374 } 1375 current_dimension_number = input_dimension_numbers[i]; 1376 } 1377 1378 // The output dimension numbers that belong to the current alignment part 1379 // need to appear in the same descending order as in the input. Again, we 1380 // can skip dimensions with a bound of 1. 1381 for (int64 j = alignment[current_alignment_index + 1].second - 1; 1382 j >= alignment[current_alignment_index].second; --j) { 1383 if (output_shape.dimensions(j) != 1) { 1384 output_layout.push_back(j); 1385 } 1386 } 1387 } 1388 // Now add all the dimensions with dimension bound 1 at the end of 1389 // 'output_layout'. 1390 for (int64 i = 0; i < output_rank; ++i) { 1391 if (output_shape.dimensions(i) == 1) { 1392 output_layout.push_back(i); 1393 } 1394 } 1395 CHECK_EQ(output_layout.size(), output_rank); 1396 Shape output_shape_with_layout = MakeShapeWithLayout( 1397 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), 1398 output_layout); 1399 CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)); 1400 return output_shape_with_layout; 1401 } 1402 1403 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, 1404 Shape shape) { 1405 shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); 1406 if (LayoutUtil::HasLayout(shape)) { 1407 Layout* layout = shape.mutable_layout(); 1408 layout->set_format(DENSE); 1409 for (size_t i = 0; i < layout->minor_to_major().size();) { 1410 if (layout->minor_to_major(i) == dim_to_delete) { 1411 layout->mutable_minor_to_major()->erase( 1412 layout->minor_to_major().begin() + i); 1413 continue; 1414 } 1415 if (layout->minor_to_major(i) > dim_to_delete) { 1416 (*layout->mutable_minor_to_major())[i] -= 1; 1417 } 1418 ++i; 1419 } 1420 } 1421 return shape; 1422 } 1423 1424 /* static */ Shape ShapeUtil::FilterDimensions( 1425 const std::function<bool(int64)>& p, Shape shape) { 1426 std::vector<int64> dims_to_delete; 1427 for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { 1428 if (!p(i)) { 1429 dims_to_delete.push_back(i); 1430 } 1431 } 1432 for (int64 dim : dims_to_delete) { 1433 shape = DeleteDimension(dim, shape); 1434 } 1435 return shape; 1436 } 1437 1438 std::ostream& operator<<(std::ostream& out, const Shape& shape) { 1439 out << ShapeUtil::HumanString(shape); 1440 return out; 1441 } 1442 1443 } // namespace xla 1444