1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/window_util.h" 17 18 #include <vector> 19 20 #include "absl/algorithm/container.h" 21 #include "absl/strings/str_cat.h" 22 #include "tensorflow/compiler/xla/types.h" 23 #include "tensorflow/compiler/xla/xla_data.pb.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace xla { 27 namespace window_util { 28 29 Window MakeWindow(absl::Span<const int64> sizes) { 30 Window window; 31 for (int64 size : sizes) { 32 auto* dimension = window.add_dimensions(); 33 dimension->set_size(size); 34 dimension->set_stride(1); 35 dimension->set_base_dilation(1); 36 dimension->set_window_dilation(1); 37 } 38 return window; 39 } 40 41 PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) { 42 PaddingConfig config; 43 for (int64 size : sizes) { 44 auto* dimension = config.add_dimensions(); 45 dimension->set_edge_padding_low(size); 46 dimension->set_edge_padding_high(size); 47 } 48 return config; 49 } 50 51 /* static */ string ToString(const WindowDimension& dim) { 52 using absl::StrAppend; 53 using absl::StrCat; 54 string str = StrCat("(size=", dim.size()); 55 if (dim.stride() != 1) { 56 StrAppend(&str, ",stride=", dim.stride()); 57 } 58 if (dim.padding_low() != 0) { 59 StrAppend(&str, ",padding_low=", dim.padding_low()); 60 } 61 if (dim.padding_high() != 0) { 62 StrAppend(&str, ",padding_high=", dim.padding_high()); 63 } 64 if (dim.base_dilation() != 1) { 65 StrAppend(&str, ",base_dilation=", dim.base_dilation()); 66 } 67 if (dim.window_dilation() != 1) { 68 StrAppend(&str, ",window_dilation=", dim.window_dilation()); 69 } 70 if (dim.window_reversal()) { 71 StrAppend(&str, ",window_reversal"); 72 } 73 StrAppend(&str, ")"); 74 return str; 75 } 76 77 string ToString(const Window& window) { 78 using absl::StrAppend; 79 using absl::StrCat; 80 81 string str; 82 const auto add_field = 83 [&](const char* heading, 84 std::function<string(const WindowDimension&)> format) { 85 StrAppend(&str, heading, "="); 86 const char* prefix = ""; 87 for (const auto& window_dimension : window.dimensions()) { 88 StrAppend(&str, prefix, format(window_dimension)); 89 prefix = "x"; 90 } 91 }; 92 93 add_field("size", 94 [](const WindowDimension& dim) { return StrCat(dim.size()); }); 95 if (HasStride(window)) { 96 add_field(" stride", 97 [](const WindowDimension& dim) { return StrCat(dim.stride()); }); 98 } 99 if (HasPadding(window)) { 100 add_field(" pad", [](const WindowDimension& dim) { 101 return StrCat(dim.padding_low(), "_", dim.padding_high()); 102 }); 103 } 104 if (HasBaseDilation(window)) { 105 add_field(" lhs_dilate", [](const WindowDimension& dim) { 106 return StrCat(dim.base_dilation()); 107 }); 108 } 109 if (HasWindowDilation(window)) { 110 add_field(" rhs_dilate", [](const WindowDimension& dim) { 111 return StrCat(dim.window_dilation()); 112 }); 113 } 114 if (HasWindowReversal(window)) { 115 add_field(" rhs_reversal", [](const WindowDimension& dim) { 116 return StrCat(dim.window_reversal() ? 1 : 0); 117 }); 118 } 119 return str; 120 } 121 122 bool HasStride(const Window& window) { 123 for (const auto& dim : window.dimensions()) { 124 if (dim.stride() != 1) { 125 return true; 126 } 127 } 128 return false; 129 } 130 131 bool HasPadding(const Window& window) { 132 for (const auto& dim : window.dimensions()) { 133 if (dim.padding_low() != 0 || dim.padding_high() != 0) { 134 return true; 135 } 136 } 137 return false; 138 } 139 140 bool HasSymmetricPadding(const Window& window) { 141 return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) { 142 return dim.padding_low() == dim.padding_high(); 143 }); 144 } 145 146 bool HasSymmetricPadding(const PaddingConfig& padding_config) { 147 return absl::c_all_of(padding_config.dimensions(), 148 [](const PaddingConfig::PaddingConfigDimension& dim) { 149 return dim.edge_padding_low() == 150 dim.edge_padding_high(); 151 }); 152 } 153 154 bool HasNegativePadding(const Window& window) { 155 return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) { 156 return dim.padding_low() < 0 || dim.padding_high() < 0; 157 }); 158 } 159 160 bool HasBaseDilation(const Window& window) { 161 for (const auto& dim : window.dimensions()) { 162 if (dim.base_dilation() != 1) { 163 return true; 164 } 165 } 166 return false; 167 } 168 169 bool HasWindowDilation(const Window& window) { 170 for (const auto& dim : window.dimensions()) { 171 if (dim.window_dilation() != 1) { 172 return true; 173 } 174 } 175 return false; 176 } 177 178 bool HasWindowReversal(const Window& window) { 179 for (const auto& dim : window.dimensions()) { 180 if (dim.window_reversal()) { 181 return true; 182 } 183 } 184 return false; 185 } 186 187 bool AllOrNoneReversed(const Window& window) { 188 if (window.dimensions().empty()) { 189 return true; 190 } 191 bool reversed = window.dimensions()[0].window_reversal(); 192 return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) { 193 return dim.window_reversal() == reversed; 194 }); 195 } 196 197 bool HasDilation(const Window& window) { 198 return HasBaseDilation(window) || HasWindowDilation(window); 199 } 200 201 bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { 202 const WindowDimension& window_dim = window.dimensions(logical_dim); 203 return window_dim.size() == 1 && window_dim.stride() == 1 && 204 window_dim.padding_low() == 0 && window_dim.padding_high() == 0; 205 } 206 207 bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { 208 return window_dimension.size() == 1 && window_dimension.stride() == 1 && 209 window_dimension.padding_low() == 0 && 210 window_dimension.padding_high() == 0 && 211 window_dimension.window_dilation() == 1 && 212 window_dimension.base_dilation() == 1; 213 } 214 215 int64 DilatedBound(int64 bound, int64 dilation) { 216 CHECK_GE(bound, 0); 217 CHECK_GE(dilation, 1); 218 if (bound == 0) { 219 return 0; 220 } 221 222 // Suppose the array has three entries 123 and the dilation factor is 4. Then 223 // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except 224 // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we 225 // add 1 to account for the final input element. 226 return (bound - 1) * dilation + 1; 227 } 228 229 int64 StridedBound(int64 bound, int64 window_size, int64 stride) { 230 CHECK_GE(window_size, 0); 231 CHECK_GE(bound, 0); 232 CHECK_GE(stride, 1); 233 234 if (bound == 0 || window_size > bound) { 235 return 0; 236 } 237 238 // Without considering stride, the maximum valid offset is bound - 239 // window_size. Taking stride into account, the valid offsets then have the 240 // form q * stride for q = 0, ..., Q such that q * stride <= bound - 241 // window_size. This implies that Q equals floor(bound - window_size / 242 // stride). There are Q + 1 valid values of q, yielding the formula below. 243 return (bound - window_size) / stride + 1; 244 } 245 246 } // namespace window_util 247 } // namespace xla 248