1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/window_util.h" 17 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/types.h" 21 #include "tensorflow/compiler/xla/xla_data.pb.h" 22 #include "tensorflow/core/lib/strings/str_util.h" 23 #include "tensorflow/core/lib/strings/strcat.h" 24 #include "tensorflow/core/lib/strings/stringprintf.h" 25 26 namespace xla { 27 namespace window_util { 28 29 Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) { 30 Window window; 31 for (int64 size : sizes) { 32 auto* dimension = window.add_dimensions(); 33 dimension->set_size(size); 34 dimension->set_stride(1); 35 dimension->set_base_dilation(1); 36 dimension->set_window_dilation(1); 37 } 38 return window; 39 } 40 41 PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) { 42 PaddingConfig config; 43 for (int64 size : sizes) { 44 auto* dimension = config.add_dimensions(); 45 dimension->set_edge_padding_low(size); 46 dimension->set_edge_padding_high(size); 47 } 48 return config; 49 } 50 51 /* static */ string ToString(const WindowDimension& dim) { 52 using tensorflow::strings::StrAppend; 53 using tensorflow::strings::StrCat; 54 string str = StrCat("(size=", dim.size()); 55 if (dim.stride() != 1) { 56 StrAppend(&str, ",stride=", dim.stride()); 57 } 58 if (dim.padding_low() != 0) { 59 StrAppend(&str, ",padding_low=", dim.padding_low()); 60 } 61 if (dim.padding_high() != 0) { 62 StrAppend(&str, ",padding_high=", dim.padding_high()); 63 } 64 if (dim.base_dilation() != 1) { 65 StrAppend(&str, ",base_dilation=", dim.base_dilation()); 66 } 67 if (dim.window_dilation() != 1) { 68 StrAppend(&str, ",window_dilation=", dim.window_dilation()); 69 } 70 if (dim.window_reversal()) { 71 StrAppend(&str, ",window_reversal"); 72 } 73 StrAppend(&str, ")"); 74 return str; 75 } 76 77 string ToString(const Window& window) { 78 using tensorflow::strings::StrAppend; 79 using tensorflow::strings::StrCat; 80 81 string str; 82 const auto add_field = 83 [&](const char* heading, 84 std::function<string(const WindowDimension&)> format) { 85 StrAppend(&str, heading, "="); 86 const char* prefix = ""; 87 for (const auto& window_dimension : window.dimensions()) { 88 StrAppend(&str, prefix, format(window_dimension)); 89 prefix = "x"; 90 } 91 }; 92 93 add_field("size", 94 [](const WindowDimension& dim) { return StrCat(dim.size()); }); 95 if (HasStride(window)) { 96 add_field(" stride", 97 [](const WindowDimension& dim) { return StrCat(dim.stride()); }); 98 } 99 if (HasPadding(window)) { 100 add_field(" pad", [](const WindowDimension& dim) { 101 return StrCat(dim.padding_low(), "_", dim.padding_high()); 102 }); 103 } 104 if (HasBaseDilation(window)) { 105 add_field(" lhs_dilate", [](const WindowDimension& dim) { 106 return StrCat(dim.base_dilation()); 107 }); 108 } 109 if (HasWindowDilation(window)) { 110 add_field(" rhs_dilate", [](const WindowDimension& dim) { 111 return StrCat(dim.window_dilation()); 112 }); 113 } 114 if (HasWindowReversal(window)) { 115 add_field(" rhs_reversal", [](const WindowDimension& dim) { 116 return StrCat(dim.window_reversal() ? 1 : 0); 117 }); 118 } 119 return str; 120 } 121 122 bool HasStride(const Window& window) { 123 for (const auto& dim : window.dimensions()) { 124 if (dim.stride() != 1) { 125 return true; 126 } 127 } 128 return false; 129 } 130 131 bool HasPadding(const Window& window) { 132 for (const auto& dim : window.dimensions()) { 133 if (dim.padding_low() != 0 || dim.padding_high() != 0) { 134 return true; 135 } 136 } 137 return false; 138 } 139 140 bool HasSymmetricPadding(const Window& window) { 141 return std::all_of(window.dimensions().begin(), window.dimensions().end(), 142 [](const WindowDimension& dim) { 143 return dim.padding_low() == dim.padding_high(); 144 }); 145 } 146 147 bool HasSymmetricPadding(const PaddingConfig& padding_config) { 148 return std::all_of(padding_config.dimensions().begin(), 149 padding_config.dimensions().end(), 150 [](const PaddingConfig::PaddingConfigDimension& dim) { 151 return dim.edge_padding_low() == dim.edge_padding_high(); 152 }); 153 } 154 155 bool HasNegativePadding(const Window& window) { 156 return std::any_of(window.dimensions().begin(), window.dimensions().end(), 157 [](const WindowDimension& dim) { 158 return dim.padding_low() < 0 || dim.padding_high() < 0; 159 }); 160 } 161 162 bool HasBaseDilation(const Window& window) { 163 for (const auto& dim : window.dimensions()) { 164 if (dim.base_dilation() != 1) { 165 return true; 166 } 167 } 168 return false; 169 } 170 171 bool HasWindowDilation(const Window& window) { 172 for (const auto& dim : window.dimensions()) { 173 if (dim.window_dilation() != 1) { 174 return true; 175 } 176 } 177 return false; 178 } 179 180 bool HasWindowReversal(const Window& window) { 181 for (const auto& dim : window.dimensions()) { 182 if (dim.window_reversal()) { 183 return true; 184 } 185 } 186 return false; 187 } 188 189 bool HasDilation(const Window& window) { 190 return HasBaseDilation(window) || HasWindowDilation(window); 191 } 192 193 bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { 194 const WindowDimension& window_dim = window.dimensions(logical_dim); 195 return window_dim.size() == 1 && window_dim.stride() == 1 && 196 window_dim.padding_low() == 0 && window_dim.padding_high() == 0; 197 } 198 199 int64 DilatedBound(int64 bound, int64 dilation) { 200 CHECK_GE(bound, 0); 201 CHECK_GE(dilation, 1); 202 203 // Suppose the array has three entries 123 and the dilation factor is 4. Then 204 // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except 205 // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we 206 // add 1 to account for the final input element. 207 return (bound - 1) * dilation + 1; 208 } 209 210 int64 StridedBound(int64 bound, int64 window_size, int64 stride) { 211 CHECK_GE(window_size, 0); 212 CHECK_GE(bound, 0); 213 CHECK_GE(stride, 1); 214 215 if (window_size > bound) { 216 return 0; 217 } 218 219 // Without considering stride, the maximum valid offset is bound - 220 // window_size. Taking stride into account, the valid offsets then have the 221 // form q * stride for q = 0, ..., Q such that q * stride <= bound - 222 // window_size. This implies that Q equals floor(bound - window_size / 223 // stride). There are Q + 1 valid values of q, yielding the formula below. 224 return (bound - window_size) / stride + 1; 225 } 226 227 } // namespace window_util 228 } // namespace xla 229