Home | History | Annotate | Download | only in xla
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/layout.h"
     17 
     18 #include "absl/strings/str_cat.h"
     19 #include "absl/strings/str_join.h"
     20 #include "tensorflow/compiler/xla/layout_util.h"
     21 
     22 namespace xla {
     23 
     24 TileProto Tile::ToProto() const {
     25   TileProto tile_proto;
     26   for (int64 i : dimensions()) {
     27     tile_proto.add_dimensions(i);
     28   }
     29   return tile_proto;
     30 }
     31 
     32 string Tile::ToString() const {
     33   std::vector<string> elements;
     34   for (auto dim : dimensions()) {
     35     if (dim >= 0) {
     36       elements.push_back(std::to_string(dim));
     37     } else {
     38       if (dim == kCombineDimension) {
     39         elements.push_back("*");
     40       } else {
     41         elements.push_back(absl::StrCat("Invalid value ", dim));
     42       }
     43     }
     44   }
     45   return absl::StrCat("(", absl::StrJoin(elements, ","), ")");
     46 }
     47 
     48 /* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) {
     49   Layout layout;
     50   layout.set_format(proto.format());
     51   layout.minor_to_major_.reserve(proto.minor_to_major_size());
     52   for (const int64 dimension : proto.minor_to_major()) {
     53     layout.add_minor_to_major(dimension);
     54   }
     55   layout.set_max_sparse_elements(proto.max_sparse_elements());
     56   for (const TileProto& tile_proto : proto.tiles()) {
     57     *layout.add_tiles() = Tile::CreateFromProto(tile_proto);
     58   }
     59   layout.set_element_size_in_bits(proto.element_size_in_bits());
     60   return layout;
     61 }
     62 
     63 LayoutProto Layout::ToProto() const {
     64   LayoutProto proto;
     65   proto.set_format(format_);
     66   proto.mutable_minor_to_major()->Reserve(minor_to_major_size());
     67   for (const int64 dimension : minor_to_major()) {
     68     proto.add_minor_to_major(dimension);
     69   }
     70   proto.set_max_sparse_elements(max_sparse_elements_);
     71   for (const Tile& tile : tiles()) {
     72     *proto.add_tiles() = tile.ToProto();
     73   }
     74   proto.set_element_size_in_bits(element_size_in_bits());
     75   return proto;
     76 }
     77 
     78 string Layout::ToString() const {
     79   if (format() == SPARSE) {
     80     CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled.";
     81     return absl::StrCat("sparse{", max_sparse_elements(), "}");
     82   } else if (format() == DENSE) {
     83     string colon_string = tiles().empty() ? "" : "T";
     84     for (Tile tile : tiles()) {
     85       absl::StrAppend(&colon_string, tile.ToString());
     86     }
     87     if (element_size_in_bits() != 0) {
     88       absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")");
     89     }
     90     return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","),
     91                         colon_string.empty() ? "" : ":", colon_string, "}");
     92   } else {
     93     CHECK_EQ(format(), INVALID_FORMAT);
     94     return "invalid{}";
     95   }
     96 }
     97 
     98 bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
     99   if (lhs.format() != rhs.format() ||
    100       lhs.minor_to_major() != rhs.minor_to_major() ||
    101       lhs.max_sparse_elements() != rhs.max_sparse_elements()) {
    102     return false;
    103   }
    104   if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) {
    105     return false;
    106   }
    107   if (!ignore_element_size_ &&
    108       lhs.element_size_in_bits() != rhs.element_size_in_bits()) {
    109     return false;
    110   }
    111   return true;
    112 }
    113 
    114 bool Layout::operator==(const Layout& other) const {
    115   return Equal()(*this, other);
    116 }
    117 
    118 std::ostream& operator<<(std::ostream& out, const Tile& tile) {
    119   out << tile.ToString();
    120   return out;
    121 }
    122 
    123 std::ostream& operator<<(std::ostream& out, const Layout& layout) {
    124   out << layout.ToString();
    125   return out;
    126 }
    127 
    128 }  // namespace xla
    129