Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/shape_tree.h"
     17 
     18 #include "tensorflow/compiler/xla/shape_util.h"
     19 #include "tensorflow/compiler/xla/test.h"
     20 #include "tensorflow/compiler/xla/xla_data.pb.h"
     21 
     22 namespace xla {
     23 namespace {
     24 
     25 class ShapeTreeTest : public ::testing::Test {
     26  protected:
     27   ShapeTreeTest() {
     28     array_shape_ = ShapeUtil::MakeShape(F32, {42, 42, 123});
     29     tuple_shape_ =
     30         ShapeUtil::MakeTupleShape({array_shape_, array_shape_, array_shape_});
     31     nested_tuple_shape_ = ShapeUtil::MakeTupleShape(
     32         {array_shape_, ShapeUtil::MakeTupleShape({array_shape_, array_shape_}),
     33          ShapeUtil::MakeTupleShape(
     34              {ShapeUtil::MakeTupleShape({array_shape_, array_shape_}),
     35               array_shape_})});
     36   }
     37 
     38   void TestShapeConstructor(const Shape& shape, int expected_num_nodes);
     39   void TestInitValueConstructor(const Shape& shape, int expected_num_nodes);
     40 
     41   // An array shape (non-tuple).
     42   Shape array_shape_;
     43 
     44   // A three element tuple shape.
     45   Shape tuple_shape_;
     46 
     47   // A nested tuple shape of the following form: (a, (c, d), ((e, f), g))
     48   Shape nested_tuple_shape_;
     49 };
     50 
     51 TEST_F(ShapeTreeTest, DefaultConstructor) {
     52   ShapeTree<int> int_tree;
     53   EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape()));
     54 
     55   ShapeTree<bool> bool_tree;
     56   EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape()));
     57 }
     58 
     59 void ShapeTreeTest::TestShapeConstructor(const Shape& shape,
     60                                          int expected_num_nodes) {
     61   ShapeTree<int> int_tree(shape);
     62   int num_nodes = 0;
     63   int_tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) {
     64     EXPECT_EQ(0, data);
     65     ++num_nodes;
     66   });
     67   EXPECT_EQ(expected_num_nodes, num_nodes);
     68 
     69   ShapeTree<bool> bool_tree(shape);
     70   num_nodes = 0;
     71   bool_tree.ForEachElement(
     72       [&num_nodes](const ShapeIndex& /*index*/, bool data) {
     73         EXPECT_EQ(false, data);
     74         ++num_nodes;
     75       });
     76   EXPECT_EQ(expected_num_nodes, num_nodes);
     77 }
     78 
     79 TEST_F(ShapeTreeTest, ShapeConstructor) {
     80   TestShapeConstructor(array_shape_, 1);
     81   TestShapeConstructor(tuple_shape_, 4);
     82   TestShapeConstructor(nested_tuple_shape_, 10);
     83 }
     84 
     85 void ShapeTreeTest::TestInitValueConstructor(const Shape& shape,
     86                                              int expected_num_nodes) {
     87   ShapeTree<int> tree(shape, 42);
     88   int num_nodes = 0;
     89   tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) {
     90     EXPECT_EQ(42, data);
     91     ++num_nodes;
     92   });
     93   EXPECT_EQ(expected_num_nodes, num_nodes);
     94 
     95   num_nodes = 0;
     96   tree.ForEachMutableElement(
     97       [&num_nodes](const ShapeIndex& /*index*/, int* data) {
     98         EXPECT_EQ(42, *data);
     99         *data = num_nodes;
    100         ++num_nodes;
    101       });
    102   EXPECT_EQ(expected_num_nodes, num_nodes);
    103 
    104   num_nodes = 0;
    105   tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) {
    106     EXPECT_EQ(num_nodes, data);
    107     ++num_nodes;
    108   });
    109   EXPECT_EQ(expected_num_nodes, num_nodes);
    110 }
    111 
    112 TEST_F(ShapeTreeTest, InitValueConstructor) {
    113   TestInitValueConstructor(array_shape_, 1);
    114   TestInitValueConstructor(tuple_shape_, 4);
    115   TestInitValueConstructor(nested_tuple_shape_, 10);
    116 }
    117 
    118 TEST_F(ShapeTreeTest, ArrayShape) {
    119   ShapeTree<int> shape_tree{array_shape_};
    120   *shape_tree.mutable_element({}) = 42;
    121   EXPECT_EQ(42, shape_tree.element({}));
    122   *shape_tree.mutable_element({}) = 123;
    123   EXPECT_EQ(123, shape_tree.element({}));
    124 
    125   EXPECT_TRUE(ShapeUtil::Compatible(array_shape_, shape_tree.shape()));
    126 
    127   // Test the copy constructor.
    128   ShapeTree<int> copy{shape_tree};
    129   EXPECT_EQ(123, copy.element({}));
    130 
    131   // Mutate the copy, and ensure the original doesn't change.
    132   *copy.mutable_element({}) = 99;
    133   EXPECT_EQ(99, copy.element({}));
    134   EXPECT_EQ(123, shape_tree.element({}));
    135 
    136   // Test the assignment operator.
    137   copy = shape_tree;
    138   EXPECT_EQ(123, copy.element({}));
    139 }
    140 
    141 TEST_F(ShapeTreeTest, TupleShape) {
    142   ShapeTree<int> shape_tree{tuple_shape_};
    143   *shape_tree.mutable_element({}) = 1;
    144   *shape_tree.mutable_element({0}) = 42;
    145   *shape_tree.mutable_element({1}) = 123;
    146   *shape_tree.mutable_element({2}) = -100;
    147   EXPECT_EQ(1, shape_tree.element({}));
    148   EXPECT_EQ(42, shape_tree.element({0}));
    149   EXPECT_EQ(123, shape_tree.element({1}));
    150   EXPECT_EQ(-100, shape_tree.element({2}));
    151 
    152   EXPECT_TRUE(ShapeUtil::Compatible(tuple_shape_, shape_tree.shape()));
    153 
    154   // Sum all elements in the shape.
    155   int sum = 0;
    156   shape_tree.ForEachElement(
    157       [&sum](const ShapeIndex& /*index*/, int data) { sum += data; });
    158   EXPECT_EQ(66, sum);
    159 
    160   // Test the copy constructor.
    161   ShapeTree<int> copy{shape_tree};
    162   EXPECT_EQ(1, copy.element({}));
    163   EXPECT_EQ(42, copy.element({0}));
    164   EXPECT_EQ(123, copy.element({1}));
    165   EXPECT_EQ(-100, copy.element({2}));
    166 
    167   // Write zero to all data elements.
    168   shape_tree.ForEachMutableElement(
    169       [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; });
    170   EXPECT_EQ(0, shape_tree.element({}));
    171   EXPECT_EQ(0, shape_tree.element({0}));
    172   EXPECT_EQ(0, shape_tree.element({1}));
    173   EXPECT_EQ(0, shape_tree.element({2}));
    174   EXPECT_EQ(1, copy.element({}));
    175   EXPECT_EQ(42, copy.element({0}));
    176   EXPECT_EQ(123, copy.element({1}));
    177   EXPECT_EQ(-100, copy.element({2}));
    178 
    179   // Test the assignment operator.
    180   copy = shape_tree;
    181   EXPECT_EQ(0, copy.element({}));
    182   EXPECT_EQ(0, copy.element({0}));
    183   EXPECT_EQ(0, copy.element({1}));
    184   EXPECT_EQ(0, copy.element({2}));
    185 }
    186 
    187 TEST_F(ShapeTreeTest, NestedTupleShape) {
    188   ShapeTree<int> shape_tree{nested_tuple_shape_};
    189   *shape_tree.mutable_element({0}) = 42;
    190   *shape_tree.mutable_element({1, 1}) = 123;
    191   *shape_tree.mutable_element({2, 0, 1}) = -100;
    192   EXPECT_EQ(42, shape_tree.element({0}));
    193   EXPECT_EQ(123, shape_tree.element({1, 1}));
    194   EXPECT_EQ(-100, shape_tree.element({2, 0, 1}));
    195 
    196   EXPECT_TRUE(ShapeUtil::Compatible(nested_tuple_shape_, shape_tree.shape()));
    197 
    198   // Test the copy constructor.
    199   ShapeTree<int> copy{shape_tree};
    200   EXPECT_EQ(42, copy.element({0}));
    201   EXPECT_EQ(123, copy.element({1, 1}));
    202   EXPECT_EQ(-100, copy.element({2, 0, 1}));
    203 
    204   // Mutate the copy, and ensure the original doesn't change.
    205   *copy.mutable_element({0}) = 1;
    206   *copy.mutable_element({1, 1}) = 2;
    207   *copy.mutable_element({2, 0, 1}) = 3;
    208   EXPECT_EQ(1, copy.element({0}));
    209   EXPECT_EQ(2, copy.element({1, 1}));
    210   EXPECT_EQ(3, copy.element({2, 0, 1}));
    211   EXPECT_EQ(42, shape_tree.element({0}));
    212   EXPECT_EQ(123, shape_tree.element({1, 1}));
    213   EXPECT_EQ(-100, shape_tree.element({2, 0, 1}));
    214 
    215   // Test the assignment operator.
    216   copy = shape_tree;
    217   EXPECT_EQ(42, copy.element({0}));
    218   EXPECT_EQ(123, copy.element({1, 1}));
    219   EXPECT_EQ(-100, copy.element({2, 0, 1}));
    220 }
    221 
    222 TEST_F(ShapeTreeTest, InvalidIndexingTuple) {
    223   ShapeTree<int> shape_tree{tuple_shape_};
    224 
    225   EXPECT_DEATH(shape_tree.element({4}), "");
    226 }
    227 
    228 TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
    229   ShapeTree<int> shape_tree{nested_tuple_shape_};
    230 
    231   EXPECT_DEATH(shape_tree.element({0, 0}), "");
    232 }
    233 
    234 TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
    235   ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
    236   EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
    237   *shape_tree.mutable_element({2}) = MakeUnique<int>(42);
    238   EXPECT_EQ(*shape_tree.element({2}), 42);
    239 }
    240 
    241 TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
    242   // Test CopySubtreeFrom method for a single value copied between array-shaped
    243   // ShapeTrees.
    244   ShapeTree<int> source(array_shape_);
    245   *source.mutable_element(/*index=*/{}) = 42;
    246   ShapeTree<int> destination(array_shape_, 123);
    247 
    248   EXPECT_EQ(destination.element(/*index=*/{}), 123);
    249   destination.CopySubtreeFrom(source, /*source_base_index=*/{},
    250                               /*target_base_index=*/{});
    251   EXPECT_EQ(destination.element(/*index=*/{}), 42);
    252 }
    253 
    254 TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
    255   // Test CopySubtreeFrom method for a copy of all elements from one
    256   // tuple-shaped ShapeTree to another.
    257   ShapeTree<int> source(tuple_shape_);
    258   *source.mutable_element(/*index=*/{}) = 10;
    259   *source.mutable_element(/*index=*/{0}) = 11;
    260   *source.mutable_element(/*index=*/{1}) = 12;
    261   *source.mutable_element(/*index=*/{2}) = 13;
    262 
    263   ShapeTree<int> destination(tuple_shape_, 0);
    264 
    265   destination.CopySubtreeFrom(source, /*source_base_index=*/{},
    266                               /*target_base_index=*/{});
    267   EXPECT_EQ(destination.element(/*index=*/{}), 10);
    268   EXPECT_EQ(destination.element(/*index=*/{0}), 11);
    269   EXPECT_EQ(destination.element(/*index=*/{1}), 12);
    270   EXPECT_EQ(destination.element(/*index=*/{2}), 13);
    271 }
    272 
    273 TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
    274   // Test CopySubtreeFrom method for a copy of a single element from one
    275   // tuple-shaped ShapeTree to another.
    276   ShapeTree<int> source(tuple_shape_);
    277   *source.mutable_element(/*index=*/{}) = 10;
    278   *source.mutable_element(/*index=*/{0}) = 11;
    279   *source.mutable_element(/*index=*/{1}) = 12;
    280   *source.mutable_element(/*index=*/{2}) = 13;
    281 
    282   ShapeTree<int> destination(tuple_shape_, 0);
    283 
    284   destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
    285                               /*target_base_index=*/{1});
    286   EXPECT_EQ(destination.element(/*index=*/{}), 0);
    287   EXPECT_EQ(destination.element(/*index=*/{0}), 0);
    288   EXPECT_EQ(destination.element(/*index=*/{1}), 11);
    289   EXPECT_EQ(destination.element(/*index=*/{2}), 0);
    290 }
    291 
    292 TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
    293   // Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
    294   // nested-tuple-shaped ShapeTree.
    295   ShapeTree<int> source(
    296       ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
    297   *source.mutable_element(/*index=*/{}) = 10;
    298   *source.mutable_element(/*index=*/{0}) = 11;
    299   *source.mutable_element(/*index=*/{1}) = 12;
    300 
    301   ShapeTree<int> destination(nested_tuple_shape_, 0);
    302 
    303   destination.CopySubtreeFrom(source, /*source_base_index=*/{},
    304                               /*target_base_index=*/{2, 0});
    305 
    306   EXPECT_EQ(destination.element(/*index=*/{}), 0);
    307   EXPECT_EQ(destination.element(/*index=*/{0}), 0);
    308   EXPECT_EQ(destination.element(/*index=*/{1}), 0);
    309   EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
    310   EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
    311   EXPECT_EQ(destination.element(/*index=*/{2}), 0);
    312   EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
    313   EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
    314   EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
    315   EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
    316 }
    317 
    318 TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
    319   // Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
    320   ShapeTree<int> source(nested_tuple_shape_, 42);
    321   *source.mutable_element(/*index=*/{1}) = 10;
    322   *source.mutable_element(/*index=*/{1, 0}) = 11;
    323   *source.mutable_element(/*index=*/{1, 1}) = 12;
    324 
    325   ShapeTree<int> destination(
    326       ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
    327 
    328   destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
    329                               /*target_base_index=*/{});
    330 
    331   EXPECT_EQ(destination.element(/*index=*/{}), 10);
    332   EXPECT_EQ(destination.element(/*index=*/{0}), 11);
    333   EXPECT_EQ(destination.element(/*index=*/{1}), 12);
    334 }
    335 
    336 TEST_F(ShapeTreeTest, OperatorEquals) {
    337   {
    338     ShapeTree<int> a(array_shape_, 123);
    339     ShapeTree<int> b(array_shape_, 42);
    340     ShapeTree<int> c(array_shape_, 42);
    341     EXPECT_FALSE(a == b);
    342     EXPECT_TRUE(a != b);
    343     EXPECT_TRUE(b == c);
    344   }
    345   {
    346     ShapeTree<int> a(tuple_shape_);
    347     *a.mutable_element(/*index=*/{}) = 10;
    348     *a.mutable_element(/*index=*/{0}) = 11;
    349     *a.mutable_element(/*index=*/{1}) = 12;
    350 
    351     ShapeTree<int> b(tuple_shape_);
    352     *b.mutable_element(/*index=*/{}) = 10;
    353     *b.mutable_element(/*index=*/{0}) = 42;
    354     *b.mutable_element(/*index=*/{1}) = 11;
    355 
    356     ShapeTree<int> c(tuple_shape_);
    357     *c.mutable_element(/*index=*/{}) = 10;
    358     *c.mutable_element(/*index=*/{0}) = 42;
    359     *c.mutable_element(/*index=*/{1}) = 11;
    360 
    361     EXPECT_FALSE(a == b);
    362     EXPECT_TRUE(a != b);
    363     EXPECT_TRUE(b == c);
    364     EXPECT_FALSE(b != c);
    365   }
    366 }
    367 
    368 TEST_F(ShapeTreeTest, ConstructWithPointerToShape) {
    369   // Construct a ShapeTree using a pointer to a shape, rather than a reference
    370   // to a shape.  This constructor is an optimization to let us avoid
    371   // constructing and destroying temporary shapes when we have many ShapeTrees.
    372   ShapeTree<int> t(&nested_tuple_shape_, 42);
    373   int num_nodes = 0;
    374   t.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) {
    375     EXPECT_EQ(42, data);
    376     ++num_nodes;
    377   });
    378   EXPECT_EQ(10, num_nodes);
    379 }
    380 
    381 TEST_F(ShapeTreeTest, CopyWithPointerToShape) {
    382   ShapeTree<int> source(&nested_tuple_shape_, 0);
    383   ShapeTree<int> dest(source);
    384   EXPECT_EQ(&dest.shape(), &nested_tuple_shape_);
    385 }
    386 
    387 TEST_F(ShapeTreeTest, CopyAssignWithPointerToShape) {
    388   ShapeTree<int> source(&nested_tuple_shape_, 0);
    389   ShapeTree<int> dest;
    390   dest = source;
    391   EXPECT_EQ(&dest.shape(), &nested_tuple_shape_);
    392 }
    393 
    394 TEST_F(ShapeTreeTest, IterateSimple) {
    395   ShapeTree<int> t(nested_tuple_shape_, 42);
    396   int num_nodes = 0;
    397   for (auto index_to_data : t) {
    398     EXPECT_EQ(42, index_to_data.second);
    399     ++num_nodes;
    400   }
    401   EXPECT_EQ(10, num_nodes);
    402 }
    403 
    404 TEST_F(ShapeTreeTest, ConstIterate) {
    405   const ShapeTree<int> t(nested_tuple_shape_, 42);
    406   int num_nodes = 0;
    407   for (const auto& index_to_data : t) {
    408     EXPECT_EQ(42, index_to_data.second);
    409     ++num_nodes;
    410   }
    411   EXPECT_EQ(10, num_nodes);
    412 }
    413 
    414 TEST_F(ShapeTreeTest, IterateAndMutate) {
    415   ShapeTree<int> t(nested_tuple_shape_, 42);
    416   int i = 0;
    417   for (auto& index_to_data : t) {
    418     EXPECT_EQ(42, index_to_data.second);
    419     if (i == 1) {
    420       index_to_data.second = 98;
    421     }
    422     ++i;
    423   }
    424   t.begin()->second = 78;
    425   EXPECT_EQ(78, t.begin()->second);
    426   i = 0;
    427   for (auto& index_to_data : t) {
    428     if (i == 0) {
    429       EXPECT_EQ(78, index_to_data.second);
    430     } else if (i == 1) {
    431       EXPECT_EQ(98, index_to_data.second);
    432     } else {
    433       EXPECT_EQ(42, index_to_data.second);
    434     }
    435     ++i;
    436   }
    437   EXPECT_EQ(78, t.begin()->second);
    438   EXPECT_EQ(98, std::next(t.begin())->second);
    439 }
    440 
    441 TEST_F(ShapeTreeTest, IterateOrder) {
    442   ShapeTree<int> t(nested_tuple_shape_, 42);
    443   std::vector<ShapeIndex> v;
    444   for (auto& index_to_data : t) {
    445     v.push_back(index_to_data.first);
    446   }
    447   EXPECT_EQ(v, (std::vector<ShapeIndex>{{},
    448                                         {0},
    449                                         {1},
    450                                         {1, 0},
    451                                         {1, 1},
    452                                         {2},
    453                                         {2, 0},
    454                                         {2, 0, 0},
    455                                         {2, 0, 1},
    456                                         {2, 1}}));
    457 }
    458 
    459 TEST_F(ShapeTreeTest, ReverseIterateOrder) {
    460   ShapeTree<int> t(nested_tuple_shape_, 42);
    461   std::vector<ShapeIndex> v;
    462   for (auto it = t.rbegin(); it != t.rend(); ++it) {
    463     v.push_back(it->first);
    464   }
    465   EXPECT_EQ(v, (std::vector<ShapeIndex>{
    466                    {2, 1},
    467                    {2, 0, 1},
    468                    {2, 0, 0},
    469                    {2, 0},
    470                    {2},
    471                    {1, 1},
    472                    {1, 0},
    473                    {1},
    474                    {0},
    475                    {},
    476                }));
    477 }
    478 
    479 TEST_F(ShapeTreeTest, IterateOrderLeaves) {
    480   ShapeTree<int> t(nested_tuple_shape_, 42);
    481   std::vector<ShapeIndex> v;
    482   for (auto& index_to_data : t.leaves()) {
    483     v.push_back(index_to_data.first);
    484   }
    485   EXPECT_EQ(v, (std::vector<ShapeIndex>{
    486                    {0}, {1, 0}, {1, 1}, {2, 0, 0}, {2, 0, 1}, {2, 1}}));
    487 }
    488 
    489 TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) {
    490   ShapeTree<int> t(nested_tuple_shape_, 42);
    491   std::vector<ShapeIndex> v;
    492   for (auto it = t.leaf_rbegin(); it != t.leaf_rend(); ++it) {
    493     v.push_back(it->first);
    494   }
    495   EXPECT_EQ(v, (std::vector<ShapeIndex>{
    496                    {2, 1},
    497                    {2, 0, 1},
    498                    {2, 0, 0},
    499                    {1, 1},
    500                    {1, 0},
    501                    {0},
    502                }));
    503 }
    504 
    505 }  // namespace
    506 }  // namespace xla
    507