Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/layout_util.h"
     17 
     18 #include <stddef.h>
     19 #include <algorithm>
     20 #include <functional>
     21 #include <random>
     22 #include <string>
     23 #include <unordered_map>
     24 #include <vector>
     25 
     26 #include "absl/strings/str_cat.h"
     27 #include "absl/strings/str_join.h"
     28 #include "tensorflow/compiler/xla/protobuf_util.h"
     29 #include "tensorflow/compiler/xla/shape_util.h"
     30 #include "tensorflow/compiler/xla/status_macros.h"
     31 #include "tensorflow/compiler/xla/types.h"
     32 #include "tensorflow/compiler/xla/util.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/hash/hash.h"
     35 #include "tensorflow/core/lib/strings/numbers.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/protobuf.h"
     38 
     39 namespace xla {
     40 namespace {
     41 
     42 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
     43 // minor_to_major to the value that represents the default layout.
     44 void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) {
     45   // The default XLA layout is major-to-minor (dim 0 is major).
     46   // For more information on XLA layouts, see:
     47   // https://www.tensorflow.org/performance/xla/shapes
     48   const int64 size = minor_to_major->size();
     49   for (int64 i = 0; i < size; ++i) {
     50     (*minor_to_major)[i] = size - 1 - i;
     51   }
     52 }
     53 
     54 }  // namespace
     55 
     56 /* static */ Layout LayoutUtil::MakeLayout(
     57     absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
     58     int64 element_size_in_bits) {
     59   Layout layout;
     60   layout.set_format(DENSE);
     61   for (int64 dimension_number : minor_to_major) {
     62     layout.add_minor_to_major(dimension_number);
     63   }
     64   for (Tile tile : tiles) {
     65     for (int64 dim : tile.dimensions()) {
     66       if (dim < 0 && dim != Tile::kCombineDimension) {
     67         LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if "
     68                       "it's negative. Value is "
     69                    << dim;
     70       }
     71     }
     72     *layout.add_tiles() = tile;
     73   }
     74   layout.set_element_size_in_bits(element_size_in_bits);
     75   return layout;
     76 }
     77 
     78 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
     79   std::vector<int64> layout(rank);
     80   std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
     81   return MakeLayout(layout);
     82 }
     83 
     84 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
     85     absl::Span<const int64> major_to_minor) {
     86   Layout layout;
     87   layout.set_format(DENSE);
     88   for (int i = major_to_minor.size() - 1; i >= 0; i--) {
     89     layout.add_minor_to_major(major_to_minor[i]);
     90   }
     91   return layout;
     92 }
     93 
     94 /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
     95   Layout layout;
     96   layout.set_format(SPARSE);
     97   layout.set_max_sparse_elements(max_sparse_elements);
     98   return layout;
     99 }
    100 
    101 namespace {
    102 
    103 // Internal helper that creates a default layout for an array of the given rank.
    104 Layout CreateDefaultLayoutForRank(int64 rank) {
    105   Layout layout;
    106   layout.set_format(DENSE);
    107   std::vector<int64>* minor_to_major = layout.mutable_minor_to_major();
    108   minor_to_major->resize(rank, 0);
    109   SetDefaultLayoutToContainer(minor_to_major);
    110   return layout;
    111 }
    112 
    113 }  // namespace
    114 
    115 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
    116   if (shape.IsOpaque() || shape.IsToken()) {
    117     // Opaque and token types have empty layouts.
    118     return Layout();
    119   }
    120 
    121   // A Layout proto corresponds to a single array, not a tuple.
    122   CHECK(shape.IsArray());
    123   return CreateDefaultLayoutForRank(shape.dimensions_size());
    124 }
    125 
    126 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
    127   return CreateDefaultLayoutForRank(rank);
    128 }
    129 
    130 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
    131   return CreateDefaultLayoutForRank(2);
    132 }
    133 
    134 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
    135   return CreateDefaultLayoutForRank(3);
    136 }
    137 
    138 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
    139   return CreateDefaultLayoutForRank(4);
    140 }
    141 
    142 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
    143   if (shape->IsTuple()) {
    144     // Tuple shape.
    145     for (auto& element_shape : *shape->mutable_tuple_shapes()) {
    146       SetToDefaultLayout(&element_shape);
    147     }
    148     shape->clear_layout();
    149   } else if (shape->IsArray()) {
    150     shape->mutable_layout()->set_format(DENSE);
    151     auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
    152     minor_to_major->resize(shape->dimensions_size(), 0);
    153     SetDefaultLayoutToContainer(minor_to_major);
    154   } else {
    155     // Opaque, token types etc. have no layout.
    156     shape->clear_layout();
    157   }
    158 }
    159 
    160 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
    161   Shape copy(shape);
    162   LayoutUtil::SetToDefaultLayout(&copy);
    163   return copy;
    164 }
    165 
    166 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
    167   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
    168     LayoutUtil::SetToDefaultLayout(&parameter_shape);
    169   }
    170   LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
    171 }
    172 
    173 /* static */ Status LayoutUtil::ValidateLayoutInShape(
    174     const Shape& shape, bool allow_missing_layouts) {
    175   if (shape.IsTuple()) {
    176     // Tuple shape.
    177     if (shape.has_layout()) {
    178       return InvalidArgument("tuple should not have a layout field");
    179     }
    180     for (auto& element_shape : shape.tuple_shapes()) {
    181       TF_RETURN_IF_ERROR(
    182           ValidateLayoutInShape(element_shape, allow_missing_layouts));
    183     }
    184     return Status::OK();
    185   } else if (shape.IsArray()) {
    186     if (!shape.has_layout()) {
    187       if (allow_missing_layouts) {
    188         return Status::OK();
    189       }
    190       return InvalidArgument("shape %s does not have a layout",
    191                              ShapeUtil::HumanString(shape));
    192     }
    193     return ValidateLayoutForShape(shape.layout(), shape);
    194   } else {
    195     // Token, opaque, etc. shape.
    196     if (shape.has_layout()) {
    197       return InvalidArgument(
    198           "shape of primitive type %s should not have a layout",
    199           PrimitiveType_Name(shape.element_type()));
    200     }
    201     return Status::OK();
    202   }
    203 }
    204 
    205 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
    206                                                        const Shape& shape) {
    207   if (shape.IsTuple()) {
    208     return InvalidArgument("a single Layout is not valid for tuple shapes");
    209   }
    210 
    211   if (!shape.IsArray()) {
    212     if (layout.minor_to_major_size() != 0) {
    213       return InvalidArgument(
    214           "shape of primitive type %s should not have a non-trivial layout",
    215           PrimitiveType_Name(shape.element_type()));
    216     }
    217     return Status::OK();
    218   }
    219 
    220   if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
    221     return InvalidArgument("Layout has an invalid format (%d)",
    222                            layout.format());
    223   }
    224 
    225   if (layout.format() == DENSE) {
    226     if (layout.minor_to_major_size() != shape.rank()) {
    227       return InvalidArgument(
    228           "layout minor_to_major field contains %d elements, "
    229           "but shape is rank %d: {%s}; shape: %s",
    230           layout.minor_to_major_size(), shape.rank(),
    231           absl::StrJoin(layout.minor_to_major(), ", "),
    232           shape.ShortDebugString());
    233     }
    234 
    235     std::vector<bool> dimensions_in_layout(shape.rank(), false);
    236     for (int64 i = 0; i < shape.rank(); ++i) {
    237       int64 dim = layout.minor_to_major(i);
    238       if (dim < 0 || dim >= shape.rank()) {
    239         return InvalidArgument(
    240             "layout minor_to_major field has out-of-bounds value: %s",
    241             HumanString(layout));
    242       }
    243       if (dimensions_in_layout[dim]) {
    244         return InvalidArgument(
    245             "layout minor_to_major field has duplicate values: {%s}",
    246             HumanString(layout));
    247       }
    248       dimensions_in_layout[dim] = true;
    249     }
    250   } else {
    251     if (layout.tiles_size() != 0) {
    252       return InvalidArgument("Only dense layouts can be tiled.");
    253     }
    254   }
    255 
    256   return Status::OK();
    257 }
    258 
    259 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
    260   shape->clear_layout();
    261   for (auto& element_shape : *shape->mutable_tuple_shapes()) {
    262     ClearLayout(&element_shape);
    263   }
    264 }
    265 
    266 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
    267   for (auto& parameter_shape : *program_shape->mutable_parameters()) {
    268     LayoutUtil::ClearLayout(&parameter_shape);
    269   }
    270   LayoutUtil::ClearLayout(program_shape->mutable_result());
    271 }
    272 
    273 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
    274   return shape.IsArray() && shape.has_layout() && IsDense(shape.layout());
    275 }
    276 
    277 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
    278   return layout.format() == DENSE;
    279 }
    280 
    281 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
    282   CHECK(layout.format() == DENSE);
    283   return std::is_sorted(layout.minor_to_major().begin(),
    284                         layout.minor_to_major().end());
    285 }
    286 
    287 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
    288   CHECK(layout.format() == DENSE);
    289   return std::is_sorted(layout.minor_to_major().begin(),
    290                         layout.minor_to_major().end(), std::greater<int64>());
    291 }
    292 
    293 /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
    294   return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
    295 }
    296 
    297 /* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
    298   return layout.format() == SPARSE;
    299 }
    300 
    301 /* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
    302   CHECK(IsSparse(layout));
    303   return layout.max_sparse_elements();
    304 }
    305 
    306 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
    307   if (shape.IsTuple()) {
    308     // Tuple shape: all subshapes must have a layout.
    309     return absl::c_all_of(shape.tuple_shapes(),
    310                           [](const Shape& s) { return HasLayout(s); });
    311   } else if (!shape.IsArray()) {
    312     // Opaque, token types etc. ignore layout.
    313     return true;
    314   }
    315   return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
    316 }
    317 
    318 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
    319   for (auto& parameter_shape : program_shape.parameters()) {
    320     if (!LayoutUtil::HasLayout(parameter_shape)) {
    321       return false;
    322     }
    323   }
    324   return LayoutUtil::HasLayout(program_shape.result());
    325 }
    326 
    327 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
    328   return lhs == rhs;
    329 }
    330 
    331 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
    332     const Shape& shape) {
    333   CHECK(IsDenseArray(shape));
    334   return AsInt64Slice(shape.layout().minor_to_major());
    335 }
    336 
    337 /* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
    338     const Layout& layout) {
    339   CHECK(layout.format() == DENSE);
    340   return AsInt64Slice(layout.minor_to_major());
    341 }
    342 
    343 /* static */ int64 LayoutUtil::Major(const Layout& layout,
    344                                      int64 physical_dimension_number) {
    345   CHECK_LE(0, physical_dimension_number);
    346   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
    347   return Minor(layout,
    348                layout.minor_to_major_size() - 1 - physical_dimension_number);
    349 }
    350 
    351 /* static */ int64 LayoutUtil::Minor(const Layout& layout,
    352                                      int64 physical_dimension_number) {
    353   CHECK_EQ(layout.format(), DENSE);
    354   CHECK_LE(0, physical_dimension_number);
    355   CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
    356   return layout.minor_to_major(physical_dimension_number);
    357 }
    358 
    359 /* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
    360     const Layout& layout) {
    361   std::vector<int64> logical_to_physical(layout.minor_to_major_size());
    362   for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) {
    363     const int64 logical = Major(layout, physical);
    364     logical_to_physical[logical] = physical;
    365   }
    366   return logical_to_physical;
    367 }
    368 
    369 /* static */ string LayoutUtil::HumanString(const Layout& layout) {
    370   return layout.ToString();
    371 }
    372 
    373 namespace {
    374 
    375 // Internal helper for recursively copying layouts.
    376 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
    377   if (src.IsTuple() != dst->IsTuple()) {
    378     return InvalidArgument(
    379         "cannot copy layout from shape: shape structure differs");
    380   }
    381   if (src.IsTuple()) {
    382     if (ShapeUtil::TupleElementCount(src) !=
    383         ShapeUtil::TupleElementCount(*dst)) {
    384       return InvalidArgument(
    385           "cannot copy layout from shape: tuple element count differs");
    386     }
    387     for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
    388       TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
    389                                             dst->mutable_tuple_shapes(i)));
    390     }
    391   } else {
    392     if (src.has_layout()) {
    393       if (src.rank() != dst->rank()) {
    394         return InvalidArgument("cannot copy layout from shape: ranks differs");
    395       }
    396       TF_RETURN_IF_ERROR(
    397           LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
    398       *dst->mutable_layout() = src.layout();
    399     } else {
    400       dst->clear_layout();
    401     }
    402   }
    403   return Status::OK();
    404 }
    405 
    406 }  // namespace
    407 
    408 /* static */
    409 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
    410   return CopyLayoutInternal(src, dst);
    411 }
    412 
    413 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
    414                                                    const Shape& rhs) {
    415   if (lhs.IsTuple()) {
    416     if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
    417                               ShapeUtil::TupleElementCount(rhs)) {
    418       return false;
    419     }
    420     for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
    421       if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
    422         return false;
    423       }
    424     }
    425     return true;
    426   } else if (lhs.IsArray()) {
    427     return lhs.rank() == rhs.rank() &&
    428            LayoutUtil::Equal(lhs.layout(), rhs.layout());
    429   } else {
    430     // Layouts of non-array and non-tuple shapes is ignored.
    431     return true;
    432   }
    433 }
    434 
    435 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
    436     const Layout& layout, absl::Span<const int64> dims) {
    437   CHECK(IsDense(layout));
    438   std::vector<int64> positions_in_layout;
    439   for (int64 dim : dims) {
    440     positions_in_layout.push_back(
    441         PositionInContainer(layout.minor_to_major(), dim));
    442   }
    443   absl::c_sort(positions_in_layout);
    444   for (size_t i = 1; i < positions_in_layout.size(); ++i) {
    445     if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
    446       return false;
    447     }
    448   }
    449   return true;
    450 }
    451 
    452 /*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
    453   using tensorflow::hash;
    454   using tensorflow::Hash64Combine;
    455 
    456   size_t hash_value = hash<Format>()(layout.format());
    457 
    458   for (int64 minor_to_major : layout.minor_to_major()) {
    459     hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
    460   }
    461   hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
    462 
    463   for (Tile tile : layout.tiles()) {
    464     for (int64 tile_dim : tile.dimensions()) {
    465       hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
    466     }
    467   }
    468   hash_value = Hash64Combine(hash_value, layout.element_size_in_bits());
    469 
    470   return hash_value;
    471 }
    472 
    473 }  // namespace xla
    474