Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/shape_util.h"
     17 
     18 #include <algorithm>
     19 #include <functional>
     20 #include <numeric>
     21 #include <unordered_map>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/index_util.h"
     26 #include "tensorflow/compiler/xla/layout_util.h"
     27 #include "tensorflow/compiler/xla/primitive_util.h"
     28 #include "tensorflow/compiler/xla/status_macros.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/core/stringpiece.h"
     33 #include "tensorflow/core/lib/gtl/iterator_range.h"
     34 #include "tensorflow/core/lib/gtl/optional.h"
     35 #include "tensorflow/core/lib/strings/numbers.h"
     36 #include "tensorflow/core/lib/strings/str_util.h"
     37 #include "tensorflow/core/lib/strings/strcat.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/platform/protobuf.h"
     40 #include "tensorflow/core/platform/regexp.h"
     41 
     42 namespace xla {
     43 
     44 string ShapeIndex::ToString() const {
     45   return tensorflow::strings::StrCat(
     46       "{", tensorflow::str_util::Join(indices_, ","), "}");
     47 }
     48 
     49 string ShapeIndexView::ToString() const {
     50   return tensorflow::strings::StrCat(
     51       "{",
     52       tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_),
     53                                  ","),
     54       "}");
     55 }
     56 
     57 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
     58   out << shape_index.ToString();
     59   return out;
     60 }
     61 
     62 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
     63   out << shape_index.ToString();
     64   return out;
     65 }
     66 
     67 namespace {
     68 
     69 // Recursive helper for comparing the equality of two shapes. Returns true if
     70 // the shapes are the same. If compare_layouts is true, then layouts must also
     71 // match.
     72 bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
     73   if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) {
     74     return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) &&
     75            ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
     76                            [=](const Shape& l, const Shape& r) {
     77                              return CompareShapes(l, r, compare_layouts);
     78                            });
     79   } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) {
     80     return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs);
     81   }
     82 
     83   if (compare_layouts) {
     84     if (lhs.layout().format() != rhs.layout().format()) {
     85       return false;
     86     }
     87     if (LayoutUtil::IsDenseArray(lhs)) {
     88       if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
     89                            LayoutUtil::MinorToMajor(rhs))) {
     90         VLOG(3) << "CompareShapes: lhs layout != rhs layout";
     91         return false;
     92       }
     93       if (!ContainersEqual(lhs.layout().padded_dimensions(),
     94                            rhs.layout().padded_dimensions())) {
     95         VLOG(3)
     96             << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions";
     97         return false;
     98       }
     99       if (lhs.layout().padding_value() != rhs.layout().padding_value()) {
    100         VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value";
    101         return false;
    102       }
    103     }
    104   }
    105 
    106   if (!ShapeUtil::SameDimensions(lhs, rhs)) {
    107     VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
    108     return false;
    109   }
    110   if (!ShapeUtil::SameElementType(lhs, rhs)) {
    111     VLOG(3) << "CompareShapes: lhs element type != rhs element type";
    112     return false;
    113   }
    114   return true;
    115 }
    116 
    117 // Constructs and returns the new shape with the given minor_to_major order in
    118 // its Layout.
    119 StatusOr<Shape> MakeShapeWithLayoutInternal(
    120     PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
    121     tensorflow::gtl::ArraySlice<int64> minor_to_major) {
    122   if (dimensions.size() != minor_to_major.size()) {
    123     return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
    124                            dimensions.size(), minor_to_major.size());
    125   }
    126   if (element_type == OPAQUE || element_type == TUPLE) {
    127     return InvalidArgument("Unsupported element type: %s",
    128                            PrimitiveType_Name(element_type).c_str());
    129   }
    130   Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
    131   auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
    132   min2maj->Clear();
    133   for (int64 value : minor_to_major) {
    134     min2maj->Add(value);
    135   }
    136   if (!shape.has_layout()) {
    137     return InvalidArgument("Shape has no layout.");
    138   }
    139   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
    140   return shape;
    141 }
    142 
    143 }  // namespace
    144 
    145 /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
    146   bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
    147   if (!equal && VLOG_IS_ON(3)) {
    148     VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
    149             << ", rhs = " << rhs.ShortDebugString();
    150   }
    151 
    152   return equal;
    153 }
    154 
    155 /* static */ int64 ShapeUtil::Rank(const Shape& shape) {
    156   CHECK(!ShapeUtil::IsTuple(shape))
    157       << "Tuples do not have a rank, shape: " << shape;
    158   return shape.dimensions_size();
    159 }
    160 
    161 /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
    162   int64 accum = 0;
    163   for (int64 dimension : shape.dimensions()) {
    164     // We do not count zero dimensions.
    165     if (dimension != 1) {
    166       accum += 1;
    167     }
    168   }
    169   return accum;
    170 }
    171 
    172 /* static */ ProgramShape ShapeUtil::MakeProgramShape(
    173     std::initializer_list<Shape> parameters, Shape result) {
    174   ProgramShape program_shape;
    175   for (const auto& shape : parameters) {
    176     *program_shape.add_parameters() = shape;
    177   }
    178   *program_shape.mutable_result() = std::move(result);
    179   return program_shape;
    180 }
    181 
    182 /* static */ Shape ShapeUtil::MakeShape(
    183     PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
    184   DCHECK_NE(TUPLE, element_type);
    185   DCHECK_NE(OPAQUE, element_type);
    186   Shape result;
    187   PopulateShape(element_type, dimensions, &result);
    188   return result;
    189 }
    190 
    191 /* static */ Shape ShapeUtil::MakeShapeWithLayout(
    192     PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
    193     tensorflow::gtl::ArraySlice<int64> minor_to_major) {
    194   return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
    195       .ValueOrDie();
    196 }
    197 
    198 /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
    199     PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
    200   std::vector<int64> layout(dimensions.size());
    201   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
    202   return MakeShapeWithLayout(element_type, dimensions, layout);
    203 }
    204 
    205 /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
    206     PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
    207     int64 max_sparse_elements) {
    208   DCHECK_NE(TUPLE, element_type);
    209   DCHECK_NE(OPAQUE, element_type);
    210   Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
    211   *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
    212   TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
    213   return shape;
    214 }
    215 
    216 /* static */ Shape
    217 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
    218     const Shape& shape) {
    219   std::vector<int64> dims(shape.dimensions_size());
    220   for (int i = 0; i < shape.dimensions_size(); ++i) {
    221     dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
    222   }
    223   return MakeShapeWithDescendingLayout(shape.element_type(), dims);
    224 }
    225 
    226 /* static */ void ShapeUtil::PopulateShape(
    227     PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
    228     Shape* shape) {
    229   shape->Clear();
    230   shape->set_element_type(element_type);
    231   for (int64 dimension : dimensions) {
    232     shape->add_dimensions(dimension);
    233   }
    234   LayoutUtil::SetToDefaultLayout(shape);
    235   TF_DCHECK_OK(ValidateShape(*shape));
    236 }
    237 
    238 /* static */ Shape ShapeUtil::MakeTupleShape(
    239     tensorflow::gtl::ArraySlice<Shape> shapes) {
    240   Shape result;
    241   result.set_element_type(TUPLE);
    242   for (const auto& shape : shapes) {
    243     AppendShapeToTuple(shape, &result);
    244   }
    245   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
    246   return result;
    247 }
    248 
    249 /* static */ Shape ShapeUtil::MakeOpaqueShape() {
    250   Shape result;
    251   result.set_element_type(OPAQUE);
    252   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
    253   return result;
    254 }
    255 
    256 /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
    257                                                 Shape* tuple_shape) {
    258   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
    259   *tuple_shape->add_tuple_shapes() = shape;
    260 }
    261 
    262 /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
    263   CHECK(LayoutUtil::IsDenseArray(*shape));
    264   shape->mutable_layout()->add_minor_to_major(Rank(*shape));
    265   shape->add_dimensions(bound);
    266   TF_DCHECK_OK(ValidateShape(*shape));
    267 }
    268 
    269 /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
    270   return primitive_util::IsIntegralType(shape.element_type());
    271 }
    272 
    273 /* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
    274                                                        int32 bits) {
    275   return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
    276 }
    277 
    278 /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
    279   if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
    280     return false;
    281   }
    282   return primitive_util::BitWidth(shape.element_type()) == bits;
    283 }
    284 
    285 /* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
    286   switch (shape.element_type()) {
    287     case S8:
    288     case S16:
    289     case S32:
    290     case S64:
    291     case F16:
    292     case BF16:
    293     case F32:
    294     case F64:
    295       return true;
    296 
    297     case PRED:
    298     case U8:
    299     case U16:
    300     case U32:
    301     case U64:
    302     case C64:
    303     case TUPLE:
    304     case OPAQUE:
    305       return false;
    306 
    307     default:
    308       LOG(FATAL) << "Unhandled element type " << shape.element_type();
    309   }
    310 }
    311 
    312 /* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
    313   return primitive_util::IsComplexType(shape.element_type());
    314 }
    315 
    316 /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
    317   return primitive_util::IsFloatingPointType(shape.element_type());
    318 }
    319 
    320 /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
    321   return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
    322                                        shape.tuple_shapes().end(), IsTuple);
    323 }
    324 
    325 /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
    326   return IsTuple(shape) && TupleElementCount(shape) == 0;
    327 }
    328 
    329 /* static */ bool ShapeUtil::IsNil(const Shape& shape) {
    330   return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape);
    331 }
    332 
    333 /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
    334   CHECK(IsTuple(shape)) << HumanString(shape);
    335   return shape.tuple_shapes_size();
    336 }
    337 
    338 /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
    339                                                           int64 index) {
    340   CHECK(IsTuple(shape));
    341   CHECK_GT(TupleElementCount(shape), index);
    342   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
    343   return shape.tuple_shapes(index);
    344 }
    345 
    346 /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
    347                                          int64 limit) {
    348   TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
    349   CHECK(IsTuple(tuple));
    350   CHECK_LE(start, TupleElementCount(tuple));
    351   CHECK_LE(limit, TupleElementCount(tuple));
    352 
    353   std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
    354                                   tuple.tuple_shapes().begin() + limit);
    355   return MakeTupleShape(new_elements);
    356 }
    357 
    358 // Returns the shape of a real or imaginary component.
    359 /* static */ Shape ShapeUtil::ComplexComponentShape(
    360     const Shape& complex_shape) {
    361   CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
    362   return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
    363                                               complex_shape.element_type()));
    364 }
    365 
    366 /* static */ bool ShapeUtil::ShapeIs(const Shape& shape,
    367                                      PrimitiveType element_type,
    368                                      std::initializer_list<int64> dimensions) {
    369   return Equal(shape, MakeShape(element_type, dimensions));
    370 }
    371 
    372 /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
    373   CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape);
    374   CHECK_EQ(shape.dimensions_size(), Rank(shape));
    375   return std::accumulate<decltype(shape.dimensions().begin()), int64>(
    376       shape.dimensions().begin(), shape.dimensions().end(), 1LL,
    377       std::multiplies<int64>());
    378 }
    379 
    380 /* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) {
    381   return ElementsIn(shape) == 0;
    382 }
    383 
    384 /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
    385   return shape.element_type() == F32 && Rank(shape) == 0;
    386 }
    387 
    388 /* static */ string ShapeUtil::HumanString(const Shape& shape) {
    389   if (IsTuple(shape)) {
    390     string text = "(";
    391     const char* prefix = "";
    392     for (const Shape& elem_shape : shape.tuple_shapes()) {
    393       tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
    394       prefix = ", ";
    395     }
    396     text += ")";
    397     return text;
    398   } else {
    399     return tensorflow::strings::StrCat(
    400         tensorflow::str_util::Lowercase(
    401             PrimitiveType_Name(shape.element_type())),
    402         "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
    403   }
    404 }
    405 
    406 namespace {
    407 
    408 // Class to memoize the computation of
    409 //   tensorflow::str_util::Lowercase(PrimitiveType_Name(p))
    410 // for all PrimitiveType values "p"
    411 class PrimitiveTypeNameGenerator {
    412  public:
    413   PrimitiveTypeNameGenerator() {
    414     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
    415       if (PrimitiveType_IsValid(i)) {
    416         lowercase_name_[i] = tensorflow::str_util::Lowercase(
    417             PrimitiveType_Name(static_cast<PrimitiveType>(i)));
    418       }
    419     }
    420   }
    421   const string& LowercaseName(PrimitiveType t) {
    422     return lowercase_name_[static_cast<int>(t)];
    423   }
    424 
    425  private:
    426   string lowercase_name_[PrimitiveType_ARRAYSIZE];
    427 };
    428 
    429 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
    430   static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator();
    431   return gen->LowercaseName(s);
    432 }
    433 
    434 StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
    435   static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
    436     static auto* map = new std::unordered_map<string, PrimitiveType>;
    437     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
    438       if (PrimitiveType_IsValid(i)) {
    439         auto value = static_cast<PrimitiveType>(i);
    440         (*map)[LowercasePrimitiveTypeName(value)] = value;
    441       }
    442     }
    443     return map;
    444   }();
    445   auto found = name_to_type->find(name);
    446   if (found == name_to_type->end()) {
    447     return InvalidArgument("Invalid element type string: \"%s\".",
    448                            name.c_str());
    449   }
    450   return found->second;
    451 }
    452 
    453 }  // namespace
    454 
    455 /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
    456   if (IsTuple(shape)) {
    457     string text = "(";
    458     const char* prefix = "";
    459     for (const Shape& elem_shape : shape.tuple_shapes()) {
    460       tensorflow::strings::StrAppend(&text, prefix,
    461                                      HumanStringWithLayout(elem_shape));
    462       prefix = ", ";
    463     }
    464     text += ")";
    465     return text;
    466   } else {
    467     string result = tensorflow::strings::StrCat(
    468         LowercasePrimitiveTypeName(shape.element_type()), "[");
    469     for (int i = 0; i < shape.dimensions().size(); i++) {
    470       tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
    471                                      shape.dimensions(i));
    472     }
    473     result += "]";
    474     if (!IsScalar(shape) && !IsOpaque(shape)) {
    475       if (LayoutUtil::HasLayout(shape)) {
    476         tensorflow::strings::StrAppend(&result,
    477                                        LayoutUtil::HumanString(shape.layout()));
    478       }
    479     }
    480     return result;
    481   }
    482 }
    483 
    484 /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
    485   std::vector<string> parameters;
    486   for (auto& shape : program_shape.parameters()) {
    487     const int i = parameters.size();
    488     parameters.push_back(
    489         tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
    490                                         ? program_shape.parameter_names(i)
    491                                         : "(unknown)",
    492                                     ": ", HumanString(shape)));
    493   }
    494   return tensorflow::strings::StrCat(
    495       "(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
    496       HumanString(program_shape.result()));
    497 }
    498 
    499 namespace {
    500 // Parses shapes with simple recursive descent structure -- consumes from the
    501 // front of s and passes that view recursively as required.
    502 StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
    503   tensorflow::str_util::RemoveLeadingWhitespace(s);
    504 
    505   if (s->Consume("(")) {  // Tuple.
    506     std::vector<Shape> shapes;
    507     bool must_end = false;
    508     while (true) {
    509       if (s->Consume(")")) {
    510         break;
    511       } else if (must_end) {
    512         return InvalidArgument("Expected end of tuple; got: \"%s\"",
    513                                s->ToString().c_str());
    514       }
    515       shapes.emplace_back();
    516       TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
    517       tensorflow::str_util::RemoveLeadingWhitespace(s);
    518       must_end = !s->Consume(",");
    519     }
    520     return ShapeUtil::MakeTupleShape(shapes);
    521   }
    522 
    523   string element_type_string;
    524   string dimensions_string;
    525   string format_string;
    526   string layout_string;
    527   // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
    528   // we convert in to the RE2-consumable type and then consume the corresponding
    529   // amount from our StringPiece type.
    530   tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
    531   if (RE2::Consume(
    532           &s_consumable,
    533           "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?",
    534           &element_type_string, &dimensions_string, &format_string,
    535           &layout_string)) {
    536     size_t consumed = s->size() - s_consumable.size();
    537     s->remove_prefix(consumed);
    538     auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> {
    539       int64 element;
    540       if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) {
    541         return InvalidArgument(
    542             "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
    543             input.c_str(), s->ToString().c_str());
    544       }
    545       return element;
    546     };
    547 
    548     auto comma_list_to_int64s =
    549         [&s,
    550          string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
    551       std::vector<int64> results;
    552       for (const string& piece : tensorflow::str_util::Split(input, ',')) {
    553         TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
    554         results.push_back(element);
    555       }
    556       return results;
    557     };
    558 
    559     // Extract the dimensions.
    560     TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
    561                         comma_list_to_int64s(dimensions_string));
    562 
    563     // Extract the primitive element type.
    564     TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
    565                         StringToPrimitiveType(element_type_string));
    566     if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE ||
    567         primitive_type == OPAQUE) {
    568       return InvalidArgument("Invalid element type string: \"%s\".",
    569                              element_type_string.c_str());
    570     }
    571 
    572     Shape result;
    573     if (format_string.empty() && layout_string.empty()) {
    574       // Create a shape without a layout set.
    575       result = ShapeUtil::MakeShape(primitive_type, dimensions);
    576     } else if (format_string == "sparse") {
    577       TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string));
    578       result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions,
    579                                                     max_elements);
    580     } else if (format_string.empty() || format_string == "dense") {
    581       // Extract the layout minor-to-major and set it.
    582       TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
    583                           comma_list_to_int64s(layout_string));
    584       TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
    585                                       primitive_type, dimensions, min2maj));
    586     } else {
    587       // This should not be reached.
    588       LOG(FATAL) << "Unhandled condition when parsing shape; format: \""
    589                  << format_string << "\", layout: \"" << layout_string << "\"";
    590     }
    591     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
    592     return std::move(result);
    593   }
    594 
    595   return InvalidArgument("Invalid shape string to parse: \"%s\"",
    596                          s->ToString().c_str());
    597 }
    598 }  // namespace
    599 
    600 /* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
    601     tensorflow::StringPiece s) {
    602   TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s));
    603   if (!s.empty()) {
    604     return InvalidArgument("Invalid shape string to parse: \"%s\"",
    605                            s.ToString().c_str());
    606   }
    607   return shape;
    608 }
    609 
    610 /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
    611                                             const Shape& rhs) {
    612   return ContainersEqual(lhs.dimensions(), rhs.dimensions());
    613 }
    614 
    615 /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
    616   if (lhs.element_type() == TUPLE) {
    617     return rhs.element_type() == TUPLE &&
    618            ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
    619   }
    620   return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
    621 }
    622 
    623 /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
    624                                                            const Shape& rhs) {
    625   if (lhs.element_type() == TUPLE) {
    626     return rhs.element_type() == TUPLE &&
    627            ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
    628                            CompatibleIgnoringElementType);
    629   }
    630   return SameDimensions(lhs, rhs);
    631 }
    632 
    633 /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
    634                                                            const Shape& rhs) {
    635   if (lhs.element_type() == TUPLE) {
    636     return rhs.element_type() == TUPLE &&
    637            ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
    638                            CompatibleIgnoringFpPrecision);
    639   }
    640   if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
    641     return CompatibleIgnoringElementType(lhs, rhs);
    642   }
    643   return false;
    644 }
    645 
    646 /* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
    647                                            int64 dimension_number) {
    648   return shape.dimensions(GetDimensionNumber(shape, dimension_number));
    649 }
    650 
    651 /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
    652                                                  int64 dimension_number) {
    653   if (dimension_number < 0) {
    654     dimension_number += Rank(shape);
    655   }
    656   CHECK_GE(dimension_number, 0);
    657   return dimension_number;
    658 }
    659 
    660 /* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
    661     PrimitiveType primitive_type) {
    662   switch (primitive_type) {
    663     case PRED:
    664       return sizeof(int8);
    665     case TUPLE:
    666       LOG(FATAL) << "tuples have no definitive size";
    667     case OPAQUE:
    668       LOG(FATAL) << "opaque have no definitive size";
    669     case S8:
    670       return sizeof(int8);
    671     case S16:
    672       return sizeof(int16);
    673     case S32:
    674       return sizeof(int32);
    675     case S64:
    676       return sizeof(int64);
    677     case U8:
    678       return sizeof(uint8);
    679     case U16:
    680       return sizeof(uint16);
    681     case U32:
    682       return sizeof(uint32);
    683     case U64:
    684       return sizeof(uint64);
    685     case BF16:
    686       return sizeof(float) / 2;
    687     case F16:
    688       return sizeof(float) / 2;
    689     case F32:
    690       return sizeof(float);
    691     case F64:
    692       return sizeof(double);
    693     case C64:
    694       return sizeof(complex64);
    695     default:
    696       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    697   }
    698 }
    699 
    700 /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
    701                                          int64 pointer_size) {
    702   TF_DCHECK_OK(ValidateShape(shape));
    703   DCHECK_NE(OPAQUE, shape.element_type());
    704   if (shape.element_type() == TUPLE) {
    705     return ByteSizeOfTupleIndexTable(shape, pointer_size);
    706   }
    707   int64 byte_size = ByteSizeOfElements(shape);
    708   if (LayoutUtil::IsSparseArray(shape)) {
    709     byte_size += ByteSizeOfSparseIndices(shape);
    710   }
    711   return byte_size;
    712 }
    713 
    714 /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
    715                                                         int64 pointer_size) {
    716   TF_DCHECK_OK(ValidateShape(shape));
    717   DCHECK_EQ(TUPLE, shape.element_type());
    718   CHECK_GT(pointer_size, 0);
    719   return pointer_size * shape.tuple_shapes_size();
    720 }
    721 
    722 /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
    723   TF_DCHECK_OK(ValidateShape(shape));
    724   DCHECK(ShapeUtil::IsArray(shape));
    725   int64 allocated_element_count;
    726 
    727   if (LayoutUtil::IsSparseArray(shape)) {
    728     allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
    729   } else {
    730     CHECK(LayoutUtil::IsDenseArray(shape));
    731     tensorflow::gtl::ArraySlice<int64> padded_dimensions =
    732         LayoutUtil::PaddedDimensions(shape);
    733     if (!padded_dimensions.empty()) {
    734       CHECK_EQ(Rank(shape), padded_dimensions.size());
    735       allocated_element_count = 1;
    736       for (int64 dimension_size : padded_dimensions) {
    737         allocated_element_count *= dimension_size;
    738       }
    739     } else {
    740       allocated_element_count = ElementsIn(shape);
    741     }
    742   }
    743   return allocated_element_count *
    744          ByteSizeOfPrimitiveType(shape.element_type());
    745 }
    746 
    747 /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
    748   TF_DCHECK_OK(ValidateShape(shape));
    749   DCHECK(LayoutUtil::IsSparseArray(shape));
    750   return LayoutUtil::MaxSparseElements(shape.layout()) *
    751          ShapeUtil::Rank(shape) * sizeof(int64);
    752 }
    753 
    754 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
    755     const Shape& shape) {
    756   if (shape.element_type() == TUPLE) {
    757     if (shape.dimensions_size() != 0) {
    758       return InvalidArgument("tuples must not have dimensions specified");
    759     }
    760     for (auto& element_shape : shape.tuple_shapes()) {
    761       TF_RETURN_IF_ERROR(
    762           ValidateShapeWithOptionalLayoutInternal(element_shape));
    763     }
    764     return Status::OK();
    765   }
    766 
    767   // Non-tuple shape.
    768   if (shape.tuple_shapes_size() > 0) {
    769     return InvalidArgument("non-tuple shape has tuple_shapes field");
    770   }
    771   if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
    772     return InvalidArgument("shape has invalid element type: %s",
    773                            shape.ShortDebugString().c_str());
    774   }
    775   if (Rank(shape) != shape.dimensions_size()) {
    776     return InvalidArgument(
    777         "shape's rank is mismatched with dimension count; rank=%lld "
    778         "dimensions_size=%d",
    779         Rank(shape), shape.dimensions_size());
    780   }
    781   for (int64 i = 0; i < Rank(shape); ++i) {
    782     int64 dimension = shape.dimensions(i);
    783     if (dimension < 0) {
    784       return InvalidArgument(
    785           "shape's dimensions must not be < 0; dimension at index %lld was "
    786           "%lld",
    787           i, dimension);
    788     }
    789   }
    790 
    791   return Status::OK();
    792 }
    793 
    794 /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
    795     const Shape& shape) {
    796   if (LayoutUtil::HasLayout(shape)) {
    797     // Since a layout is present, upgrade to the full set of invariant checks.
    798     return ValidateShape(shape);
    799   }
    800   return ValidateShapeWithOptionalLayoutInternal(shape);
    801 }
    802 
    803 /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
    804   TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
    805 
    806   return LayoutUtil::ValidateLayoutInShape(shape);
    807 }
    808 
    809 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
    810                                                 PrimitiveType type) {
    811   Shape new_shape = original;
    812   new_shape.set_element_type(type);
    813   return new_shape;
    814 }
    815 
    816 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
    817                                                  ShapeIndexView index) {
    818   const Shape* return_shape = &shape;
    819   for (auto i : index) {
    820     CHECK(IsTuple(*return_shape))
    821         << "Invalid index " << index << " for shape " << shape;
    822     return_shape = &return_shape->tuple_shapes(i);
    823   }
    824   return *return_shape;
    825 }
    826 
    827 /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
    828                                                   ShapeIndexView index) {
    829   Shape* return_shape = shape;
    830   for (auto i : index) {
    831     CHECK(IsTuple(*return_shape));
    832     return_shape = return_shape->mutable_tuple_shapes(i);
    833   }
    834   return return_shape;
    835 }
    836 
    837 /* static */
    838 bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
    839   return !IsTuple(GetSubshape(shape, index));
    840 }
    841 
    842 /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
    843   std::vector<int64> dimension_sizes;
    844   std::vector<int64> degenerate_dimensions;
    845   for (int64 i = 0; i < shape.dimensions_size(); ++i) {
    846     if (shape.dimensions(i) == 1) {
    847       degenerate_dimensions.push_back(i);
    848     } else {
    849       dimension_sizes.push_back(shape.dimensions(i));
    850     }
    851   }
    852 
    853   // Construct minor_to_major of stripped shape. The order of the non-degenerate
    854   // dimensions should be preserved from the original shape. First, create
    855   // vector of the non-degenerate dimensions from the original minor_to_major
    856   // array.
    857   std::vector<int64> minor_to_major;
    858   for (int64 i : shape.layout().minor_to_major()) {
    859     if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(),
    860                   i) == degenerate_dimensions.end()) {
    861       minor_to_major.push_back(i);
    862     }
    863   }
    864 
    865   // The dimensions in minor_to_major need to be renumbered to account for the
    866   // degenerate dimensions which have removed. Decrement each dimension number
    867   // once for each degenerate dimension which has a smaller number.
    868   for (int i = 0; i < minor_to_major.size(); ++i) {
    869     int adjustment = 0;
    870     for (int64 dim : degenerate_dimensions) {
    871       if (minor_to_major[i] > dim) {
    872         adjustment++;
    873       }
    874     }
    875     minor_to_major[i] -= adjustment;
    876   }
    877 
    878   {
    879     std::vector<int64> dims(minor_to_major.size());
    880     std::iota(dims.begin(), dims.end(), 0);
    881     DCHECK(minor_to_major.size() == dims.size() &&
    882            std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
    883                                dims.begin()));
    884   }
    885   Shape stripped_shape =
    886       shape.has_layout() ? MakeShapeWithLayout(shape.element_type(),
    887                                                dimension_sizes, minor_to_major)
    888                          : MakeShape(shape.element_type(), dimension_sizes);
    889 
    890   VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
    891   VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
    892   return stripped_shape;
    893 }
    894 
    895 namespace {
    896 
    897 // Helper for ForEachSubshape which visits the subshapes of the given shape in
    898 // DFS pre-order starting with the index.
    899 Status ForEachSubshapeHelper(const Shape& shape,
    900                              const ShapeUtil::StatusVisitorFunction& func,
    901                              ShapeIndex* index) {
    902   TF_RETURN_IF_ERROR(func(shape, *index));
    903   if (ShapeUtil::IsTuple(shape)) {
    904     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
    905       index->push_back(i);
    906       TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
    907           ShapeUtil::GetTupleElementShape(shape, i), func, index));
    908       index->pop_back();
    909     }
    910   }
    911   return Status::OK();
    912 }
    913 
    914 // Helper for ForEachMutableSubshape which visits the subshapes of the given
    915 // shape in DFS pre-order starting with the index.
    916 Status ForEachMutableSubshapeHelper(
    917     Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
    918     ShapeIndex* index) {
    919   TF_RETURN_IF_ERROR(func(shape, *index));
    920   if (ShapeUtil::IsTuple(*shape)) {
    921     for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
    922       index->push_back(i);
    923       TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
    924           shape->mutable_tuple_shapes(i), func, index));
    925       index->pop_back();
    926     }
    927   }
    928   return Status::OK();
    929 }
    930 
    931 }  // namespace
    932 
    933 /* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
    934                                              const VisitorFunction& func) {
    935   ShapeIndex index;
    936   ForEachSubshapeHelper(
    937       shape,
    938       [&func](const Shape& subshape, const ShapeIndex& index) {
    939         func(subshape, index);
    940         return Status::OK();
    941       },
    942       &index)
    943       .IgnoreError();
    944 }
    945 
    946 /* static */ void ShapeUtil::ForEachMutableSubshape(
    947     Shape* shape, const MutatingVisitorFunction& func) {
    948   ShapeIndex index;
    949   ForEachMutableSubshapeHelper(
    950       shape,
    951       [&func](Shape* subshape, const ShapeIndex& index) {
    952         func(subshape, index);
    953         return Status::OK();
    954       },
    955       &index)
    956       .IgnoreError();
    957 }
    958 
    959 /* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
    960     const Shape& shape, const StatusVisitorFunction& func) {
    961   ShapeIndex index;
    962   return ForEachSubshapeHelper(shape, func, &index);
    963 }
    964 
    965 /* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
    966     Shape* shape, const MutatingStatusVisitorFunction& func) {
    967   ShapeIndex index;
    968   return ForEachMutableSubshapeHelper(shape, func, &index);
    969 }
    970 
    971 /* static */ Shape ShapeUtil::PermuteDimensions(
    972     tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape) {
    973   Shape new_shape = shape;
    974   new_shape.clear_dimensions();
    975   for (auto dim : Permute(permutation, shape.dimensions())) {
    976     new_shape.add_dimensions(dim);
    977   }
    978   if (shape.has_layout()) {
    979     CHECK(LayoutUtil::IsDenseArray(shape));
    980     Layout* new_layout = new_shape.mutable_layout();
    981     new_layout->set_format(DENSE);
    982     new_layout->clear_minor_to_major();
    983     for (auto index : Permute(permutation, shape.layout().minor_to_major())) {
    984       new_layout->add_minor_to_major(index);
    985     }
    986     if (shape.layout().padded_dimensions_size() > 0) {
    987       new_layout->clear_padded_dimensions();
    988       for (auto dim :
    989            Permute(permutation, shape.layout().padded_dimensions())) {
    990         new_layout->add_padded_dimensions(dim);
    991       }
    992     }
    993   }
    994   return new_shape;
    995 }
    996 
    997 /* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
    998 ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
    999                                              const Shape& shape_post) {
   1000   auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
   1001 
   1002   std::vector<int64> deleted_indices;
   1003   std::vector<int64> inserted_indices;
   1004   // Returns false if any input/output index between prior_unmodified_dim_pair
   1005   // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
   1006   // the degerenate input/output dimensions in the gap to
   1007   // deleted_indices/inserted_indices respectively.
   1008   auto check_modified_dims =
   1009       [&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
   1010           std::pair<int64, int64> prior_unmodified_dim_pair,
   1011           std::pair<int64, int64> unmodified_dim_pair) {
   1012         for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
   1013              modified_input_dim < unmodified_dim_pair.first;
   1014              ++modified_input_dim) {
   1015           if (shape_pre.dimensions(modified_input_dim) > 1) {
   1016             return false;
   1017           }
   1018           deleted_indices.push_back(modified_input_dim);
   1019         }
   1020         for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
   1021              modified_output_dim < unmodified_dim_pair.second;
   1022              ++modified_output_dim) {
   1023           if (shape_post.dimensions(modified_output_dim) > 1) {
   1024             return false;
   1025           }
   1026           inserted_indices.push_back(modified_output_dim);
   1027         }
   1028         return true;
   1029       };
   1030 
   1031   std::vector<std::pair<int64, int64>> unmodified_dims =
   1032       DimensionsUnmodifiedByReshape(shape_pre, shape_post);
   1033   // Returns nil if the reshape modifies any non-degenerate input/output
   1034   // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
   1035   // dimensions, so we only need to check whether dimensions in the gaps (thus
   1036   // modified) have size >1.
   1037   for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
   1038     // Check (modified) dimensions between unmodified_dims[i-1] and
   1039     // unmodified_dims[i].
   1040     auto prior_unmodified_dim_pair =
   1041         i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL);
   1042     auto unmodified_dim_pair =
   1043         i < unmodified_dims.size()
   1044             ? unmodified_dims[i]
   1045             : std::make_pair(Rank(shape_pre), Rank(shape_post));
   1046     if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
   1047       return nil;
   1048     }
   1049   }
   1050 
   1051   return std::make_tuple(true, deleted_indices, inserted_indices);
   1052 }
   1053 
   1054 /* static */ std::vector<std::pair<int64, int64>>
   1055 ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
   1056                                          const Shape& output_shape) {
   1057   // Unmodified dimensions are merely common factors of rank 1.
   1058   auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
   1059                                       AsInt64Slice(output_shape.dimensions()));
   1060   for (size_t i = 0; i < common_factors.size() - 1;) {
   1061     if (1 != common_factors[i + 1].first - common_factors[i].first ||
   1062         1 != common_factors[i + 1].second - common_factors[i].second) {
   1063       common_factors.erase(common_factors.begin() + i);
   1064     } else {
   1065       ++i;
   1066     }
   1067   }
   1068   // `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
   1069   common_factors.pop_back();
   1070   return common_factors;
   1071 }
   1072 
   1073 /* static */ bool ShapeUtil::TransposeIsBitcast(
   1074     const Shape& input_shape, const Shape& output_shape,
   1075     tensorflow::gtl::ArraySlice<int64> dimension_mapping) {
   1076   // Can't insert bitcasts without layout information.
   1077   if (!LayoutUtil::HasLayout(input_shape) &&
   1078       !LayoutUtil::HasLayout(output_shape)) {
   1079     return false;
   1080   }
   1081 
   1082   // Padding is not handled.
   1083   if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) {
   1084     return false;
   1085   }
   1086 
   1087   // Check the reshape permutes the positions of each dimension in the
   1088   // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
   1089   //   input_positions = apply(dimension_mapping, output_positions)
   1090   //
   1091   // Because the positions of each dimension are the inverse permutation of the
   1092   // minor-to-major order, the above check is equivalent to
   1093   //   inverse(input_dimensions) =
   1094   //       apply(dimension_mapping, inverse(output_dimensions))
   1095   //   # `I` indicates identity permutation.
   1096   //   apply(input_dimensions, I) =
   1097   //       apply(dimension_mapping, apply(output_dimensions, I))
   1098   //   apply(input_dimensions, I) =
   1099   //       apply((dimension_mapping * output_dimensions), I)
   1100   //   input_dimensions = dimension_mapping * output_dimensions
   1101   return ContainersEqual(
   1102       ComposePermutations(dimension_mapping,
   1103                           AsInt64Slice(output_shape.layout().minor_to_major())),
   1104       input_shape.layout().minor_to_major());
   1105 }
   1106 
   1107 /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
   1108                                               const Shape& output_shape) {
   1109   // Can't convert reshapes into bitcasts without layout information.
   1110   if (!LayoutUtil::HasLayout(input_shape) ||
   1111       !LayoutUtil::HasLayout(output_shape)) {
   1112     return false;
   1113   }
   1114 
   1115   // Padding is not handled.
   1116   if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) {
   1117     return false;
   1118   }
   1119 
   1120   CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape));
   1121   if (ElementsIn(input_shape) == 0) {
   1122     return true;
   1123   }
   1124 
   1125   // TL;DR: The rest of the method checks that the reshape does not change the
   1126   // physical location of any unit input or output index. Unit indices have
   1127   // exactly one dimension that equals 1 and other dimensions 0. This condition
   1128   // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
   1129   // reshape shouldn't change the physical location of any element. It is also a
   1130   // sufficient condition as is proved below (note: many details are omitted for
   1131   // space).
   1132   //
   1133   // Definitions:
   1134   //
   1135   // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
   1136   // the size of i-th least significant dimension of IS or OS (this is opposite
   1137   // to how we define the index of Shape::dimensions()).
   1138   //
   1139   // * Given an input or output index I, denote by p(I) I's physical linear
   1140   // index (or physical index for short) and l(I) I's logical linear index (or
   1141   // logical index for short).
   1142   //
   1143   // * Given a logical index k, denote by II(k) the input index whose linear
   1144   // index is k, and OI(k) the corresponding output index.
   1145   //
   1146   // * Denote by IT[i] the increment of physical index if i-th dimension of the
   1147   // input index is increased by 1. Similarly, OT[i] means the increment if i-th
   1148   // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
   1149   // is a function of IS or OS and the layout, and not dependent on the specific
   1150   // input or output index.
   1151   //
   1152   // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
   1153   // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
   1154   // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
   1155   // to prove, with every increment on k, the above formula still holds.
   1156   //
   1157   // First, suppose reshaping from IS to OS is non-factorizable (we discuss
   1158   // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
   1159   // there exists (i,j) such that
   1160   //
   1161   //   0<=i<=|IS|
   1162   //   0<=j<=|OS|
   1163   //   |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
   1164   //   product(IS[i], IS[i+1], ..., IS[|IS|-1])
   1165   //     = product(OS[j], OS[j+1], ..., OS[|OS|-1])
   1166   //
   1167   // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
   1168   // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
   1169   // unit indices which are already tested. This also means IT[0]=OT[0]
   1170   // because p(II(1))=IT[0] and p(OI(1))=OT[0].
   1171   //
   1172   // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
   1173   // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
   1174   // to the output physical.
   1175   //
   1176   // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
   1177   // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
   1178   // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
   1179   // logical index k-1 to logical index k, dimension 1 of the input index
   1180   // is increased by 1 and dimension 0 is reset to 0 thus decreased by
   1181   // IS[0]-1. Therefore, the physical input index is increased by
   1182   //
   1183   //   p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
   1184   //
   1185   // Because IS[0]<OS[0], the only change to the output index is that its
   1186   // dimension 0 is increased by one. Therefore,
   1187   //
   1188   //   p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
   1189   //
   1190   // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
   1191   // p(II(k))=p(OI(k)). Therefore,
   1192   //   IT[1] - (IS[0]-1) * IT[0] = IT[0]
   1193   //   IT[1] = IS[0] * IT[0]
   1194   // In other words, input dimension 1 is immediately more major than input
   1195   // dimension 0. We can now conceptually collapse these two dimensions because
   1196   // an increment in the logical index affecting only these two dimensions maps
   1197   // to IT[0] in the physical index.
   1198   //
   1199   // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
   1200   // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
   1201   // identical.
   1202   //
   1203   // A factorizable reshape can be factorized into a list of non-factorizable
   1204   // sub-reshapes, each of which can be handled similarly to the proof above.
   1205   // For example,
   1206   //
   1207   //   [7x9x2x15] -> [63x6x5]
   1208   //
   1209   // can be factorized into
   1210   //
   1211   //   [7x9] -> [63] and [2x15] -> [6x5].
   1212   //
   1213   // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
   1214   // same logical linear index. According to the factorization, we know
   1215   // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
   1216   // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
   1217   // similar proof, with the increment of the logical index set to
   1218   // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
   1219   // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
   1220   //
   1221   //   p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
   1222   //                  = p(y2,0,0) + p(0,0,y1,y0)
   1223   //                  = p(y2,y1,y0)
   1224   //
   1225   // check_input_unit_indices checks one way of the condition: each input unit
   1226   // index is mapped to an output index with the same physical location. This
   1227   // lambda will be called again with input_shape and output_shape reversed to
   1228   // check the other way.
   1229   auto check_input_unit_indices = [](const Shape& input_shape,
   1230                                      const Shape& output_shape) {
   1231     // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
   1232     // as input_shape/output_shape and the dimension-0-major layout. These two
   1233     // shapes are used for conversion between logical linear indices and
   1234     // multi-dimensional indices.
   1235     Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
   1236         input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
   1237     Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
   1238         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
   1239 
   1240     for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) {
   1241       if (input_shape.dimensions(input_dim) <= 1) {
   1242         continue;
   1243       }
   1244 
   1245       std::vector<int64> input_unit_index(Rank(input_shape), 0);
   1246       input_unit_index[input_dim] = 1;
   1247       int64 logical_linear_index =
   1248           IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
   1249                                                         input_unit_index);
   1250       // output_index has the same logical linear index as input_unit_index.
   1251       std::vector<int64> output_index =
   1252           IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
   1253                                                         logical_linear_index);
   1254       // Check input_unit_index and output_index have the same physical linear
   1255       // index.
   1256       if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
   1257                                                         input_unit_index) !=
   1258           IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
   1259                                                         output_index)) {
   1260         return false;
   1261       }
   1262     }
   1263     return true;
   1264   };
   1265   return check_input_unit_indices(input_shape, output_shape) &&
   1266          check_input_unit_indices(output_shape, input_shape);
   1267 }
   1268 
   1269 /* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
   1270     const Shape& input_shape, const Shape& output_shape) {
   1271   int64 input_rank = Rank(input_shape);
   1272   int64 output_rank = Rank(output_shape);
   1273 
   1274   // First, calculate an alignment of the dimensions. A consecutive sequence of
   1275   // input dimensions and output dimensions belong to the same alignment part if
   1276   // the products of their dimension bounds are the same. In the easiest case,
   1277   // an alignment part consists of one input dimension and one output dimension
   1278   // which both have the same dimension bound. An alignment part specifies which
   1279   // dimensions need to be kept together in a physical layout if we want a
   1280   // reshape to be a bitcast. The order of the alignment parts is defined by the
   1281   // physical layout of the input shape, so when we construct the layout for the
   1282   // output shape we just process the alignment parts in this order, and then
   1283   // layout the dimensions belonging to each part in descending (major to minor)
   1284   // order.
   1285 
   1286   // Stores the input and output dimension numbers where each alignment part
   1287   // starts.
   1288   std::vector<std::pair<int64, int64>> alignment;
   1289   alignment.push_back({0, 0});
   1290 
   1291   // Stores a mapping from the input dimension to the alignment part it belongs
   1292   // to.
   1293   std::vector<int64> dimension_to_alignment_index(input_rank);
   1294   int64 input_dimension_product = 1, output_dimension_product = 1;
   1295   for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
   1296     // Check if we have reached the end of an alignment part.
   1297     if (input_dimension_product == output_dimension_product &&
   1298         input_dimension_product > 1) {
   1299       alignment.push_back({i, j});
   1300       input_dimension_product = output_dimension_product = 1;
   1301     }
   1302     if (input_dimension_product < output_dimension_product ||
   1303         j == output_rank) {
   1304       if (i == input_rank) {
   1305         return tensorflow::gtl::nullopt;
   1306       }
   1307       dimension_to_alignment_index[i] = alignment.size() - 1;
   1308       input_dimension_product *= input_shape.dimensions(i);
   1309       ++i;
   1310     } else {
   1311       output_dimension_product *= output_shape.dimensions(j);
   1312       ++j;
   1313     }
   1314   }
   1315   if (input_dimension_product != output_dimension_product) {
   1316     return tensorflow::gtl::nullopt;
   1317   }
   1318   // We also need to store an end element so that we know where the last
   1319   // alignment part ends.
   1320   alignment.push_back({input_rank, output_rank});
   1321 
   1322   // Now check if the physical layout can potentially be aligned to the output
   1323   // shape by changing the physical layout of the output shape. We need to check
   1324   // that all dimension numbers that belong to the same alignment part appear
   1325   // consecutively, and are in descending order. However we can ignore any
   1326   // trivial dimension bounds of 1, because they can be placed anywhere.
   1327   auto input_dimension_numbers = input_shape.layout().minor_to_major();
   1328   std::vector<int64> output_layout;
   1329   output_layout.reserve(output_rank);
   1330   for (int64 i = 0; i < input_rank;) {
   1331     int64 current_dimension_number = input_dimension_numbers[i];
   1332 
   1333     // Skip trivial dimensions with a bound of 1.
   1334     if (input_shape.dimensions(current_dimension_number) == 1) {
   1335       ++i;
   1336       continue;
   1337     }
   1338 
   1339     // Calculate the number of non-trivial dimension bounds in the input shape
   1340     // belonging to the current alignment part.
   1341     const int64 current_alignment_index =
   1342         dimension_to_alignment_index[current_dimension_number];
   1343     // Because of the special end element that we added, we can be sure that
   1344     // 'current_alignment_index' is < alignment.size() - 1.
   1345     CHECK_LT(current_alignment_index, alignment.size() - 1);
   1346     int64 num_non_trivial_dimensions_in_alignment_part = 0;
   1347     for (int64 j = alignment[current_alignment_index].first;
   1348          j < alignment[current_alignment_index + 1].first; ++j) {
   1349       if (input_shape.dimensions(j) != 1) {
   1350         ++num_non_trivial_dimensions_in_alignment_part;
   1351       }
   1352     }
   1353 
   1354     // Check that the following 'num_non_trivial_dimensions_in_alignment_part'
   1355     // dimension numbers (ignoring dimension numbers with dimension bound 1) are
   1356     // in descending order and belong to the current alignment part.
   1357     for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part;
   1358          ++i, ++j) {
   1359       if (i == input_rank) {
   1360         return tensorflow::gtl::nullopt;
   1361       }
   1362       // Skip trivial dimensions with a bound of 1.
   1363       if (input_shape.dimensions(input_dimension_numbers[i]) == 1) {
   1364         --j;
   1365         continue;
   1366       }
   1367       // If the current dimension number belongs to a different alignment part,
   1368       // or the dimension numbers are not in descending order, we can return
   1369       // early.
   1370       if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
   1371               current_alignment_index ||
   1372           input_dimension_numbers[i] > current_dimension_number) {
   1373         return tensorflow::gtl::nullopt;
   1374       }
   1375       current_dimension_number = input_dimension_numbers[i];
   1376     }
   1377 
   1378     // The output dimension numbers that belong to the current alignment part
   1379     // need to appear in the same descending order as in the input. Again, we
   1380     // can skip dimensions with a bound of 1.
   1381     for (int64 j = alignment[current_alignment_index + 1].second - 1;
   1382          j >= alignment[current_alignment_index].second; --j) {
   1383       if (output_shape.dimensions(j) != 1) {
   1384         output_layout.push_back(j);
   1385       }
   1386     }
   1387   }
   1388   // Now add all the dimensions with dimension bound 1 at the end of
   1389   // 'output_layout'.
   1390   for (int64 i = 0; i < output_rank; ++i) {
   1391     if (output_shape.dimensions(i) == 1) {
   1392       output_layout.push_back(i);
   1393     }
   1394   }
   1395   CHECK_EQ(output_layout.size(), output_rank);
   1396   Shape output_shape_with_layout = MakeShapeWithLayout(
   1397       output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
   1398       output_layout);
   1399   CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout));
   1400   return output_shape_with_layout;
   1401 }
   1402 
   1403 /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
   1404                                               Shape shape) {
   1405   shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
   1406   if (LayoutUtil::HasLayout(shape)) {
   1407     Layout* layout = shape.mutable_layout();
   1408     layout->set_format(DENSE);
   1409     for (size_t i = 0; i < layout->minor_to_major().size();) {
   1410       if (layout->minor_to_major(i) == dim_to_delete) {
   1411         layout->mutable_minor_to_major()->erase(
   1412             layout->minor_to_major().begin() + i);
   1413         continue;
   1414       }
   1415       if (layout->minor_to_major(i) > dim_to_delete) {
   1416         (*layout->mutable_minor_to_major())[i] -= 1;
   1417       }
   1418       ++i;
   1419     }
   1420   }
   1421   return shape;
   1422 }
   1423 
   1424 /* static */ Shape ShapeUtil::FilterDimensions(
   1425     const std::function<bool(int64)>& p, Shape shape) {
   1426   std::vector<int64> dims_to_delete;
   1427   for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
   1428     if (!p(i)) {
   1429       dims_to_delete.push_back(i);
   1430     }
   1431   }
   1432   for (int64 dim : dims_to_delete) {
   1433     shape = DeleteDimension(dim, shape);
   1434   }
   1435   return shape;
   1436 }
   1437 
   1438 std::ostream& operator<<(std::ostream& out, const Shape& shape) {
   1439   out << ShapeUtil::HumanString(shape);
   1440   return out;
   1441 }
   1442 
   1443 }  // namespace xla
   1444