Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/window_util.h"
     17 
     18 #include <vector>
     19 
     20 #include "absl/algorithm/container.h"
     21 #include "absl/strings/str_cat.h"
     22 #include "tensorflow/compiler/xla/types.h"
     23 #include "tensorflow/compiler/xla/xla_data.pb.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace xla {
     27 namespace window_util {
     28 
     29 Window MakeWindow(absl::Span<const int64> sizes) {
     30   Window window;
     31   for (int64 size : sizes) {
     32     auto* dimension = window.add_dimensions();
     33     dimension->set_size(size);
     34     dimension->set_stride(1);
     35     dimension->set_base_dilation(1);
     36     dimension->set_window_dilation(1);
     37   }
     38   return window;
     39 }
     40 
     41 PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
     42   PaddingConfig config;
     43   for (int64 size : sizes) {
     44     auto* dimension = config.add_dimensions();
     45     dimension->set_edge_padding_low(size);
     46     dimension->set_edge_padding_high(size);
     47   }
     48   return config;
     49 }
     50 
     51 /* static */ string ToString(const WindowDimension& dim) {
     52   using absl::StrAppend;
     53   using absl::StrCat;
     54   string str = StrCat("(size=", dim.size());
     55   if (dim.stride() != 1) {
     56     StrAppend(&str, ",stride=", dim.stride());
     57   }
     58   if (dim.padding_low() != 0) {
     59     StrAppend(&str, ",padding_low=", dim.padding_low());
     60   }
     61   if (dim.padding_high() != 0) {
     62     StrAppend(&str, ",padding_high=", dim.padding_high());
     63   }
     64   if (dim.base_dilation() != 1) {
     65     StrAppend(&str, ",base_dilation=", dim.base_dilation());
     66   }
     67   if (dim.window_dilation() != 1) {
     68     StrAppend(&str, ",window_dilation=", dim.window_dilation());
     69   }
     70   if (dim.window_reversal()) {
     71     StrAppend(&str, ",window_reversal");
     72   }
     73   StrAppend(&str, ")");
     74   return str;
     75 }
     76 
     77 string ToString(const Window& window) {
     78   using absl::StrAppend;
     79   using absl::StrCat;
     80 
     81   string str;
     82   const auto add_field =
     83       [&](const char* heading,
     84           std::function<string(const WindowDimension&)> format) {
     85         StrAppend(&str, heading, "=");
     86         const char* prefix = "";
     87         for (const auto& window_dimension : window.dimensions()) {
     88           StrAppend(&str, prefix, format(window_dimension));
     89           prefix = "x";
     90         }
     91       };
     92 
     93   add_field("size",
     94             [](const WindowDimension& dim) { return StrCat(dim.size()); });
     95   if (HasStride(window)) {
     96     add_field(" stride",
     97               [](const WindowDimension& dim) { return StrCat(dim.stride()); });
     98   }
     99   if (HasPadding(window)) {
    100     add_field(" pad", [](const WindowDimension& dim) {
    101       return StrCat(dim.padding_low(), "_", dim.padding_high());
    102     });
    103   }
    104   if (HasBaseDilation(window)) {
    105     add_field(" lhs_dilate", [](const WindowDimension& dim) {
    106       return StrCat(dim.base_dilation());
    107     });
    108   }
    109   if (HasWindowDilation(window)) {
    110     add_field(" rhs_dilate", [](const WindowDimension& dim) {
    111       return StrCat(dim.window_dilation());
    112     });
    113   }
    114   if (HasWindowReversal(window)) {
    115     add_field(" rhs_reversal", [](const WindowDimension& dim) {
    116       return StrCat(dim.window_reversal() ? 1 : 0);
    117     });
    118   }
    119   return str;
    120 }
    121 
    122 bool HasStride(const Window& window) {
    123   for (const auto& dim : window.dimensions()) {
    124     if (dim.stride() != 1) {
    125       return true;
    126     }
    127   }
    128   return false;
    129 }
    130 
    131 bool HasPadding(const Window& window) {
    132   for (const auto& dim : window.dimensions()) {
    133     if (dim.padding_low() != 0 || dim.padding_high() != 0) {
    134       return true;
    135     }
    136   }
    137   return false;
    138 }
    139 
    140 bool HasSymmetricPadding(const Window& window) {
    141   return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
    142     return dim.padding_low() == dim.padding_high();
    143   });
    144 }
    145 
    146 bool HasSymmetricPadding(const PaddingConfig& padding_config) {
    147   return absl::c_all_of(padding_config.dimensions(),
    148                         [](const PaddingConfig::PaddingConfigDimension& dim) {
    149                           return dim.edge_padding_low() ==
    150                                  dim.edge_padding_high();
    151                         });
    152 }
    153 
    154 bool HasNegativePadding(const Window& window) {
    155   return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
    156     return dim.padding_low() < 0 || dim.padding_high() < 0;
    157   });
    158 }
    159 
    160 bool HasBaseDilation(const Window& window) {
    161   for (const auto& dim : window.dimensions()) {
    162     if (dim.base_dilation() != 1) {
    163       return true;
    164     }
    165   }
    166   return false;
    167 }
    168 
    169 bool HasWindowDilation(const Window& window) {
    170   for (const auto& dim : window.dimensions()) {
    171     if (dim.window_dilation() != 1) {
    172       return true;
    173     }
    174   }
    175   return false;
    176 }
    177 
    178 bool HasWindowReversal(const Window& window) {
    179   for (const auto& dim : window.dimensions()) {
    180     if (dim.window_reversal()) {
    181       return true;
    182     }
    183   }
    184   return false;
    185 }
    186 
    187 bool AllOrNoneReversed(const Window& window) {
    188   if (window.dimensions().empty()) {
    189     return true;
    190   }
    191   bool reversed = window.dimensions()[0].window_reversal();
    192   return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
    193     return dim.window_reversal() == reversed;
    194   });
    195 }
    196 
    197 bool HasDilation(const Window& window) {
    198   return HasBaseDilation(window) || HasWindowDilation(window);
    199 }
    200 
    201 bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) {
    202   const WindowDimension& window_dim = window.dimensions(logical_dim);
    203   return window_dim.size() == 1 && window_dim.stride() == 1 &&
    204          window_dim.padding_low() == 0 && window_dim.padding_high() == 0;
    205 }
    206 
    207 bool IsTrivialWindowDimension(const WindowDimension& window_dimension) {
    208   return window_dimension.size() == 1 && window_dimension.stride() == 1 &&
    209          window_dimension.padding_low() == 0 &&
    210          window_dimension.padding_high() == 0 &&
    211          window_dimension.window_dilation() == 1 &&
    212          window_dimension.base_dilation() == 1;
    213 }
    214 
    215 int64 DilatedBound(int64 bound, int64 dilation) {
    216   CHECK_GE(bound, 0);
    217   CHECK_GE(dilation, 1);
    218   if (bound == 0) {
    219     return 0;
    220   }
    221 
    222   // Suppose the array has three entries 123 and the dilation factor is 4. Then
    223   // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
    224   // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
    225   // add 1 to account for the final input element.
    226   return (bound - 1) * dilation + 1;
    227 }
    228 
    229 int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
    230   CHECK_GE(window_size, 0);
    231   CHECK_GE(bound, 0);
    232   CHECK_GE(stride, 1);
    233 
    234   if (bound == 0 || window_size > bound) {
    235     return 0;
    236   }
    237 
    238   // Without considering stride, the maximum valid offset is bound -
    239   // window_size. Taking stride into account, the valid offsets then have the
    240   // form q * stride for q = 0, ..., Q such that q * stride <= bound -
    241   // window_size. This implies that Q equals floor(bound - window_size /
    242   // stride). There are Q + 1 valid values of q, yielding the formula below.
    243   return (bound - window_size) / stride + 1;
    244 }
    245 
    246 }  // namespace window_util
    247 }  // namespace xla
    248