Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
     18 
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/shape_util.h"
     22 
     23 namespace xla {
     24 namespace cpu {
     25 
     26 // ShapePartitionAssigner partitions the most-major dimensions of 'shape' such
     27 // that the total partition count <= 'target_partition_count'.
     28 //
     29 // Example 1:
     30 //
     31 //   Let 'shape' = [8, 16, 32] and 'target_partition_count' = 6.
     32 //
     33 //   Because the most-major dimension size is <= 'target_partition_count', we
     34 //   can generate our target number of partitions by partition the most-major
     35 //   dimensions.
     36 //
     37 //   This will result in the following partitions of the most-major dimension:
     38 //
     39 //     [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8)
     40 //
     41 //   Note that the last partition has residule because the dimension size is
     42 //   not a multiple of the partition count.
     43 //
     44 //
     45 // Example 2:
     46 //
     47 //   Let 'shape' = [8, 16, 32] and 'target_partition_count' = 16.
     48 //
     49 //   Because the most-major dimension only has size 8, we must also partition
     50 //   the next most-major dimension to generate the target of 16 partitions.
     51 //   We factor 'target_partition_count' by the number of most-major dimensions
     52 //   we need to partition, to get a per-dimension target partition count:
     53 //
     54 //     target_dimension_partition_count = 16 ^ (1 / 2) == 4
     55 //
     56 //   This will result in the following partitions of the most-major dimension:
     57 //
     58 //     [0, 2), [2, 4), [4, 6), [6, 8)
     59 //
     60 //   This will result in the following partitions of the second most-major
     61 //   dimension:
     62 //
     63 //     [0, 4), [4, 8), [8, 12), [12, 16)
     64 //
     65 class ShapePartitionAssigner {
     66  public:
     67   ShapePartitionAssigner(const Shape& shape) : shape_(shape) {}
     68 
     69   // Returns dimension partition counts (starting at outer-most dimension).
     70   std::vector<int64> Run(int64 target_partition_count);
     71 
     72   // Returns the total partition count based on 'dimension_partition_counts'.
     73   static int64 GetTotalPartitionCount(
     74       const std::vector<int64>& dimension_partition_counts);
     75 
     76  private:
     77   const Shape& shape_;
     78 };
     79 
     80 // ShapePartitionIterator iterates through outer-dimension partitions of
     81 // 'shape' as specified by 'dimension_partition_counts'.
     82 class ShapePartitionIterator {
     83  public:
     84   ShapePartitionIterator(const Shape& shape,
     85                          const std::vector<int64>& dimension_partition_counts);
     86 
     87   // Returns a partition [start, size] for each dimension.
     88   // Partitions are listed starting from outer-most dimension first.
     89   std::vector<std::pair<int64, int64>> GetPartition(int64 index) const;
     90 
     91   int64 GetTotalPartitionCount() const;
     92 
     93  private:
     94   const Shape& shape_;
     95   const std::vector<int64> dimension_partition_counts_;
     96 
     97   std::vector<int64> dimensions_;
     98   std::vector<int64> dimension_partition_sizes_;
     99   std::vector<int64> dimension_partition_strides_;
    100 };
    101 
    102 }  // namespace cpu
    103 }  // namespace xla
    104 
    105 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
    106