1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/shape_tree.h" 17 18 #include "tensorflow/compiler/xla/shape_util.h" 19 #include "tensorflow/compiler/xla/test.h" 20 #include "tensorflow/compiler/xla/xla_data.pb.h" 21 22 namespace xla { 23 namespace { 24 25 class ShapeTreeTest : public ::testing::Test { 26 protected: 27 ShapeTreeTest() { 28 array_shape_ = ShapeUtil::MakeShape(F32, {42, 42, 123}); 29 tuple_shape_ = 30 ShapeUtil::MakeTupleShape({array_shape_, array_shape_, array_shape_}); 31 nested_tuple_shape_ = ShapeUtil::MakeTupleShape( 32 {array_shape_, ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 33 ShapeUtil::MakeTupleShape( 34 {ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 35 array_shape_})}); 36 } 37 38 void TestShapeConstructor(const Shape& shape, int expected_num_nodes); 39 void TestInitValueConstructor(const Shape& shape, int expected_num_nodes); 40 41 // An array shape (non-tuple). 42 Shape array_shape_; 43 44 // A three element tuple shape. 45 Shape tuple_shape_; 46 47 // A nested tuple shape of the following form: (a, (c, d), ((e, f), g)) 48 Shape nested_tuple_shape_; 49 }; 50 51 TEST_F(ShapeTreeTest, DefaultConstructor) { 52 ShapeTree<int> int_tree; 53 EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); 54 55 ShapeTree<bool> bool_tree; 56 EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); 57 } 58 59 void ShapeTreeTest::TestShapeConstructor(const Shape& shape, 60 int expected_num_nodes) { 61 ShapeTree<int> int_tree(shape); 62 int num_nodes = 0; 63 int_tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { 64 EXPECT_EQ(0, data); 65 ++num_nodes; 66 }); 67 EXPECT_EQ(expected_num_nodes, num_nodes); 68 69 ShapeTree<bool> bool_tree(shape); 70 num_nodes = 0; 71 bool_tree.ForEachElement( 72 [&num_nodes](const ShapeIndex& /*index*/, bool data) { 73 EXPECT_EQ(false, data); 74 ++num_nodes; 75 }); 76 EXPECT_EQ(expected_num_nodes, num_nodes); 77 } 78 79 TEST_F(ShapeTreeTest, ShapeConstructor) { 80 TestShapeConstructor(array_shape_, 1); 81 TestShapeConstructor(tuple_shape_, 4); 82 TestShapeConstructor(nested_tuple_shape_, 10); 83 } 84 85 void ShapeTreeTest::TestInitValueConstructor(const Shape& shape, 86 int expected_num_nodes) { 87 ShapeTree<int> tree(shape, 42); 88 int num_nodes = 0; 89 tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { 90 EXPECT_EQ(42, data); 91 ++num_nodes; 92 }); 93 EXPECT_EQ(expected_num_nodes, num_nodes); 94 95 num_nodes = 0; 96 tree.ForEachMutableElement( 97 [&num_nodes](const ShapeIndex& /*index*/, int* data) { 98 EXPECT_EQ(42, *data); 99 *data = num_nodes; 100 ++num_nodes; 101 }); 102 EXPECT_EQ(expected_num_nodes, num_nodes); 103 104 num_nodes = 0; 105 tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { 106 EXPECT_EQ(num_nodes, data); 107 ++num_nodes; 108 }); 109 EXPECT_EQ(expected_num_nodes, num_nodes); 110 } 111 112 TEST_F(ShapeTreeTest, InitValueConstructor) { 113 TestInitValueConstructor(array_shape_, 1); 114 TestInitValueConstructor(tuple_shape_, 4); 115 TestInitValueConstructor(nested_tuple_shape_, 10); 116 } 117 118 TEST_F(ShapeTreeTest, ArrayShape) { 119 ShapeTree<int> shape_tree{array_shape_}; 120 *shape_tree.mutable_element({}) = 42; 121 EXPECT_EQ(42, shape_tree.element({})); 122 *shape_tree.mutable_element({}) = 123; 123 EXPECT_EQ(123, shape_tree.element({})); 124 125 EXPECT_TRUE(ShapeUtil::Compatible(array_shape_, shape_tree.shape())); 126 127 // Test the copy constructor. 128 ShapeTree<int> copy{shape_tree}; 129 EXPECT_EQ(123, copy.element({})); 130 131 // Mutate the copy, and ensure the original doesn't change. 132 *copy.mutable_element({}) = 99; 133 EXPECT_EQ(99, copy.element({})); 134 EXPECT_EQ(123, shape_tree.element({})); 135 136 // Test the assignment operator. 137 copy = shape_tree; 138 EXPECT_EQ(123, copy.element({})); 139 } 140 141 TEST_F(ShapeTreeTest, TupleShape) { 142 ShapeTree<int> shape_tree{tuple_shape_}; 143 *shape_tree.mutable_element({}) = 1; 144 *shape_tree.mutable_element({0}) = 42; 145 *shape_tree.mutable_element({1}) = 123; 146 *shape_tree.mutable_element({2}) = -100; 147 EXPECT_EQ(1, shape_tree.element({})); 148 EXPECT_EQ(42, shape_tree.element({0})); 149 EXPECT_EQ(123, shape_tree.element({1})); 150 EXPECT_EQ(-100, shape_tree.element({2})); 151 152 EXPECT_TRUE(ShapeUtil::Compatible(tuple_shape_, shape_tree.shape())); 153 154 // Sum all elements in the shape. 155 int sum = 0; 156 shape_tree.ForEachElement( 157 [&sum](const ShapeIndex& /*index*/, int data) { sum += data; }); 158 EXPECT_EQ(66, sum); 159 160 // Test the copy constructor. 161 ShapeTree<int> copy{shape_tree}; 162 EXPECT_EQ(1, copy.element({})); 163 EXPECT_EQ(42, copy.element({0})); 164 EXPECT_EQ(123, copy.element({1})); 165 EXPECT_EQ(-100, copy.element({2})); 166 167 // Write zero to all data elements. 168 shape_tree.ForEachMutableElement( 169 [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; }); 170 EXPECT_EQ(0, shape_tree.element({})); 171 EXPECT_EQ(0, shape_tree.element({0})); 172 EXPECT_EQ(0, shape_tree.element({1})); 173 EXPECT_EQ(0, shape_tree.element({2})); 174 EXPECT_EQ(1, copy.element({})); 175 EXPECT_EQ(42, copy.element({0})); 176 EXPECT_EQ(123, copy.element({1})); 177 EXPECT_EQ(-100, copy.element({2})); 178 179 // Test the assignment operator. 180 copy = shape_tree; 181 EXPECT_EQ(0, copy.element({})); 182 EXPECT_EQ(0, copy.element({0})); 183 EXPECT_EQ(0, copy.element({1})); 184 EXPECT_EQ(0, copy.element({2})); 185 } 186 187 TEST_F(ShapeTreeTest, NestedTupleShape) { 188 ShapeTree<int> shape_tree{nested_tuple_shape_}; 189 *shape_tree.mutable_element({0}) = 42; 190 *shape_tree.mutable_element({1, 1}) = 123; 191 *shape_tree.mutable_element({2, 0, 1}) = -100; 192 EXPECT_EQ(42, shape_tree.element({0})); 193 EXPECT_EQ(123, shape_tree.element({1, 1})); 194 EXPECT_EQ(-100, shape_tree.element({2, 0, 1})); 195 196 EXPECT_TRUE(ShapeUtil::Compatible(nested_tuple_shape_, shape_tree.shape())); 197 198 // Test the copy constructor. 199 ShapeTree<int> copy{shape_tree}; 200 EXPECT_EQ(42, copy.element({0})); 201 EXPECT_EQ(123, copy.element({1, 1})); 202 EXPECT_EQ(-100, copy.element({2, 0, 1})); 203 204 // Mutate the copy, and ensure the original doesn't change. 205 *copy.mutable_element({0}) = 1; 206 *copy.mutable_element({1, 1}) = 2; 207 *copy.mutable_element({2, 0, 1}) = 3; 208 EXPECT_EQ(1, copy.element({0})); 209 EXPECT_EQ(2, copy.element({1, 1})); 210 EXPECT_EQ(3, copy.element({2, 0, 1})); 211 EXPECT_EQ(42, shape_tree.element({0})); 212 EXPECT_EQ(123, shape_tree.element({1, 1})); 213 EXPECT_EQ(-100, shape_tree.element({2, 0, 1})); 214 215 // Test the assignment operator. 216 copy = shape_tree; 217 EXPECT_EQ(42, copy.element({0})); 218 EXPECT_EQ(123, copy.element({1, 1})); 219 EXPECT_EQ(-100, copy.element({2, 0, 1})); 220 } 221 222 TEST_F(ShapeTreeTest, InvalidIndexingTuple) { 223 ShapeTree<int> shape_tree{tuple_shape_}; 224 225 EXPECT_DEATH(shape_tree.element({4}), ""); 226 } 227 228 TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { 229 ShapeTree<int> shape_tree{nested_tuple_shape_}; 230 231 EXPECT_DEATH(shape_tree.element({0, 0}), ""); 232 } 233 234 TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { 235 ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_}; 236 EXPECT_EQ(shape_tree.element({2}).get(), nullptr); 237 *shape_tree.mutable_element({2}) = MakeUnique<int>(42); 238 EXPECT_EQ(*shape_tree.element({2}), 42); 239 } 240 241 TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) { 242 // Test CopySubtreeFrom method for a single value copied between array-shaped 243 // ShapeTrees. 244 ShapeTree<int> source(array_shape_); 245 *source.mutable_element(/*index=*/{}) = 42; 246 ShapeTree<int> destination(array_shape_, 123); 247 248 EXPECT_EQ(destination.element(/*index=*/{}), 123); 249 destination.CopySubtreeFrom(source, /*source_base_index=*/{}, 250 /*target_base_index=*/{}); 251 EXPECT_EQ(destination.element(/*index=*/{}), 42); 252 } 253 254 TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) { 255 // Test CopySubtreeFrom method for a copy of all elements from one 256 // tuple-shaped ShapeTree to another. 257 ShapeTree<int> source(tuple_shape_); 258 *source.mutable_element(/*index=*/{}) = 10; 259 *source.mutable_element(/*index=*/{0}) = 11; 260 *source.mutable_element(/*index=*/{1}) = 12; 261 *source.mutable_element(/*index=*/{2}) = 13; 262 263 ShapeTree<int> destination(tuple_shape_, 0); 264 265 destination.CopySubtreeFrom(source, /*source_base_index=*/{}, 266 /*target_base_index=*/{}); 267 EXPECT_EQ(destination.element(/*index=*/{}), 10); 268 EXPECT_EQ(destination.element(/*index=*/{0}), 11); 269 EXPECT_EQ(destination.element(/*index=*/{1}), 12); 270 EXPECT_EQ(destination.element(/*index=*/{2}), 13); 271 } 272 273 TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) { 274 // Test CopySubtreeFrom method for a copy of a single element from one 275 // tuple-shaped ShapeTree to another. 276 ShapeTree<int> source(tuple_shape_); 277 *source.mutable_element(/*index=*/{}) = 10; 278 *source.mutable_element(/*index=*/{0}) = 11; 279 *source.mutable_element(/*index=*/{1}) = 12; 280 *source.mutable_element(/*index=*/{2}) = 13; 281 282 ShapeTree<int> destination(tuple_shape_, 0); 283 284 destination.CopySubtreeFrom(source, /*source_base_index=*/{0}, 285 /*target_base_index=*/{1}); 286 EXPECT_EQ(destination.element(/*index=*/{}), 0); 287 EXPECT_EQ(destination.element(/*index=*/{0}), 0); 288 EXPECT_EQ(destination.element(/*index=*/{1}), 11); 289 EXPECT_EQ(destination.element(/*index=*/{2}), 0); 290 } 291 292 TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) { 293 // Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a 294 // nested-tuple-shaped ShapeTree. 295 ShapeTree<int> source( 296 ShapeUtil::MakeTupleShape({array_shape_, array_shape_})); 297 *source.mutable_element(/*index=*/{}) = 10; 298 *source.mutable_element(/*index=*/{0}) = 11; 299 *source.mutable_element(/*index=*/{1}) = 12; 300 301 ShapeTree<int> destination(nested_tuple_shape_, 0); 302 303 destination.CopySubtreeFrom(source, /*source_base_index=*/{}, 304 /*target_base_index=*/{2, 0}); 305 306 EXPECT_EQ(destination.element(/*index=*/{}), 0); 307 EXPECT_EQ(destination.element(/*index=*/{0}), 0); 308 EXPECT_EQ(destination.element(/*index=*/{1}), 0); 309 EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0); 310 EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0); 311 EXPECT_EQ(destination.element(/*index=*/{2}), 0); 312 EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10); 313 EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11); 314 EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12); 315 EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0); 316 } 317 318 TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) { 319 // Test CopySubtreeFrom method for a copy from a nested-tuple-shape. 320 ShapeTree<int> source(nested_tuple_shape_, 42); 321 *source.mutable_element(/*index=*/{1}) = 10; 322 *source.mutable_element(/*index=*/{1, 0}) = 11; 323 *source.mutable_element(/*index=*/{1, 1}) = 12; 324 325 ShapeTree<int> destination( 326 ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0); 327 328 destination.CopySubtreeFrom(source, /*source_base_index=*/{1}, 329 /*target_base_index=*/{}); 330 331 EXPECT_EQ(destination.element(/*index=*/{}), 10); 332 EXPECT_EQ(destination.element(/*index=*/{0}), 11); 333 EXPECT_EQ(destination.element(/*index=*/{1}), 12); 334 } 335 336 TEST_F(ShapeTreeTest, OperatorEquals) { 337 { 338 ShapeTree<int> a(array_shape_, 123); 339 ShapeTree<int> b(array_shape_, 42); 340 ShapeTree<int> c(array_shape_, 42); 341 EXPECT_FALSE(a == b); 342 EXPECT_TRUE(a != b); 343 EXPECT_TRUE(b == c); 344 } 345 { 346 ShapeTree<int> a(tuple_shape_); 347 *a.mutable_element(/*index=*/{}) = 10; 348 *a.mutable_element(/*index=*/{0}) = 11; 349 *a.mutable_element(/*index=*/{1}) = 12; 350 351 ShapeTree<int> b(tuple_shape_); 352 *b.mutable_element(/*index=*/{}) = 10; 353 *b.mutable_element(/*index=*/{0}) = 42; 354 *b.mutable_element(/*index=*/{1}) = 11; 355 356 ShapeTree<int> c(tuple_shape_); 357 *c.mutable_element(/*index=*/{}) = 10; 358 *c.mutable_element(/*index=*/{0}) = 42; 359 *c.mutable_element(/*index=*/{1}) = 11; 360 361 EXPECT_FALSE(a == b); 362 EXPECT_TRUE(a != b); 363 EXPECT_TRUE(b == c); 364 EXPECT_FALSE(b != c); 365 } 366 } 367 368 TEST_F(ShapeTreeTest, ConstructWithPointerToShape) { 369 // Construct a ShapeTree using a pointer to a shape, rather than a reference 370 // to a shape. This constructor is an optimization to let us avoid 371 // constructing and destroying temporary shapes when we have many ShapeTrees. 372 ShapeTree<int> t(&nested_tuple_shape_, 42); 373 int num_nodes = 0; 374 t.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { 375 EXPECT_EQ(42, data); 376 ++num_nodes; 377 }); 378 EXPECT_EQ(10, num_nodes); 379 } 380 381 TEST_F(ShapeTreeTest, CopyWithPointerToShape) { 382 ShapeTree<int> source(&nested_tuple_shape_, 0); 383 ShapeTree<int> dest(source); 384 EXPECT_EQ(&dest.shape(), &nested_tuple_shape_); 385 } 386 387 TEST_F(ShapeTreeTest, CopyAssignWithPointerToShape) { 388 ShapeTree<int> source(&nested_tuple_shape_, 0); 389 ShapeTree<int> dest; 390 dest = source; 391 EXPECT_EQ(&dest.shape(), &nested_tuple_shape_); 392 } 393 394 TEST_F(ShapeTreeTest, IterateSimple) { 395 ShapeTree<int> t(nested_tuple_shape_, 42); 396 int num_nodes = 0; 397 for (auto index_to_data : t) { 398 EXPECT_EQ(42, index_to_data.second); 399 ++num_nodes; 400 } 401 EXPECT_EQ(10, num_nodes); 402 } 403 404 TEST_F(ShapeTreeTest, ConstIterate) { 405 const ShapeTree<int> t(nested_tuple_shape_, 42); 406 int num_nodes = 0; 407 for (const auto& index_to_data : t) { 408 EXPECT_EQ(42, index_to_data.second); 409 ++num_nodes; 410 } 411 EXPECT_EQ(10, num_nodes); 412 } 413 414 TEST_F(ShapeTreeTest, IterateAndMutate) { 415 ShapeTree<int> t(nested_tuple_shape_, 42); 416 int i = 0; 417 for (auto& index_to_data : t) { 418 EXPECT_EQ(42, index_to_data.second); 419 if (i == 1) { 420 index_to_data.second = 98; 421 } 422 ++i; 423 } 424 t.begin()->second = 78; 425 EXPECT_EQ(78, t.begin()->second); 426 i = 0; 427 for (auto& index_to_data : t) { 428 if (i == 0) { 429 EXPECT_EQ(78, index_to_data.second); 430 } else if (i == 1) { 431 EXPECT_EQ(98, index_to_data.second); 432 } else { 433 EXPECT_EQ(42, index_to_data.second); 434 } 435 ++i; 436 } 437 EXPECT_EQ(78, t.begin()->second); 438 EXPECT_EQ(98, std::next(t.begin())->second); 439 } 440 441 TEST_F(ShapeTreeTest, IterateOrder) { 442 ShapeTree<int> t(nested_tuple_shape_, 42); 443 std::vector<ShapeIndex> v; 444 for (auto& index_to_data : t) { 445 v.push_back(index_to_data.first); 446 } 447 EXPECT_EQ(v, (std::vector<ShapeIndex>{{}, 448 {0}, 449 {1}, 450 {1, 0}, 451 {1, 1}, 452 {2}, 453 {2, 0}, 454 {2, 0, 0}, 455 {2, 0, 1}, 456 {2, 1}})); 457 } 458 459 TEST_F(ShapeTreeTest, ReverseIterateOrder) { 460 ShapeTree<int> t(nested_tuple_shape_, 42); 461 std::vector<ShapeIndex> v; 462 for (auto it = t.rbegin(); it != t.rend(); ++it) { 463 v.push_back(it->first); 464 } 465 EXPECT_EQ(v, (std::vector<ShapeIndex>{ 466 {2, 1}, 467 {2, 0, 1}, 468 {2, 0, 0}, 469 {2, 0}, 470 {2}, 471 {1, 1}, 472 {1, 0}, 473 {1}, 474 {0}, 475 {}, 476 })); 477 } 478 479 TEST_F(ShapeTreeTest, IterateOrderLeaves) { 480 ShapeTree<int> t(nested_tuple_shape_, 42); 481 std::vector<ShapeIndex> v; 482 for (auto& index_to_data : t.leaves()) { 483 v.push_back(index_to_data.first); 484 } 485 EXPECT_EQ(v, (std::vector<ShapeIndex>{ 486 {0}, {1, 0}, {1, 1}, {2, 0, 0}, {2, 0, 1}, {2, 1}})); 487 } 488 489 TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { 490 ShapeTree<int> t(nested_tuple_shape_, 42); 491 std::vector<ShapeIndex> v; 492 for (auto it = t.leaf_rbegin(); it != t.leaf_rend(); ++it) { 493 v.push_back(it->first); 494 } 495 EXPECT_EQ(v, (std::vector<ShapeIndex>{ 496 {2, 1}, 497 {2, 0, 1}, 498 {2, 0, 0}, 499 {1, 1}, 500 {1, 0}, 501 {0}, 502 })); 503 } 504 505 } // namespace 506 } // namespace xla 507