1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/layout.h" 17 18 #include "absl/strings/str_cat.h" 19 #include "absl/strings/str_join.h" 20 #include "tensorflow/compiler/xla/layout_util.h" 21 22 namespace xla { 23 24 TileProto Tile::ToProto() const { 25 TileProto tile_proto; 26 for (int64 i : dimensions()) { 27 tile_proto.add_dimensions(i); 28 } 29 return tile_proto; 30 } 31 32 string Tile::ToString() const { 33 std::vector<string> elements; 34 for (auto dim : dimensions()) { 35 if (dim >= 0) { 36 elements.push_back(std::to_string(dim)); 37 } else { 38 if (dim == kCombineDimension) { 39 elements.push_back("*"); 40 } else { 41 elements.push_back(absl::StrCat("Invalid value ", dim)); 42 } 43 } 44 } 45 return absl::StrCat("(", absl::StrJoin(elements, ","), ")"); 46 } 47 48 /* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { 49 Layout layout; 50 layout.set_format(proto.format()); 51 layout.minor_to_major_.reserve(proto.minor_to_major_size()); 52 for (const int64 dimension : proto.minor_to_major()) { 53 layout.add_minor_to_major(dimension); 54 } 55 layout.set_max_sparse_elements(proto.max_sparse_elements()); 56 for (const TileProto& tile_proto : proto.tiles()) { 57 *layout.add_tiles() = Tile::CreateFromProto(tile_proto); 58 } 59 layout.set_element_size_in_bits(proto.element_size_in_bits()); 60 return layout; 61 } 62 63 LayoutProto Layout::ToProto() const { 64 LayoutProto proto; 65 proto.set_format(format_); 66 proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); 67 for (const int64 dimension : minor_to_major()) { 68 proto.add_minor_to_major(dimension); 69 } 70 proto.set_max_sparse_elements(max_sparse_elements_); 71 for (const Tile& tile : tiles()) { 72 *proto.add_tiles() = tile.ToProto(); 73 } 74 proto.set_element_size_in_bits(element_size_in_bits()); 75 return proto; 76 } 77 78 string Layout::ToString() const { 79 if (format() == SPARSE) { 80 CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled."; 81 return absl::StrCat("sparse{", max_sparse_elements(), "}"); 82 } else if (format() == DENSE) { 83 string colon_string = tiles().empty() ? "" : "T"; 84 for (Tile tile : tiles()) { 85 absl::StrAppend(&colon_string, tile.ToString()); 86 } 87 if (element_size_in_bits() != 0) { 88 absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")"); 89 } 90 return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), 91 colon_string.empty() ? "" : ":", colon_string, "}"); 92 } else { 93 CHECK_EQ(format(), INVALID_FORMAT); 94 return "invalid{}"; 95 } 96 } 97 98 bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { 99 if (lhs.format() != rhs.format() || 100 lhs.minor_to_major() != rhs.minor_to_major() || 101 lhs.max_sparse_elements() != rhs.max_sparse_elements()) { 102 return false; 103 } 104 if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { 105 return false; 106 } 107 if (!ignore_element_size_ && 108 lhs.element_size_in_bits() != rhs.element_size_in_bits()) { 109 return false; 110 } 111 return true; 112 } 113 114 bool Layout::operator==(const Layout& other) const { 115 return Equal()(*this, other); 116 } 117 118 std::ostream& operator<<(std::ostream& out, const Tile& tile) { 119 out << tile.ToString(); 120 return out; 121 } 122 123 std::ostream& operator<<(std::ostream& out, const Layout& layout) { 124 out << layout.ToString(); 125 return out; 126 } 127 128 } // namespace xla 129