1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" 17 18 #include <algorithm> 19 #include <random> 20 21 #include "tensorflow/compiler/xla/test_helpers.h" 22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 23 #include "tensorflow/compiler/xla/util.h" 24 25 namespace xla { 26 namespace cpu { 27 namespace { 28 29 class ShapePartitionAssignerTest : public HloTestBase { 30 protected: 31 typedef std::vector<int64> Vec; 32 33 void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { 34 ShapePartitionAssigner assigner(shape); 35 // Check all partitions of outer dimension. 36 for (int64 i = 1; i <= expected_max_partition_count; ++i) { 37 EXPECT_TRUE(ContainersEqual(Vec({i}), 38 assigner.Run(/*target_partition_count=*/i))); 39 } 40 // Check target_partition_count > outer dimension size. 41 EXPECT_TRUE(ContainersEqual( 42 Vec({expected_max_partition_count}), 43 assigner.Run( 44 /*target_partition_count=*/expected_max_partition_count + 1))); 45 } 46 }; 47 48 TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { 49 RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); 50 } 51 52 TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { 53 RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); 54 } 55 56 TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { 57 RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); 58 } 59 60 TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { 61 RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); 62 } 63 64 TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { 65 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); 66 ShapePartitionAssigner assigner(shape); 67 68 for (int64 i = 1; i <= 5; ++i) { 69 EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( 70 /*target_partition_count=*/i))); 71 } 72 73 EXPECT_TRUE( 74 ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); 75 EXPECT_TRUE( 76 ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); 77 EXPECT_TRUE( 78 ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); 79 EXPECT_TRUE( 80 ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); 81 EXPECT_TRUE(ContainersEqual(Vec({3, 3}), 82 assigner.Run(/*target_partition_count=*/10))); 83 EXPECT_TRUE(ContainersEqual(Vec({3, 3}), 84 assigner.Run(/*target_partition_count=*/11))); 85 EXPECT_TRUE(ContainersEqual(Vec({4, 3}), 86 assigner.Run(/*target_partition_count=*/12))); 87 EXPECT_TRUE(ContainersEqual(Vec({4, 3}), 88 assigner.Run(/*target_partition_count=*/13))); 89 EXPECT_TRUE(ContainersEqual(Vec({4, 3}), 90 assigner.Run(/*target_partition_count=*/14))); 91 EXPECT_TRUE(ContainersEqual(Vec({5, 3}), 92 assigner.Run(/*target_partition_count=*/15))); 93 EXPECT_TRUE(ContainersEqual(Vec({5, 3}), 94 assigner.Run(/*target_partition_count=*/16))); 95 } 96 97 TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { 98 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); 99 ShapePartitionAssigner assigner(shape); 100 101 for (int64 i = 1; i <= 3; ++i) { 102 EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( 103 /*target_partition_count=*/i))); 104 } 105 106 EXPECT_TRUE( 107 ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); 108 EXPECT_TRUE( 109 ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); 110 EXPECT_TRUE( 111 ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); 112 EXPECT_TRUE( 113 ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); 114 EXPECT_TRUE( 115 ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); 116 EXPECT_TRUE( 117 ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); 118 EXPECT_TRUE(ContainersEqual(Vec({3, 3}), 119 assigner.Run(/*target_partition_count=*/10))); 120 EXPECT_TRUE(ContainersEqual(Vec({3, 3}), 121 assigner.Run(/*target_partition_count=*/11))); 122 EXPECT_TRUE(ContainersEqual(Vec({3, 4}), 123 assigner.Run(/*target_partition_count=*/12))); 124 EXPECT_TRUE(ContainersEqual(Vec({3, 4}), 125 assigner.Run(/*target_partition_count=*/13))); 126 EXPECT_TRUE(ContainersEqual(Vec({3, 4}), 127 assigner.Run(/*target_partition_count=*/14))); 128 EXPECT_TRUE(ContainersEqual(Vec({3, 5}), 129 assigner.Run(/*target_partition_count=*/15))); 130 EXPECT_TRUE(ContainersEqual(Vec({3, 5}), 131 assigner.Run(/*target_partition_count=*/16))); 132 } 133 134 class ShapePartitionIteratorTest : public HloTestBase { 135 protected: 136 typedef std::vector<std::pair<int64, int64>> Partition; 137 }; 138 139 TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { 140 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}); 141 142 { 143 ShapePartitionIterator iterator(shape, {1}); 144 EXPECT_EQ(1, iterator.GetTotalPartitionCount()); 145 EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); 146 } 147 148 { 149 ShapePartitionIterator iterator(shape, {2}); 150 EXPECT_EQ(2, iterator.GetTotalPartitionCount()); 151 EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); 152 EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); 153 } 154 155 { 156 ShapePartitionIterator iterator(shape, {3}); 157 EXPECT_EQ(3, iterator.GetTotalPartitionCount()); 158 EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); 159 EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); 160 EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); 161 } 162 } 163 164 TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { 165 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); 166 167 { 168 ShapePartitionIterator iterator(shape, {1, 1}); 169 EXPECT_EQ(1, iterator.GetTotalPartitionCount()); 170 EXPECT_TRUE( 171 ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); 172 } 173 174 { 175 ShapePartitionIterator iterator(shape, {2, 2}); 176 EXPECT_EQ(4, iterator.GetTotalPartitionCount()); 177 EXPECT_TRUE( 178 ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); 179 EXPECT_TRUE( 180 ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); 181 EXPECT_TRUE( 182 ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); 183 EXPECT_TRUE( 184 ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); 185 } 186 } 187 188 class RandomShapePartitionIteratorTest : public HloTestBase { 189 protected: 190 typedef std::vector<std::pair<int64, int64>> Partition; 191 RandomShapePartitionIteratorTest() 192 : generator_(rd_()), distribution_(1, 10) {} 193 194 std::vector<int64> RandR4Dims() { return {Rand(), Rand(), Rand(), Rand()}; } 195 196 int64 Rand() { return distribution_(generator_); } 197 198 std::random_device rd_; 199 std::mt19937 generator_; 200 std::uniform_int_distribution<int> distribution_; 201 }; 202 203 TEST_F(RandomShapePartitionIteratorTest, RandomShapeAndPartitions) { 204 // Choose random dimensions for R4 shape. 205 Shape shape = ShapeUtil::MakeShapeWithLayout(F32, RandR4Dims(), {3, 2, 1, 0}); 206 // Choose random number of outer dimensions to partition. 207 const int num_outer_dims_to_partition = 1 + (Rand() % 3); 208 // Choose random outer dimension partition counts. 209 std::vector<int64> dim_sizes(num_outer_dims_to_partition); 210 std::vector<int64> dim_partition_counts(num_outer_dims_to_partition); 211 int64 total_dim_size = 1; 212 for (int i = 0; i < num_outer_dims_to_partition; ++i) { 213 const int64 dimension = shape.layout().minor_to_major( 214 shape.layout().minor_to_major_size() - 1 - i); 215 dim_sizes[i] = shape.dimensions(dimension); 216 total_dim_size *= dim_sizes[i]; 217 // Choose dimension partition count in [1, dim_size] 218 const int64 dim_partition_count = 1 + Rand() % dim_sizes[i]; 219 dim_partition_counts[i] = dim_partition_count; 220 } 221 // Iterate through all partition: for each partition record covered 222 // index ranges by dimension. 223 std::vector<std::map<int64, int64>> ranges(num_outer_dims_to_partition); 224 ShapePartitionIterator partition_iterator(shape, dim_partition_counts); 225 const int64 partition_count = partition_iterator.GetTotalPartitionCount(); 226 for (int64 i = 0; i < partition_count; ++i) { 227 const auto& dim_partition = partition_iterator.GetPartition(i); 228 for (int dim = 0; dim < dim_partition.size(); ++dim) { 229 ranges[dim].insert( 230 std::make_pair(dim_partition[dim].first, 231 dim_partition[dim].first + dim_partition[dim].second)); 232 } 233 } 234 // Check that partitions cover entire dimension size range (for each 235 // partitioned dimension). 236 for (int i = 0; i < ranges.size(); ++i) { 237 int64 expected_index = 0; 238 for (auto& r : ranges[i]) { 239 EXPECT_EQ(expected_index, r.first); 240 expected_index = r.second; 241 } 242 EXPECT_EQ(expected_index, dim_sizes[i]); 243 } 244 } 245 246 } // namespace 247 } // namespace cpu 248 } // namespace xla 249