Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Shapes are protobuf messages, so this utility header offers a bunch of
     17 // functionality for querying / poking at them.
     18 
     19 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
     20 #define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
     21 
     22 #include <initializer_list>
     23 #include <string>
     24 
     25 #include "tensorflow/compiler/xla/layout_util.h"
     26 #include "tensorflow/compiler/xla/primitive_util.h"
     27 #include "tensorflow/compiler/xla/statusor.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/xla_data.pb.h"
     30 #include "tensorflow/core/lib/gtl/array_slice.h"
     31 #include "tensorflow/core/lib/gtl/optional.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 namespace xla {
     36 
     37 // An index for specifying a particular nested subshape within a shape. Used in
     38 // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
     39 // structures (trees) and ShapeIndex defines a path through the tree where each
     40 // element of ShapeIndex indexes into a tuple (or nested tuple) within the
     41 // shape. For a non-nested tuple, an index has a single element. For example,
     42 // given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
     43 // {1} corresponds to array b. For a nested tuple, the index can have more than
     44 // one element. For the nested tuple (a, (b, c, d), e) below are the values
     45 // corresponding to the given indices:
     46 //
     47 //   index {0}    : array a
     48 //   index {1, 2} : array d
     49 //   index {2}    : array e
     50 //   index {0, 0} : invalid index (element at {0} is an array not a tuple)
     51 //
     52 // For indexing into array shapes, the index is always trivially empty, ie {}.
     53 //
     54 // ShapeIndex is a trivial wrapper around std::vector with a minimum number of
     55 // methods implemented.
     56 class ShapeIndex {
     57  public:
     58   ShapeIndex() = default;
     59   ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
     60 
     61   bool empty() const { return indices_.empty(); }
     62   size_t size() const { return indices_.size(); }
     63   void push_back(int64 value) { indices_.push_back(value); }
     64   void pop_back() { indices_.pop_back(); }
     65 
     66   // push_front is O(n^2), but shapes don't usually have a ton of dimensions.
     67   void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
     68 
     69   std::vector<int64>::const_iterator begin() const { return indices_.begin(); }
     70   std::vector<int64>::const_iterator end() const { return indices_.end(); }
     71   std::vector<int64>::iterator begin() { return indices_.begin(); }
     72   std::vector<int64>::iterator end() { return indices_.end(); }
     73 
     74   const int64* data() const { return indices_.data(); }
     75 
     76   int64 back() const { return indices_.back(); }
     77   int64& back() { return indices_.back(); }
     78 
     79   const int64& operator[](size_t i) const { return indices_[i]; }
     80   int64& operator[](size_t i) { return indices_[i]; }
     81 
     82   bool operator==(const ShapeIndex& other) const {
     83     return indices_ == other.indices_;
     84   }
     85   bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
     86   bool operator<(const ShapeIndex& other) const {
     87     return indices_ < other.indices_;
     88   }
     89 
     90   string ToString() const;
     91 
     92  private:
     93   std::vector<int64> indices_;
     94 };
     95 
     96 // A view into a ShapeIndex as above, with the cheap/easy ability to consume the
     97 // value at the front of the view.
     98 //
     99 // NB! ShapeIndexView does not own the memory backing the index array.
    100 // The memory backing the index array should be owned by an object
    101 // that lives longer than the ShapeIndexView instances pointing into
    102 // it.
    103 class ShapeIndexView {
    104  public:
    105   ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
    106       : ShapeIndexView(shape_index.data() + offset,
    107                        shape_index.data() + shape_index.size()) {
    108     CHECK_LE(offset, shape_index.size());
    109   }
    110   ShapeIndexView(std::initializer_list<int64> indices)
    111       : ShapeIndexView(indices.begin(), indices.end()) {}
    112   ShapeIndexView(const ShapeIndexView& other) = default;
    113 
    114   using iterator = const int64*;
    115 
    116   iterator begin() const { return begin_; }
    117   iterator end() const { return end_; }
    118   int64 size() const { return std::distance(begin_, end_); }
    119   bool empty() const { return begin_ == end_; }
    120   int64 front() const {
    121     CHECK(!empty());
    122     return *begin_;
    123   }
    124   ShapeIndexView ConsumeFront() const {
    125     CHECK(!empty());
    126     auto new_begin = begin_;
    127     ++new_begin;
    128     return ShapeIndexView(new_begin, end_);
    129   }
    130 
    131   string ToString() const;
    132 
    133  private:
    134   ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {}
    135 
    136   iterator begin_;
    137   iterator end_;
    138 };
    139 
    140 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
    141 std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
    142 
    143 // Namespaced collection of (static) shape utilities.
    144 //
    145 // These are all effectively convenience functions for testing/tweaking proto
    146 // properties, which do invariant checks before / after the operation.
    147 class ShapeUtil {
    148  public:
    149   // Returns the number of elements are contained within the provided shape;
    150   // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
    151   // may not actually be able to store this number of elements. See
    152   // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
    153   // elements that can be stored in a sparse shape.
    154   // Precondition: !IsTuple(shape)
    155   static int64 ElementsIn(const Shape& shape);
    156 
    157   // Returns true if 'shape' has zero elements.
    158   static bool HasZeroElements(const Shape& shape);
    159 
    160   // Returns the number of bytes required for an allocation of shape.  The
    161   // |pointer_size| parameter is used for calculating the size of tuple
    162   // shapes. This includes only the size of the top-level buffer. For example, a
    163   // tuple is stored as an array of pointers to other buffers. In this case,
    164   // this method only returns the size of the pointer array.
    165   // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
    166   //               !ShapeUtil::IsOpaque(shape)
    167   static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
    168 
    169   // Returns the number of bytes used to store the primitive_type.
    170   //
    171   // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
    172   static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
    173 
    174   // Returns the number of bytes required to store the tuple member pointers for
    175   // a allocation of shape. The `shape` must be a TUPLE shape, and
    176   // `pointer_size` must be larger than zero.
    177   static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
    178                                          int64 pointer_size);
    179 
    180   // Returns the number of bytes required for the elements in an allocation of
    181   // `shape`, which must be an array shape. The return value does not include
    182   // the bytes needed to store sparse indices. Dense shapes use a separate
    183   // memory location for each element, and so for these shapes,
    184   // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
    185   // size also includes padding if present in the layout. For sparse shapes,
    186   // `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
    187   // ByteSizeOfSparseindices(shape)`.
    188   static int64 ByteSizeOfElements(const Shape& shape);
    189 
    190   // Returns the number of bytes required for the sparse indices in an
    191   // allocation of shape. The shape must be an array shape. The return value
    192   // does not include the bytes needed to store sparse indices.
    193   static int64 ByteSizeOfSparseIndices(const Shape& shape);
    194 
    195   // Returns a human-readable string that represents the given shape, with or
    196   // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
    197   static string HumanString(const Shape& shape);
    198   static string HumanStringWithLayout(const Shape& shape);
    199 
    200   // As above, but for program shapes, returns a string for the form:
    201   //
    202   // (param_name: f32[42x12], ...) -> f32[24x42]
    203   static string HumanString(const ProgramShape& program_shape);
    204 
    205   // Parses a ShapeUtil::HumanString-format shape string back into a shape
    206   // object.
    207   static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
    208 
    209   // Returns whether the LHS and RHS shapes have the same dimensions; note: does
    210   // not check element type.
    211   static bool SameDimensions(const Shape& lhs, const Shape& rhs);
    212 
    213   // Returns whether the lhs and rhs shapes have the same element type.
    214   static bool SameElementType(const Shape& lhs, const Shape& rhs) {
    215     return lhs.element_type() == rhs.element_type();
    216   }
    217 
    218   // As SameElementType, but allows floating point types to have different
    219   // precisions.
    220   static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
    221                                                  const Shape& b) {
    222     if (ElementIsFloating(a) && ElementIsFloating(b)) {
    223       return true;
    224     }
    225     return ShapeUtil::SameElementType(a, b);
    226   }
    227 
    228   // Returns the higher-precision element type if a and b are both floating
    229   // point types; otherwise, checks that that they have the same element type
    230   // and returns it.
    231   static PrimitiveType HigherPrecisionElementType(const Shape& a,
    232                                                   const Shape& b) {
    233     if (SameElementType(a, b)) {
    234       return a.element_type();
    235     }
    236     CHECK(SameElementTypeIgnoringFpPrecision(a, b));
    237     return primitive_util::BitWidth(a.element_type()) <
    238                    primitive_util::BitWidth(b.element_type())
    239                ? b.element_type()
    240                : a.element_type();
    241   }
    242 
    243   // Returns true if the rank, dimension sizes, and element type are
    244   // identical. Layout is ignored. Tuple elements are compared recursively for
    245   // compatibility.
    246   static bool Compatible(const Shape& lhs, const Shape& rhs);
    247 
    248   // Returns true if the rank and dimension sizes are identical. Element type
    249   // and layout are ignored. Tuple elements are compared recursively for
    250   // compatibility.
    251   static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
    252 
    253   // As Compatible, but allow one of lhs and rhs to be BF16 while the other
    254   // being F32. Tuple elements are compared recursively for compatibility.
    255   static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
    256 
    257   // Returns whether the lhs and rhs shapes are identical protobufs.
    258   static bool Equal(const Shape& lhs, const Shape& rhs);
    259 
    260   // Returns the rank (number of dimensions) of the given shape.
    261   // Precondition: !IsTuple(shape)
    262   static int64 Rank(const Shape& shape);
    263 
    264   // Returns the number of dimensions for which the dimension is not (trivially)
    265   // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
    266   // fluff. Note that zero dimensions are included in the true rank, e.g.,
    267   // f32[3,0,1] has a true rank of 2D.
    268   static int64 TrueRank(const Shape& shape);
    269 
    270   static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
    271                                        Shape result);
    272 
    273   ////////////////////
    274   // Scalar-specific
    275 
    276   static bool IsScalar(const Shape& shape) {
    277     return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
    278   }
    279   static bool IsEffectiveScalar(const Shape& shape) {
    280     return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
    281   }
    282   static bool IsScalarF32(const Shape& shape);
    283 
    284   // Extracts the size of the shape's dimension at dimension number
    285   // GetDimensionNumber(dimension_number).
    286   static int64 GetDimension(const Shape& shape, int64 dimension_number);
    287 
    288   // Resolves a dimension number, supporting negative indexing.
    289   //
    290   // Negative indexing has similar semantics to Python. For an N-dimensional
    291   // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
    292   // N-2, and so on.
    293   //
    294   // This function always returns a positive dimension number for any given
    295   // dimension_number (which itself can be negative).
    296   static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
    297 
    298   // Returns a shape with the same dimensions as the original, but with the
    299   // element type changed to type.
    300   static Shape ChangeElementType(const Shape& original, PrimitiveType type);
    301 
    302   // Creates a tuple shape from a slice of element shapes within the tuple.
    303   static Shape MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes);
    304 
    305   // Creates an opaque shape. These are generally used for threading a context
    306   // into a custom operation.
    307   static Shape MakeOpaqueShape();
    308 
    309   // Appends a shape to the given tuple.
    310   static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
    311 
    312   // Appends a major dimension to the shape with the given bound.
    313   static void AppendMajorDimension(int bound, Shape* shape);
    314 
    315   // Returns an empty tuple shape. Can be used to indicate side-effects.
    316   static Shape MakeNil() { return MakeTupleShape({}); }
    317 
    318   // Constructs a new shape with the given element type and sequence of
    319   // dimensions.
    320   static Shape MakeShape(PrimitiveType element_type,
    321                          tensorflow::gtl::ArraySlice<int64> dimensions);
    322 
    323   // Constructs a new shape with the given minor_to_major order in its Layout.
    324   // Returns a value shape such that shape.has_layout().
    325   static Shape MakeShapeWithLayout(
    326       PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
    327       tensorflow::gtl::ArraySlice<int64> minor_to_major);
    328 
    329   static Shape MakeShapeWithSparseLayout(
    330       PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
    331       int64 max_sparse_elements);
    332 
    333   // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
    334   static Shape MakeShapeWithDescendingLayout(
    335       PrimitiveType element_type,
    336       tensorflow::gtl::ArraySlice<int64> dimensions);
    337 
    338   // Returns a new Shape based on the given Shape with low-dimension-major
    339   // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
    340   // rearranged so that it has the same in-memory layout as the given shape.
    341   //
    342   // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
    343   static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
    344       const Shape& shape);
    345 
    346   // As MakeShape, but the object to write to is passed in.
    347   static void PopulateShape(PrimitiveType element_type,
    348                             tensorflow::gtl::ArraySlice<int64> dimensions,
    349                             Shape* shape);
    350 
    351   // Validates that the provided shape satisfies invariants.
    352   static Status ValidateShape(const Shape& shape);
    353 
    354   // Validates the provided shape satisfies invariants, except those that
    355   // pertain to layout.
    356   //
    357   // Layout is optional for client-provided shapes, so that the compiler may
    358   // determine and assign an optimized layout.
    359   static Status ValidateShapeWithOptionalLayout(const Shape& shape);
    360 
    361   // Returns whether the element type of the shape is integral (signed or
    362   // unsigned). Note that predicates are not considered integral here, since
    363   // they are logical values.
    364   static bool ElementIsIntegral(const Shape& shape);
    365 
    366   // Returns whether the element type of the shape is floating point.
    367   static bool ElementIsFloating(const Shape& shape);
    368 
    369   // Returns whether the element type of the shape is complex.
    370   static bool ElementIsComplex(const Shape& shape);
    371 
    372   // Returns whether the element type has the given bit width.
    373   static bool ElementHasBitWidth(const Shape& shape, int bits);
    374 
    375   // Returns whether the element type of the shape is integral and has
    376   // the specified number of bits.
    377   static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
    378 
    379   // Returns whether the element type of the shape is signed. Note
    380   // that floating point numbers are signed.
    381   static bool ElementIsSigned(const Shape& shape);
    382 
    383   // Returns whether the shape is a tuple.
    384   static bool IsTuple(const Shape& shape) {
    385     return shape.element_type() == TUPLE;
    386   }
    387 
    388   // Returns whether the shape is an opaque value (i.e. an 'existential' typed
    389   // value that is passed to CustomCall operations).
    390   static bool IsOpaque(const Shape& shape) {
    391     return shape.element_type() == OPAQUE;
    392   }
    393 
    394   // Returns whether the shape is an array.  Note that scalars are considered
    395   // arrays.
    396   static bool IsArray(const Shape& shape) {
    397     return !IsTuple(shape) && !IsOpaque(shape);
    398   }
    399 
    400   // Returns whether the shape is a tuple with at least one element which is
    401   // also a tuple.
    402   static bool IsNestedTuple(const Shape& shape);
    403 
    404   // Returns true if shape is an empty tuple.
    405   static bool IsEmptyTuple(const Shape& shape);
    406 
    407   // Returns true if shape is an empty tuple, or is an array with no elements.
    408   static bool IsNil(const Shape& shape);
    409 
    410   // Returns the number of elements in the given tuple shape.
    411   // Precondition: IsTuple(shape)
    412   static int64 TupleElementCount(const Shape& shape);
    413 
    414   // Returns the tuple element shape at given index.
    415   // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
    416   static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
    417 
    418   // Slices tuple elements in the range [start, limit) and returns a new tuple
    419   // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
    420   static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
    421 
    422   // Returns the shape of the real/imaginary components of the given complex
    423   // shape.
    424   static Shape ComplexComponentShape(const Shape& complex_shape);
    425 
    426   // Shorthand for testing whether a shape is of a given element type and
    427   // sequence of dimensions.
    428   //
    429   // DEPRECATED: Use Equal() instead.
    430   static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
    431                       std::initializer_list<int64> dimensions);
    432 
    433   // GetSubshape and GetMutableSubshape return a particular nested Shape within
    434   // the given Shape argument.
    435   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
    436   static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
    437 
    438   // Returns whether the given index in the given shape is a leaf element of the
    439   // shape.
    440   static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);
    441 
    442   // Calls the given visitor function for each subshape of the given shape.
    443   // Subshapes are visited in DFS pre-order starting with the entire shape
    444   // (index {}).
    445   using VisitorFunction = std::function<void(const Shape& /*subshape*/,
    446                                              const ShapeIndex& /*index*/)>;
    447   static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
    448   using MutatingVisitorFunction =
    449       std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
    450   static void ForEachMutableSubshape(Shape* shape,
    451                                      const MutatingVisitorFunction& func);
    452 
    453   // Variants of ForEach(Mutable)Subshape which propagate Status from the
    454   // visitor function.
    455   using StatusVisitorFunction = std::function<Status(
    456       const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
    457   static Status ForEachSubshapeWithStatus(const Shape& shape,
    458                                           const StatusVisitorFunction& func);
    459   using MutatingStatusVisitorFunction =
    460       std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
    461   static Status ForEachMutableSubshapeWithStatus(
    462       Shape* shape, const MutatingStatusVisitorFunction& func);
    463 
    464   // Removes all degenerate dimensions (size one) from the given shape. The
    465   // stripped minor_to_major preserves the relative ordering of non-degenerate
    466   // dimensions. The stripped shape has the property that the underlying
    467   // representation (bits in memory) for the stripped shape is the same as the
    468   // original shape modulo padding. Examples:
    469   //
    470   // input shape:    F32 [1, 2, 1], minor_to_major = {0, 1, 2}
    471   // stripped shape: F32 [2], minor_to_major = {0}
    472   //
    473   // input shape:    F32 [6, 1, 5], minor_to_major = {2, 0, 1}
    474   // stripped shape: F32 [6, 5], minor_to_major = {1, 0}
    475   //
    476   // input shape:    F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1}
    477   // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1}
    478   //
    479   // input shape:    F32 [1, 1], minor_to_major = {0, 1}
    480   // stripped shape: F32 [], minor_to_major = {}
    481   // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
    482   static Shape StripDegenerateDimensions(const Shape& shape);
    483 
    484   // Permutes the dimensions by the given permutation, so
    485   // return_value.dimensions[permutation[i]] = argument.dimensions[i]
    486   static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
    487                                  const Shape& shape);
    488 
    489   // If we can go from `shape_pre` to `shape_post` by merely inserting or
    490   // deleting 1-sized dimensions, return the indices in `shape_pre` of the
    491   // deleted dimensions and the indices in `dims_post` of the inserted
    492   // dimensions.
    493   // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
    494   // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
    495   // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
    496   // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
    497   // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
    498   // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
    499   static std::tuple<bool, std::vector<int64>, std::vector<int64>>
    500   InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
    501                                     const Shape& shape_post);
    502 
    503   // Suppose a reshape transforms input_shape to output shape. Returns a vector
    504   // of pairs that indicate the input and output dimensions that this reshape
    505   // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
    506   // the returned vector, the reshape transforms any input index whose I-th
    507   // dimension is x to an output index whose O-th dimension is x too.
    508   //
    509   // Post-condition: the returned vector is sorted (by both input and output
    510   // dimensions because input and output dimensions have the same order).
    511   //
    512   // Example:
    513   //   input  shape = T[a, b, x, y, cd]
    514   //   output shape = T[ab, x, 1, y, c, d]
    515   //   return value = {{2, 1}, {3, 3}}
    516   //
    517   //   The two pairs represent the input and output dimension of size x and
    518   //   those of size y.
    519   static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
    520       const Shape& input_shape, const Shape& output_shape);
    521 
    522   // Returns whether a transpose from input_shape to output_shape with dimension
    523   // mapping "dimension_mapping" produces a result which is bit-wise identical
    524   // to its input and thus may be replaced with a bitcast.
    525   static bool TransposeIsBitcast(
    526       const Shape& input_shape, const Shape& output_shape,
    527       tensorflow::gtl::ArraySlice<int64> dimension_mapping);
    528 
    529   // Returns whether a reshape from "input_shape" to "output_shape" is a
    530   // bitcast.
    531   static bool ReshapeIsBitcast(const Shape& input_shape,
    532                                const Shape& output_shape);
    533 
    534   // Find a physical layout for 'output_shape' such that
    535   // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
    536   // true (where 'output_shape_with_layout' is 'output_shape' with the found
    537   // layout). The layout of 'input_shape' is kept fixed. Returns
    538   // 'output_shape_with_layout' if such a layout can be found, and an error
    539   // otherwise.
    540   static tensorflow::gtl::optional<Shape> AlignLayouts(
    541       const Shape& input_shape, const Shape& output_shape);
    542 
    543   // Returns a shape with the given dimension deleted.
    544   // For example:
    545   //  `DeleteDimension(1, T[m, n, k]) = T[m, k]`
    546   static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
    547 
    548   // Returns a shape with all the dimensions of the input shape for which `p`
    549   // returns true.
    550   // For examples:
    551   //  `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
    552   //  `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
    553   static Shape FilterDimensions(const std::function<bool(int64)>& p,
    554                                 Shape shape);
    555 
    556   // Iterates through all the shape indexes, in minor to major order, starting
    557   // from the base indexes, incrementing by the incr steps, up to count
    558   // (index[i] < base[i] + count[i]), and calls the visitor_function with the
    559   // current index.
    560   // The visitor_function visitor function should return true if it wants to
    561   // continue, or false otherwise.
    562   //
    563   // visitor_function must be a callable of type bool(const std::vector<int64>&)
    564   // or compatible.
    565   template <typename FnType>
    566   static void ForEachIndex(const Shape& shape,
    567                            tensorflow::gtl::ArraySlice<int64> base,
    568                            tensorflow::gtl::ArraySlice<int64> count,
    569                            tensorflow::gtl::ArraySlice<int64> incr,
    570                            const FnType& visitor_function) {
    571     if (ShapeUtil::HasZeroElements(shape)) {
    572       return;
    573     }
    574     CHECK_EQ(Rank(shape), base.size());
    575     CHECK_EQ(incr.size(), base.size());
    576     CHECK_EQ(count.size(), base.size());
    577     const int64 rank = LayoutUtil::MinorToMajor(shape).size();
    578     // Allows handling R0 arrays, such that the visitor function will be called
    579     // once with the proper empty indexes.
    580     int64 n = -1;
    581     std::vector<int64> indexes(base.begin(), base.end());
    582     while (n < rank && visitor_function(indexes)) {
    583       // Increments dimensions in minor to major order.
    584       for (n = 0; n < rank; ++n) {
    585         int64 dim = LayoutUtil::Minor(shape.layout(), n);
    586         indexes[dim] += incr[dim];
    587         if (indexes[dim] < base[dim] + count[dim]) {
    588           break;
    589         }
    590         indexes[dim] = base[dim];
    591       }
    592     }
    593   }
    594 
    595  private:
    596   // Validates all of the non-layout properties of the shape -- this is a helper
    597   // used by both the layout-optional and layout-required public method.
    598   static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
    599 
    600   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
    601 };
    602 
    603 std::ostream& operator<<(std::ostream& out, const Shape& shape);
    604 
    605 }  // namespace xla
    606 
    607 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
    608