Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/shape_util.h"
     17 
     18 #include "tensorflow/compiler/xla/layout_util.h"
     19 #include "tensorflow/compiler/xla/status_macros.h"
     20 #include "tensorflow/compiler/xla/test.h"
     21 #include "tensorflow/compiler/xla/test_helpers.h"
     22 #include "tensorflow/compiler/xla/types.h"
     23 #include "tensorflow/compiler/xla/util.h"
     24 #include "tensorflow/compiler/xla/xla_data.pb.h"
     25 
     26 namespace xla {
     27 namespace {
     28 
     29 using ::testing::ElementsAre;
     30 
     31 TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) {
     32   Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
     33   EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1));
     34   EXPECT_EQ(2, ShapeUtil::GetDimension(matrix, -2));
     35 }
     36 
     37 TEST(ShapeUtilTest, GetDimensionHelperExampleInDocumentationTest) {
     38   auto shape = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
     39   ASSERT_EQ(4, ShapeUtil::GetDimension(shape, -1));
     40 }
     41 
     42 TEST(ShapeUtilTest, NegativeIndexOobFails) {
     43   Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
     44   ASSERT_DEATH(ShapeUtil::GetDimension(matrix, -3), "dimension_number >= 0");
     45 }
     46 
     47 TEST(ShapeUtilTest, Rank1DimensionIndexing) {
     48   Shape shape = ShapeUtil::MakeShape(F32, {3});
     49   ASSERT_EQ(3, shape.dimensions(0));
     50 }
     51 
     52 TEST(ShapeUtilTest, Rank2DimensionIndexing) {
     53   Shape shape = ShapeUtil::MakeShape(F32, {3, 2});
     54   ASSERT_EQ(2, shape.dimensions(1));
     55   ASSERT_EQ(3, shape.dimensions(0));
     56 }
     57 
     58 TEST(ShapeUtilTest, Rank3DimensionIndexing) {
     59   Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7});
     60   ASSERT_EQ(7, shape.dimensions(2));
     61   ASSERT_EQ(2, shape.dimensions(1));
     62   ASSERT_EQ(3, shape.dimensions(0));
     63 }
     64 
     65 TEST(ShapeUtilTest, Rank4DimensionIndexing) {
     66   Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7, 8});
     67   ASSERT_EQ(8, shape.dimensions(3));
     68   ASSERT_EQ(7, shape.dimensions(2));
     69   ASSERT_EQ(2, shape.dimensions(1));
     70   ASSERT_EQ(3, shape.dimensions(0));
     71 }
     72 
     73 TEST(ShapeUtilTest, ParseShapeStringR2F32) {
     74   string shape_string = "f32[123,456]";
     75   TF_ASSERT_OK_AND_ASSIGN(Shape actual,
     76                           ShapeUtil::ParseShapeString(shape_string));
     77   Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
     78   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
     79       << "expected: " << ShapeUtil::HumanString(expected)
     80       << "actual:   " << ShapeUtil::HumanString(actual);
     81 }
     82 
     83 TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
     84   string shape_string = "(f32[1572864],s8[5120,1024])";
     85   TF_ASSERT_OK_AND_ASSIGN(Shape actual,
     86                           ShapeUtil::ParseShapeString(shape_string));
     87   Shape expected =
     88       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
     89                                  ShapeUtil::MakeShape(S8, {5120, 1024})});
     90   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
     91       << "expected: " << ShapeUtil::HumanString(expected)
     92       << "actual:   " << ShapeUtil::HumanString(actual);
     93 }
     94 
     95 TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
     96   string shape_string = "(f32[1],(f32[2]), f32[3])";
     97   TF_ASSERT_OK_AND_ASSIGN(Shape actual,
     98                           ShapeUtil::ParseShapeString(shape_string));
     99   Shape expected = ShapeUtil::MakeTupleShape({
    100       ShapeUtil::MakeShape(F32, {1}),
    101       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
    102       ShapeUtil::MakeShape(F32, {3}),
    103   });
    104   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
    105       << "expected: " << ShapeUtil::HumanString(expected)
    106       << "actual:   " << ShapeUtil::HumanString(actual);
    107 }
    108 
    109 TEST(ShapeUtilTest, ParseShapeStringWithLayout) {
    110   string shape_string = "f32[123,456]{0,1}";
    111   TF_ASSERT_OK_AND_ASSIGN(Shape actual,
    112                           ShapeUtil::ParseShapeString(shape_string));
    113   Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
    114   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
    115       << "expected: " << ShapeUtil::HumanString(expected)
    116       << "actual:   " << ShapeUtil::HumanString(actual);
    117 }
    118 
    119 TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) {
    120   string shape_string = "f32[123,456]dense{0,1}";
    121   TF_ASSERT_OK_AND_ASSIGN(Shape actual,
    122                           ShapeUtil::ParseShapeString(shape_string));
    123   Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
    124   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
    125       << "expected: " << ShapeUtil::HumanString(expected)
    126       << "actual:   " << ShapeUtil::HumanString(actual);
    127 }
    128 
    129 TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
    130   string shape_string = "f32[123,456]sparse{10}";
    131   TF_ASSERT_OK_AND_ASSIGN(Shape actual,
    132                           ShapeUtil::ParseShapeString(shape_string));
    133   Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
    134   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
    135       << "expected: " << ShapeUtil::HumanString(expected)
    136       << "actual: " << ShapeUtil::HumanString(actual);
    137 }
    138 
    139 TEST(ShapeUtilTest, ParseInvalidShapeString) {
    140   string shape_strings[] = {
    141       "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
    142       "f32[123,456]dense{foo}",  "f32[123,456]sparse{foo}",
    143   };
    144   for (const string& shape_string : shape_strings) {
    145     StatusOr<Shape> result = ShapeUtil::ParseShapeString(shape_string);
    146     ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
    147   }
    148 }
    149 
    150 TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
    151   Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
    152   Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
    153   ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2));
    154 }
    155 
    156 TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
    157   Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
    158   auto layout_1 = shape_1.mutable_layout();
    159   layout_1->clear_minor_to_major();
    160   layout_1->add_minor_to_major(0);
    161   layout_1->add_minor_to_major(1);
    162 
    163   Shape shape_2 = ShapeUtil::MakeShape(F32, {3, 2});
    164   auto layout_2 = shape_2.mutable_layout();
    165   layout_2->clear_minor_to_major();
    166   layout_2->add_minor_to_major(1);
    167   layout_2->add_minor_to_major(0);
    168 
    169   EXPECT_FALSE(ShapeUtil::Equal(shape_1, shape_2));
    170   EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
    171 }
    172 
    173 TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) {
    174   Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
    175   Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
    176   ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
    177 }
    178 
    179 TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) {
    180   Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
    181   Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
    182   ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
    183 }
    184 
    185 TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
    186   Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
    187   Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
    188   EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
    189 }
    190 
    191 TEST(ShapeUtilTest, CompatibleTuples) {
    192   Shape tuple1 = ShapeUtil::MakeTupleShape(
    193       {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
    194   Shape tuple2 = ShapeUtil::MakeTupleShape(
    195       {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
    196   EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
    197 }
    198 
    199 TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) {
    200   Shape tuple1 = ShapeUtil::MakeTupleShape(
    201       {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})});
    202   Shape tuple2 = ShapeUtil::MakeTupleShape(
    203       {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
    204   EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
    205 }
    206 
    207 TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
    208   Shape tuple1 = ShapeUtil::MakeTupleShape(
    209       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
    210   Shape tuple2 = ShapeUtil::MakeTupleShape(
    211       {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
    212   EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
    213   EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
    214 }
    215 
    216 TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) {
    217   Shape tuple1 = ShapeUtil::MakeTupleShape(
    218       {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
    219   Shape tuple2 = ShapeUtil::MakeTupleShape(
    220       {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
    221   EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
    222 }
    223 
    224 TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
    225   Shape tuple1 = ShapeUtil::MakeTupleShape(
    226       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
    227   Shape tuple2 = ShapeUtil::MakeTupleShape(
    228       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})});
    229   EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
    230   EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
    231 }
    232 
    233 TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {
    234   Shape tuple1 = ShapeUtil::MakeTupleShape(
    235       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
    236   Shape tuple2 = ShapeUtil::MakeTupleShape(
    237       {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {4, 2})});
    238   EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
    239 }
    240 
    241 TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) {
    242   Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30});
    243   shape1.mutable_layout()->add_padded_dimensions(10);
    244 
    245   Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30});
    246   shape2.mutable_layout()->add_padded_dimensions(11);
    247 
    248   EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2));
    249 }
    250 
    251 TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) {
    252   Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30});
    253   shape1.mutable_layout()->set_padding_value(ZERO_PAD);
    254 
    255   Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30});
    256   shape2.mutable_layout()->set_padding_value(LOWEST_PAD);
    257 
    258   EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2));
    259 }
    260 
    261 TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) {
    262   Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {});
    263   ASSERT_TRUE(scalar_default_layout.has_layout())
    264       << ShapeUtil::HumanStringWithLayout(scalar_default_layout);
    265 
    266   const Shape scalar_empty_min2maj =
    267       ShapeUtil::MakeShapeWithLayout(F32, {}, {});
    268   ASSERT_TRUE(scalar_empty_min2maj.has_layout())
    269       << ShapeUtil::HumanStringWithLayout(scalar_empty_min2maj);
    270 
    271   EXPECT_TRUE(ShapeUtil::Equal(scalar_default_layout, scalar_empty_min2maj));
    272 }
    273 
    274 TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
    275   EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32));
    276   EXPECT_EQ(4, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {})));
    277   EXPECT_EQ(800, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {10, 20})));
    278 
    279   EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64));
    280   EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {})));
    281   EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20})));
    282 
    283   EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
    284   EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
    285   EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
    286 }
    287 
    288 TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
    289   EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32));
    290   Shape shape = ShapeUtil::MakeShape(F32, {10, 20});
    291   EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape));
    292 
    293   shape.mutable_layout()->add_padded_dimensions(15);
    294   shape.mutable_layout()->add_padded_dimensions(21);
    295   EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape));
    296 }
    297 
    298 TEST(ShapeUtilTest, NestedTuple) {
    299   EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({})));
    300   EXPECT_FALSE(ShapeUtil::IsNestedTuple(
    301       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})})));
    302   EXPECT_TRUE(ShapeUtil::IsNestedTuple(
    303       ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({})})));
    304   EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
    305       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})})));
    306   EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
    307       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeTupleShape({})})));
    308   EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
    309       {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeShape(S32, {})})));
    310   EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
    311       {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeTupleShape({})})));
    312 }
    313 
    314 TEST(ShapeUtilTest, ElementsIn) {
    315   EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {})));
    316   EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0})));
    317   EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1})));
    318   EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 1})));
    319   EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2})));
    320   EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2, 1})));
    321   EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 5})));
    322   EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 0, 5})));
    323   EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0, 3, 0})));
    324   EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 3, 5})));
    325   EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
    326 }
    327 
    328 TEST(ShapeUtilTest, HasZeroElements) {
    329   EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {})));
    330   EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0})));
    331   EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1})));
    332   EXPECT_EQ(false,
    333             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1})));
    334   EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2})));
    335   EXPECT_EQ(false,
    336             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1})));
    337   EXPECT_EQ(false,
    338             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5})));
    339   EXPECT_EQ(true,
    340             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5})));
    341   EXPECT_EQ(true,
    342             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0})));
    343   EXPECT_EQ(false,
    344             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5})));
    345   EXPECT_EQ(false,
    346             ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17})));
    347 }
    348 
    349 TEST(ShapeUtilTest, SameDimensions) {
    350   EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
    351                                         ShapeUtil::MakeShape(S32, {})));
    352   EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
    353                                         ShapeUtil::MakeShape(F32, {})));
    354   EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
    355                                         ShapeUtil::MakeShape(S32, {1})));
    356   EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0}),
    357                                         ShapeUtil::MakeShape(S32, {0})));
    358   EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {2}),
    359                                         ShapeUtil::MakeShape(S32, {2})));
    360   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
    361                                          ShapeUtil::MakeShape(F32, {2})));
    362   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0, 0}),
    363                                          ShapeUtil::MakeShape(F32, {0})));
    364   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
    365                                          ShapeUtil::MakeShape(F32, {1, 1})));
    366   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
    367                                          ShapeUtil::MakeShape(F32, {1})));
    368   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
    369                                          ShapeUtil::MakeShape(F32, {1, 1})));
    370   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
    371                                          ShapeUtil::MakeShape(F32, {1, 0})));
    372   EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1, 1}),
    373                                          ShapeUtil::MakeShape(F32, {1, 2})));
    374 }
    375 
    376 TEST(ShapeUtilTest, GetSubshape) {
    377   // Test array shape.
    378   Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123});
    379   EXPECT_TRUE(
    380       ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(array_shape, {})));
    381   EXPECT_TRUE(ShapeUtil::Equal(
    382       array_shape, *ShapeUtil::GetMutableSubshape(&array_shape, {})));
    383 
    384   // Test tuple shape.
    385   Shape tuple_shape =
    386       ShapeUtil::MakeTupleShape({array_shape, array_shape, array_shape});
    387   EXPECT_TRUE(
    388       ShapeUtil::Equal(tuple_shape, ShapeUtil::GetSubshape(tuple_shape, {})));
    389   EXPECT_TRUE(
    390       ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {0})));
    391   EXPECT_TRUE(
    392       ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {1})));
    393   EXPECT_TRUE(
    394       ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {2})));
    395 
    396   // Test nested tuple shape.
    397   Shape nested_tuple_shape = ShapeUtil::MakeTupleShape(
    398       {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}),
    399        ShapeUtil::MakeTupleShape(
    400            {ShapeUtil::MakeTupleShape({array_shape, array_shape}),
    401             array_shape})});
    402   EXPECT_TRUE(ShapeUtil::Equal(nested_tuple_shape,
    403                                ShapeUtil::GetSubshape(nested_tuple_shape, {})));
    404   EXPECT_TRUE(ShapeUtil::Equal(
    405       array_shape, ShapeUtil::GetSubshape(nested_tuple_shape, {0})));
    406   EXPECT_TRUE(
    407       ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}),
    408                        ShapeUtil::GetSubshape(nested_tuple_shape, {1})));
    409   EXPECT_TRUE(
    410       ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}),
    411                        ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0})));
    412 }
    413 
    414 TEST(ShapeUtilTest, IsLeafIndex) {
    415   // Test array shape.
    416   Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123});
    417   EXPECT_TRUE(ShapeUtil::IsLeafIndex(array_shape, {}));
    418 
    419   // Test tuple shape.
    420   Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape});
    421   EXPECT_FALSE(ShapeUtil::IsLeafIndex(tuple_shape, {}));
    422   EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {0}));
    423   EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {1}));
    424 
    425   // Test nested tuple shape.
    426   Shape nested_tuple_shape = ShapeUtil::MakeTupleShape(
    427       {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}),
    428        ShapeUtil::MakeTupleShape(
    429            {ShapeUtil::MakeTupleShape({array_shape, array_shape}),
    430             array_shape})});
    431   EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {}));
    432   EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {0}));
    433   EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1}));
    434   EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 0}));
    435   EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1}));
    436 }
    437 
    438 TEST(ShapeUtilTest, HumanString) {
    439   Shape opaque = ShapeUtil::MakeOpaqueShape();
    440   Shape scalar = ShapeUtil::MakeShape(F32, {});
    441   Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
    442   Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
    443   Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
    444   Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
    445 
    446   EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
    447   EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
    448   EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
    449   EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
    450   EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
    451             ShapeUtil::HumanString(tuple));
    452   EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
    453             ShapeUtil::HumanString(nested_tuple));
    454 
    455   EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
    456   EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar));
    457   EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix));
    458   EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
    459   EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
    460             ShapeUtil::HumanStringWithLayout(tuple));
    461   EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})",
    462             ShapeUtil::HumanStringWithLayout(nested_tuple));
    463 
    464   ProgramShape prog = ShapeUtil::MakeProgramShape(
    465       {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
    466   EXPECT_EQ(
    467       "((unknown): opaque[], "
    468       "(unknown): f32[], "
    469       "(unknown): u32[1,2], "
    470       "(unknown): s32[3,4], "
    471       "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
    472       "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
    473       "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
    474       ShapeUtil::HumanString(prog));
    475 
    476   prog.add_parameter_names("arg0");
    477   prog.add_parameter_names("scalar");
    478   prog.add_parameter_names("matrix");
    479   prog.add_parameter_names("matrix2");
    480   prog.add_parameter_names("tuple");
    481   prog.add_parameter_names("nested_tuple");
    482   EXPECT_EQ(
    483       "(arg0: opaque[], "
    484       "scalar: f32[], "
    485       "matrix: u32[1,2], "
    486       "matrix2: s32[3,4], "
    487       "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
    488       "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
    489       "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
    490       ShapeUtil::HumanString(prog));
    491 }
    492 
    493 TEST(ShapeUtilTest, ForEachSubshapeArray) {
    494   const Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
    495   int calls = 0;
    496   ShapeUtil::ForEachSubshape(
    497       shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
    498         EXPECT_EQ(&shape, &subshape);
    499         EXPECT_TRUE(index.empty());
    500         ++calls;
    501       });
    502   EXPECT_EQ(1, calls);
    503 }
    504 
    505 TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) {
    506   const Shape shape = ShapeUtil::MakeTupleShape(
    507       {ShapeUtil::MakeShape(F32, {42}),
    508        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
    509                                   ShapeUtil::MakeShape(PRED, {33})})});
    510   int calls = 0;
    511   ShapeUtil::ForEachSubshape(
    512       shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
    513         EXPECT_TRUE(
    514             ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index)));
    515         if (calls == 0) {
    516           // Visitation should go from outside in.
    517           EXPECT_TRUE(index.empty());
    518         } else if (calls == 4) {
    519           // Last visitation should be to the array with 33 elements.
    520           EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape));
    521         }
    522         ++calls;
    523       });
    524   EXPECT_EQ(5, calls);
    525 }
    526 
    527 TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) {
    528   Shape shape = ShapeUtil::MakeTupleShape(
    529       {ShapeUtil::MakeShape(F32, {42}),
    530        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
    531                                   ShapeUtil::MakeShape(PRED, {33})})});
    532   int calls = 0;
    533   ShapeUtil::ForEachMutableSubshape(
    534       &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) {
    535         // Pointer values should be equal
    536         EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index));
    537         if (calls == 0) {
    538           // Visitation should go from outside in.
    539           EXPECT_TRUE(index.empty());
    540         } else if (calls == 4) {
    541           // Last visitation should be to the array with 33 elements.
    542           EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape));
    543         }
    544         ++calls;
    545       });
    546   EXPECT_EQ(5, calls);
    547 }
    548 
    549 TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
    550   Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4});
    551   Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1});
    552   Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12});
    553   EXPECT_TRUE(std::get<0>(
    554       ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1)));
    555   EXPECT_FALSE(std::get<0>(
    556       ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
    557 }
    558 
    559 TEST(ShapeUtilTest, ShapeIs) {
    560   EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {}));
    561 }
    562 
    563 TEST(ShapeUtilTest, ForEachIndex) {
    564   struct ShapeDimensionAndNumberInvocations {
    565     std::vector<int64> dimensions;
    566     int invocations;
    567   } test_data[] = {
    568       {{}, 1},     {{0}, 0},      {{16}, 16},          {{3, 0}, 0},
    569       {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610},
    570   };
    571 
    572   for (const auto& data : test_data) {
    573     Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
    574     // Increments at every invocation.
    575     int invocations = 0;
    576     auto increment_func = [&invocations](const std::vector<int64>& indexes) {
    577       invocations++;
    578       return true;
    579     };
    580 
    581     std::vector<int64> zero_base(data.dimensions.size(), 0);
    582     std::vector<int64> step(data.dimensions.size(), 1);
    583 
    584     ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step,
    585                             increment_func);
    586 
    587     EXPECT_EQ(invocations, data.invocations);
    588   }
    589 }
    590 
    591 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
    592   // All output dimensions should be unmodified. One of the input dimensions is
    593   // modified because the input rank is larger by one.
    594   EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape(
    595                   ShapeUtil::MakeShape(S32, {1, 1, 1, 1}),
    596                   ShapeUtil::MakeShape(S32, {1, 1, 1})),
    597               ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1),
    598                           std::make_pair(2, 2)));
    599 }
    600 
    601 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) {
    602   // All input dimensions should be unmodified. One of the output dimensions is
    603   // modified because the output rank is larger by one.
    604   EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape(
    605                   ShapeUtil::MakeShape(S32, {1, 1, 1}),
    606                   ShapeUtil::MakeShape(S32, {1, 1, 1, 1})),
    607               ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1),
    608                           std::make_pair(2, 2)));
    609 }
    610 
    611 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) {
    612   // The only matching dimension is the one with size 5.
    613   // 4, 1, 3, 5, 6, 7
    614   //          |
    615   // 2, 6, 1, 5, 1, 42
    616   EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape(
    617                   ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}),
    618                   ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})),
    619               ElementsAre(std::make_pair(3, 3)));
    620 }
    621 
    622 TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) {
    623   for (bool input_is_row_major : {true, false}) {
    624     for (bool output_is_row_major : {true, false}) {
    625       Layout input_layout = input_is_row_major ? LayoutUtil::MakeLayout({1, 0})
    626                                                : LayoutUtil::MakeLayout({0, 1});
    627       Layout output_layout = output_is_row_major
    628                                  ? LayoutUtil::MakeLayout({1, 0})
    629                                  : LayoutUtil::MakeLayout({0, 1});
    630       // Suppose the input is logically (i.e. ignoring its layout)
    631       //   0  1  2  3
    632       //   4  5  6  7
    633       //   8  9  10 11
    634       //
    635       // The reshape transforms the input to logically
    636       //   0  1
    637       //   2  3
    638       //   4  5
    639       //   6  7
    640       //   8  9
    641       //   10 11
    642       //
    643       // The input and the output have the same underlying data only if they
    644       // are both row-major.
    645       EXPECT_EQ(
    646           ShapeUtil::ReshapeIsBitcast(
    647               ShapeUtil::MakeShapeWithLayout(
    648                   F32, {3, 4}, AsInt64Slice(input_layout.minor_to_major())),
    649               ShapeUtil::MakeShapeWithLayout(
    650                   F32, {6, 2}, AsInt64Slice(output_layout.minor_to_major()))),
    651           input_is_row_major && output_is_row_major);
    652     }
    653   }
    654 }
    655 
    656 TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
    657   EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(
    658       ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {1, 0, 2}),
    659       ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
    660 }
    661 
    662 TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
    663   EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
    664       ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
    665       ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
    666 }
    667 
    668 TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensions) {
    669   Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11},
    670                                                {3, 2, 1, 0, 4});
    671   auto aligned_shape = ShapeUtil::AlignLayouts(
    672       input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11}));
    673   EXPECT_TRUE(aligned_shape);
    674   EXPECT_THAT(aligned_shape.value().layout().minor_to_major(),
    675               ElementsAre(4, 3, 2, 1, 0, 5));
    676   EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value()));
    677 
    678   aligned_shape = ShapeUtil::AlignLayouts(
    679       input, ShapeUtil::MakeShape(xla::F32, {3, 2, 4, 35, 11}));
    680   EXPECT_TRUE(aligned_shape);
    681   EXPECT_THAT(aligned_shape.value().layout().minor_to_major(),
    682               ElementsAre(3, 2, 1, 0, 4));
    683   EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value()));
    684 }
    685 
    686 TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) {
    687   Shape input =
    688       ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 3, 8, 1, 5, 7, 1, 11, 1, 1},
    689                                      {5, 0, 4, 2, 1, 3, 6, 7, 9, 8});
    690   auto aligned_shape = ShapeUtil::AlignLayouts(
    691       input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1}));
    692   EXPECT_TRUE(aligned_shape);
    693   EXPECT_THAT(aligned_shape.value().layout().minor_to_major(),
    694               ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8));
    695   EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value()));
    696 }
    697 
    698 // A test case where the consecutive elements of the input shape belonging to
    699 // the same layout part are not in descending order.
    700 TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensionsWrongInputLayout) {
    701   // Same physical layout as in AlignLayoutsWithoutTrivialDimensions, except
    702   // that the first two dimension numbers are exchanged.
    703   Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11},
    704                                                {2, 3, 1, 0, 4});
    705   auto aligned_shape = ShapeUtil::AlignLayouts(
    706       input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11}));
    707   EXPECT_FALSE(aligned_shape);
    708 }
    709 
    710 // A test case where the physical layout of the input shape does not place all
    711 // dimensions that belong to the same alignment part consecutively.
    712 TEST(AlignmentTest,
    713      AlignLayoutsWithoutTrivialDimensionsNonConsecutiveAlignmentPart) {
    714   Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11},
    715                                                {3, 2, 1, 0, 4});
    716   auto aligned_shape = ShapeUtil::AlignLayouts(
    717       input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 5, 77}));
    718   EXPECT_FALSE(aligned_shape);
    719 }
    720 
    721 }  // namespace
    722 }  // namespace xla
    723