1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/shape_util.h" 17 18 #include "tensorflow/compiler/xla/layout_util.h" 19 #include "tensorflow/compiler/xla/status_macros.h" 20 #include "tensorflow/compiler/xla/test.h" 21 #include "tensorflow/compiler/xla/test_helpers.h" 22 #include "tensorflow/compiler/xla/types.h" 23 #include "tensorflow/compiler/xla/util.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 26 namespace xla { 27 namespace { 28 29 using ::testing::ElementsAre; 30 31 TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { 32 Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); 33 EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); 34 EXPECT_EQ(2, ShapeUtil::GetDimension(matrix, -2)); 35 } 36 37 TEST(ShapeUtilTest, GetDimensionHelperExampleInDocumentationTest) { 38 auto shape = ShapeUtil::MakeShape(F32, {1, 2, 3, 4}); 39 ASSERT_EQ(4, ShapeUtil::GetDimension(shape, -1)); 40 } 41 42 TEST(ShapeUtilTest, NegativeIndexOobFails) { 43 Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); 44 ASSERT_DEATH(ShapeUtil::GetDimension(matrix, -3), "dimension_number >= 0"); 45 } 46 47 TEST(ShapeUtilTest, Rank1DimensionIndexing) { 48 Shape shape = ShapeUtil::MakeShape(F32, {3}); 49 ASSERT_EQ(3, shape.dimensions(0)); 50 } 51 52 TEST(ShapeUtilTest, Rank2DimensionIndexing) { 53 Shape shape = ShapeUtil::MakeShape(F32, {3, 2}); 54 ASSERT_EQ(2, shape.dimensions(1)); 55 ASSERT_EQ(3, shape.dimensions(0)); 56 } 57 58 TEST(ShapeUtilTest, Rank3DimensionIndexing) { 59 Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7}); 60 ASSERT_EQ(7, shape.dimensions(2)); 61 ASSERT_EQ(2, shape.dimensions(1)); 62 ASSERT_EQ(3, shape.dimensions(0)); 63 } 64 65 TEST(ShapeUtilTest, Rank4DimensionIndexing) { 66 Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7, 8}); 67 ASSERT_EQ(8, shape.dimensions(3)); 68 ASSERT_EQ(7, shape.dimensions(2)); 69 ASSERT_EQ(2, shape.dimensions(1)); 70 ASSERT_EQ(3, shape.dimensions(0)); 71 } 72 73 TEST(ShapeUtilTest, ParseShapeStringR2F32) { 74 string shape_string = "f32[123,456]"; 75 TF_ASSERT_OK_AND_ASSIGN(Shape actual, 76 ShapeUtil::ParseShapeString(shape_string)); 77 Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); 78 ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) 79 << "expected: " << ShapeUtil::HumanString(expected) 80 << "actual: " << ShapeUtil::HumanString(actual); 81 } 82 83 TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { 84 string shape_string = "(f32[1572864],s8[5120,1024])"; 85 TF_ASSERT_OK_AND_ASSIGN(Shape actual, 86 ShapeUtil::ParseShapeString(shape_string)); 87 Shape expected = 88 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), 89 ShapeUtil::MakeShape(S8, {5120, 1024})}); 90 ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) 91 << "expected: " << ShapeUtil::HumanString(expected) 92 << "actual: " << ShapeUtil::HumanString(actual); 93 } 94 95 TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { 96 string shape_string = "(f32[1],(f32[2]), f32[3])"; 97 TF_ASSERT_OK_AND_ASSIGN(Shape actual, 98 ShapeUtil::ParseShapeString(shape_string)); 99 Shape expected = ShapeUtil::MakeTupleShape({ 100 ShapeUtil::MakeShape(F32, {1}), 101 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), 102 ShapeUtil::MakeShape(F32, {3}), 103 }); 104 ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) 105 << "expected: " << ShapeUtil::HumanString(expected) 106 << "actual: " << ShapeUtil::HumanString(actual); 107 } 108 109 TEST(ShapeUtilTest, ParseShapeStringWithLayout) { 110 string shape_string = "f32[123,456]{0,1}"; 111 TF_ASSERT_OK_AND_ASSIGN(Shape actual, 112 ShapeUtil::ParseShapeString(shape_string)); 113 Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); 114 ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) 115 << "expected: " << ShapeUtil::HumanString(expected) 116 << "actual: " << ShapeUtil::HumanString(actual); 117 } 118 119 TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { 120 string shape_string = "f32[123,456]dense{0,1}"; 121 TF_ASSERT_OK_AND_ASSIGN(Shape actual, 122 ShapeUtil::ParseShapeString(shape_string)); 123 Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); 124 ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) 125 << "expected: " << ShapeUtil::HumanString(expected) 126 << "actual: " << ShapeUtil::HumanString(actual); 127 } 128 129 TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { 130 string shape_string = "f32[123,456]sparse{10}"; 131 TF_ASSERT_OK_AND_ASSIGN(Shape actual, 132 ShapeUtil::ParseShapeString(shape_string)); 133 Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); 134 ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) 135 << "expected: " << ShapeUtil::HumanString(expected) 136 << "actual: " << ShapeUtil::HumanString(actual); 137 } 138 139 TEST(ShapeUtilTest, ParseInvalidShapeString) { 140 string shape_strings[] = { 141 "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", 142 "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", 143 }; 144 for (const string& shape_string : shape_strings) { 145 StatusOr<Shape> result = ShapeUtil::ParseShapeString(shape_string); 146 ASSERT_FALSE(result.ok()) << "shape: " << shape_string; 147 } 148 } 149 150 TEST(ShapeUtilTest, CompatibleIdenticalShapes) { 151 Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); 152 Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); 153 ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2)); 154 } 155 156 TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { 157 Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); 158 auto layout_1 = shape_1.mutable_layout(); 159 layout_1->clear_minor_to_major(); 160 layout_1->add_minor_to_major(0); 161 layout_1->add_minor_to_major(1); 162 163 Shape shape_2 = ShapeUtil::MakeShape(F32, {3, 2}); 164 auto layout_2 = shape_2.mutable_layout(); 165 layout_2->clear_minor_to_major(); 166 layout_2->add_minor_to_major(1); 167 layout_2->add_minor_to_major(0); 168 169 EXPECT_FALSE(ShapeUtil::Equal(shape_1, shape_2)); 170 EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2)); 171 } 172 173 TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) { 174 Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); 175 Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); 176 ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); 177 } 178 179 TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) { 180 Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); 181 Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2}); 182 ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); 183 } 184 185 TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { 186 Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); 187 Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2}); 188 EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); 189 } 190 191 TEST(ShapeUtilTest, CompatibleTuples) { 192 Shape tuple1 = ShapeUtil::MakeTupleShape( 193 {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); 194 Shape tuple2 = ShapeUtil::MakeTupleShape( 195 {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); 196 EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2)); 197 } 198 199 TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) { 200 Shape tuple1 = ShapeUtil::MakeTupleShape( 201 {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})}); 202 Shape tuple2 = ShapeUtil::MakeTupleShape( 203 {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); 204 EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); 205 } 206 207 TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { 208 Shape tuple1 = ShapeUtil::MakeTupleShape( 209 {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); 210 Shape tuple2 = ShapeUtil::MakeTupleShape( 211 {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); 212 EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); 213 EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); 214 } 215 216 TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) { 217 Shape tuple1 = ShapeUtil::MakeTupleShape( 218 {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); 219 Shape tuple2 = ShapeUtil::MakeTupleShape( 220 {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); 221 EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); 222 } 223 224 TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { 225 Shape tuple1 = ShapeUtil::MakeTupleShape( 226 {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); 227 Shape tuple2 = ShapeUtil::MakeTupleShape( 228 {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})}); 229 EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); 230 EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); 231 } 232 233 TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { 234 Shape tuple1 = ShapeUtil::MakeTupleShape( 235 {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); 236 Shape tuple2 = ShapeUtil::MakeTupleShape( 237 {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {4, 2})}); 238 EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); 239 } 240 241 TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { 242 Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); 243 shape1.mutable_layout()->add_padded_dimensions(10); 244 245 Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30}); 246 shape2.mutable_layout()->add_padded_dimensions(11); 247 248 EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); 249 } 250 251 TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) { 252 Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); 253 shape1.mutable_layout()->set_padding_value(ZERO_PAD); 254 255 Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30}); 256 shape2.mutable_layout()->set_padding_value(LOWEST_PAD); 257 258 EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); 259 } 260 261 TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) { 262 Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {}); 263 ASSERT_TRUE(scalar_default_layout.has_layout()) 264 << ShapeUtil::HumanStringWithLayout(scalar_default_layout); 265 266 const Shape scalar_empty_min2maj = 267 ShapeUtil::MakeShapeWithLayout(F32, {}, {}); 268 ASSERT_TRUE(scalar_empty_min2maj.has_layout()) 269 << ShapeUtil::HumanStringWithLayout(scalar_empty_min2maj); 270 271 EXPECT_TRUE(ShapeUtil::Equal(scalar_default_layout, scalar_empty_min2maj)); 272 } 273 274 TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { 275 EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32)); 276 EXPECT_EQ(4, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {}))); 277 EXPECT_EQ(800, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {10, 20}))); 278 279 EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64)); 280 EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {}))); 281 EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20}))); 282 283 EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); 284 EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); 285 EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); 286 } 287 288 TEST(ShapeUtilTest, ByteSizeOfWithPadding) { 289 EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32)); 290 Shape shape = ShapeUtil::MakeShape(F32, {10, 20}); 291 EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape)); 292 293 shape.mutable_layout()->add_padded_dimensions(15); 294 shape.mutable_layout()->add_padded_dimensions(21); 295 EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); 296 } 297 298 TEST(ShapeUtilTest, NestedTuple) { 299 EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({}))); 300 EXPECT_FALSE(ShapeUtil::IsNestedTuple( 301 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); 302 EXPECT_TRUE(ShapeUtil::IsNestedTuple( 303 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({})}))); 304 EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( 305 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}))); 306 EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( 307 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeTupleShape({})}))); 308 EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( 309 {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeShape(S32, {})}))); 310 EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape( 311 {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeTupleShape({})}))); 312 } 313 314 TEST(ShapeUtilTest, ElementsIn) { 315 EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {}))); 316 EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0}))); 317 EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1}))); 318 EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 1}))); 319 EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2}))); 320 EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2, 1}))); 321 EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 5}))); 322 EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 0, 5}))); 323 EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0, 3, 0}))); 324 EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 3, 5}))); 325 EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); 326 } 327 328 TEST(ShapeUtilTest, HasZeroElements) { 329 EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {}))); 330 EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0}))); 331 EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1}))); 332 EXPECT_EQ(false, 333 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1}))); 334 EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2}))); 335 EXPECT_EQ(false, 336 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1}))); 337 EXPECT_EQ(false, 338 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5}))); 339 EXPECT_EQ(true, 340 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5}))); 341 EXPECT_EQ(true, 342 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0}))); 343 EXPECT_EQ(false, 344 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5}))); 345 EXPECT_EQ(false, 346 ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17}))); 347 } 348 349 TEST(ShapeUtilTest, SameDimensions) { 350 EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}), 351 ShapeUtil::MakeShape(S32, {}))); 352 EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}), 353 ShapeUtil::MakeShape(F32, {}))); 354 EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), 355 ShapeUtil::MakeShape(S32, {1}))); 356 EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0}), 357 ShapeUtil::MakeShape(S32, {0}))); 358 EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {2}), 359 ShapeUtil::MakeShape(S32, {2}))); 360 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), 361 ShapeUtil::MakeShape(F32, {2}))); 362 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0, 0}), 363 ShapeUtil::MakeShape(F32, {0}))); 364 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), 365 ShapeUtil::MakeShape(F32, {1, 1}))); 366 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}), 367 ShapeUtil::MakeShape(F32, {1}))); 368 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), 369 ShapeUtil::MakeShape(F32, {1, 1}))); 370 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}), 371 ShapeUtil::MakeShape(F32, {1, 0}))); 372 EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1, 1}), 373 ShapeUtil::MakeShape(F32, {1, 2}))); 374 } 375 376 TEST(ShapeUtilTest, GetSubshape) { 377 // Test array shape. 378 Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123}); 379 EXPECT_TRUE( 380 ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(array_shape, {}))); 381 EXPECT_TRUE(ShapeUtil::Equal( 382 array_shape, *ShapeUtil::GetMutableSubshape(&array_shape, {}))); 383 384 // Test tuple shape. 385 Shape tuple_shape = 386 ShapeUtil::MakeTupleShape({array_shape, array_shape, array_shape}); 387 EXPECT_TRUE( 388 ShapeUtil::Equal(tuple_shape, ShapeUtil::GetSubshape(tuple_shape, {}))); 389 EXPECT_TRUE( 390 ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {0}))); 391 EXPECT_TRUE( 392 ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {1}))); 393 EXPECT_TRUE( 394 ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {2}))); 395 396 // Test nested tuple shape. 397 Shape nested_tuple_shape = ShapeUtil::MakeTupleShape( 398 {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}), 399 ShapeUtil::MakeTupleShape( 400 {ShapeUtil::MakeTupleShape({array_shape, array_shape}), 401 array_shape})}); 402 EXPECT_TRUE(ShapeUtil::Equal(nested_tuple_shape, 403 ShapeUtil::GetSubshape(nested_tuple_shape, {}))); 404 EXPECT_TRUE(ShapeUtil::Equal( 405 array_shape, ShapeUtil::GetSubshape(nested_tuple_shape, {0}))); 406 EXPECT_TRUE( 407 ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}), 408 ShapeUtil::GetSubshape(nested_tuple_shape, {1}))); 409 EXPECT_TRUE( 410 ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}), 411 ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0}))); 412 } 413 414 TEST(ShapeUtilTest, IsLeafIndex) { 415 // Test array shape. 416 Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123}); 417 EXPECT_TRUE(ShapeUtil::IsLeafIndex(array_shape, {})); 418 419 // Test tuple shape. 420 Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape}); 421 EXPECT_FALSE(ShapeUtil::IsLeafIndex(tuple_shape, {})); 422 EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {0})); 423 EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {1})); 424 425 // Test nested tuple shape. 426 Shape nested_tuple_shape = ShapeUtil::MakeTupleShape( 427 {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}), 428 ShapeUtil::MakeTupleShape( 429 {ShapeUtil::MakeTupleShape({array_shape, array_shape}), 430 array_shape})}); 431 EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {})); 432 EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {0})); 433 EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1})); 434 EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 0})); 435 EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1})); 436 } 437 438 TEST(ShapeUtilTest, HumanString) { 439 Shape opaque = ShapeUtil::MakeOpaqueShape(); 440 Shape scalar = ShapeUtil::MakeShape(F32, {}); 441 Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); 442 Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); 443 Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); 444 Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); 445 446 EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); 447 EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); 448 EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); 449 EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); 450 EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", 451 ShapeUtil::HumanString(tuple)); 452 EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", 453 ShapeUtil::HumanString(nested_tuple)); 454 455 EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); 456 EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar)); 457 EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix)); 458 EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); 459 EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", 460 ShapeUtil::HumanStringWithLayout(tuple)); 461 EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", 462 ShapeUtil::HumanStringWithLayout(nested_tuple)); 463 464 ProgramShape prog = ShapeUtil::MakeProgramShape( 465 {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); 466 EXPECT_EQ( 467 "((unknown): opaque[], " 468 "(unknown): f32[], " 469 "(unknown): u32[1,2], " 470 "(unknown): s32[3,4], " 471 "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " 472 "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " 473 "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", 474 ShapeUtil::HumanString(prog)); 475 476 prog.add_parameter_names("arg0"); 477 prog.add_parameter_names("scalar"); 478 prog.add_parameter_names("matrix"); 479 prog.add_parameter_names("matrix2"); 480 prog.add_parameter_names("tuple"); 481 prog.add_parameter_names("nested_tuple"); 482 EXPECT_EQ( 483 "(arg0: opaque[], " 484 "scalar: f32[], " 485 "matrix: u32[1,2], " 486 "matrix2: s32[3,4], " 487 "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " 488 "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " 489 "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", 490 ShapeUtil::HumanString(prog)); 491 } 492 493 TEST(ShapeUtilTest, ForEachSubshapeArray) { 494 const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); 495 int calls = 0; 496 ShapeUtil::ForEachSubshape( 497 shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { 498 EXPECT_EQ(&shape, &subshape); 499 EXPECT_TRUE(index.empty()); 500 ++calls; 501 }); 502 EXPECT_EQ(1, calls); 503 } 504 505 TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) { 506 const Shape shape = ShapeUtil::MakeTupleShape( 507 {ShapeUtil::MakeShape(F32, {42}), 508 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), 509 ShapeUtil::MakeShape(PRED, {33})})}); 510 int calls = 0; 511 ShapeUtil::ForEachSubshape( 512 shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { 513 EXPECT_TRUE( 514 ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index))); 515 if (calls == 0) { 516 // Visitation should go from outside in. 517 EXPECT_TRUE(index.empty()); 518 } else if (calls == 4) { 519 // Last visitation should be to the array with 33 elements. 520 EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape)); 521 } 522 ++calls; 523 }); 524 EXPECT_EQ(5, calls); 525 } 526 527 TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) { 528 Shape shape = ShapeUtil::MakeTupleShape( 529 {ShapeUtil::MakeShape(F32, {42}), 530 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), 531 ShapeUtil::MakeShape(PRED, {33})})}); 532 int calls = 0; 533 ShapeUtil::ForEachMutableSubshape( 534 &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) { 535 // Pointer values should be equal 536 EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index)); 537 if (calls == 0) { 538 // Visitation should go from outside in. 539 EXPECT_TRUE(index.empty()); 540 } else if (calls == 4) { 541 // Last visitation should be to the array with 33 elements. 542 EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape)); 543 } 544 ++calls; 545 }); 546 EXPECT_EQ(5, calls); 547 } 548 549 TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { 550 Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4}); 551 Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1}); 552 Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12}); 553 EXPECT_TRUE(std::get<0>( 554 ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1))); 555 EXPECT_FALSE(std::get<0>( 556 ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); 557 } 558 559 TEST(ShapeUtilTest, ShapeIs) { 560 EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {})); 561 } 562 563 TEST(ShapeUtilTest, ForEachIndex) { 564 struct ShapeDimensionAndNumberInvocations { 565 std::vector<int64> dimensions; 566 int invocations; 567 } test_data[] = { 568 {{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0}, 569 {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610}, 570 }; 571 572 for (const auto& data : test_data) { 573 Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); 574 // Increments at every invocation. 575 int invocations = 0; 576 auto increment_func = [&invocations](const std::vector<int64>& indexes) { 577 invocations++; 578 return true; 579 }; 580 581 std::vector<int64> zero_base(data.dimensions.size(), 0); 582 std::vector<int64> step(data.dimensions.size(), 1); 583 584 ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step, 585 increment_func); 586 587 EXPECT_EQ(invocations, data.invocations); 588 } 589 } 590 591 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { 592 // All output dimensions should be unmodified. One of the input dimensions is 593 // modified because the input rank is larger by one. 594 EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( 595 ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), 596 ShapeUtil::MakeShape(S32, {1, 1, 1})), 597 ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), 598 std::make_pair(2, 2))); 599 } 600 601 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) { 602 // All input dimensions should be unmodified. One of the output dimensions is 603 // modified because the output rank is larger by one. 604 EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( 605 ShapeUtil::MakeShape(S32, {1, 1, 1}), 606 ShapeUtil::MakeShape(S32, {1, 1, 1, 1})), 607 ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), 608 std::make_pair(2, 2))); 609 } 610 611 TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { 612 // The only matching dimension is the one with size 5. 613 // 4, 1, 3, 5, 6, 7 614 // | 615 // 2, 6, 1, 5, 1, 42 616 EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( 617 ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), 618 ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), 619 ElementsAre(std::make_pair(3, 3))); 620 } 621 622 TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) { 623 for (bool input_is_row_major : {true, false}) { 624 for (bool output_is_row_major : {true, false}) { 625 Layout input_layout = input_is_row_major ? LayoutUtil::MakeLayout({1, 0}) 626 : LayoutUtil::MakeLayout({0, 1}); 627 Layout output_layout = output_is_row_major 628 ? LayoutUtil::MakeLayout({1, 0}) 629 : LayoutUtil::MakeLayout({0, 1}); 630 // Suppose the input is logically (i.e. ignoring its layout) 631 // 0 1 2 3 632 // 4 5 6 7 633 // 8 9 10 11 634 // 635 // The reshape transforms the input to logically 636 // 0 1 637 // 2 3 638 // 4 5 639 // 6 7 640 // 8 9 641 // 10 11 642 // 643 // The input and the output have the same underlying data only if they 644 // are both row-major. 645 EXPECT_EQ( 646 ShapeUtil::ReshapeIsBitcast( 647 ShapeUtil::MakeShapeWithLayout( 648 F32, {3, 4}, AsInt64Slice(input_layout.minor_to_major())), 649 ShapeUtil::MakeShapeWithLayout( 650 F32, {6, 2}, AsInt64Slice(output_layout.minor_to_major()))), 651 input_is_row_major && output_is_row_major); 652 } 653 } 654 } 655 656 TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { 657 EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast( 658 ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {1, 0, 2}), 659 ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); 660 } 661 662 TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { 663 EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( 664 ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), 665 ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); 666 } 667 668 TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensions) { 669 Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, 670 {3, 2, 1, 0, 4}); 671 auto aligned_shape = ShapeUtil::AlignLayouts( 672 input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); 673 EXPECT_TRUE(aligned_shape); 674 EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), 675 ElementsAre(4, 3, 2, 1, 0, 5)); 676 EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); 677 678 aligned_shape = ShapeUtil::AlignLayouts( 679 input, ShapeUtil::MakeShape(xla::F32, {3, 2, 4, 35, 11})); 680 EXPECT_TRUE(aligned_shape); 681 EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), 682 ElementsAre(3, 2, 1, 0, 4)); 683 EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); 684 } 685 686 TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { 687 Shape input = 688 ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 3, 8, 1, 5, 7, 1, 11, 1, 1}, 689 {5, 0, 4, 2, 1, 3, 6, 7, 9, 8}); 690 auto aligned_shape = ShapeUtil::AlignLayouts( 691 input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); 692 EXPECT_TRUE(aligned_shape); 693 EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), 694 ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); 695 EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); 696 } 697 698 // A test case where the consecutive elements of the input shape belonging to 699 // the same layout part are not in descending order. 700 TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensionsWrongInputLayout) { 701 // Same physical layout as in AlignLayoutsWithoutTrivialDimensions, except 702 // that the first two dimension numbers are exchanged. 703 Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, 704 {2, 3, 1, 0, 4}); 705 auto aligned_shape = ShapeUtil::AlignLayouts( 706 input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); 707 EXPECT_FALSE(aligned_shape); 708 } 709 710 // A test case where the physical layout of the input shape does not place all 711 // dimensions that belong to the same alignment part consecutively. 712 TEST(AlignmentTest, 713 AlignLayoutsWithoutTrivialDimensionsNonConsecutiveAlignmentPart) { 714 Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, 715 {3, 2, 1, 0, 4}); 716 auto aligned_shape = ShapeUtil::AlignLayouts( 717 input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 5, 77})); 718 EXPECT_FALSE(aligned_shape); 719 } 720 721 } // namespace 722 } // namespace xla 723