Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/window_util.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/types.h"
     21 #include "tensorflow/compiler/xla/xla_data.pb.h"
     22 #include "tensorflow/core/lib/strings/str_util.h"
     23 #include "tensorflow/core/lib/strings/strcat.h"
     24 #include "tensorflow/core/lib/strings/stringprintf.h"
     25 
     26 namespace xla {
     27 namespace window_util {
     28 
     29 Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
     30   Window window;
     31   for (int64 size : sizes) {
     32     auto* dimension = window.add_dimensions();
     33     dimension->set_size(size);
     34     dimension->set_stride(1);
     35     dimension->set_base_dilation(1);
     36     dimension->set_window_dilation(1);
     37   }
     38   return window;
     39 }
     40 
     41 PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
     42   PaddingConfig config;
     43   for (int64 size : sizes) {
     44     auto* dimension = config.add_dimensions();
     45     dimension->set_edge_padding_low(size);
     46     dimension->set_edge_padding_high(size);
     47   }
     48   return config;
     49 }
     50 
     51 /* static */ string ToString(const WindowDimension& dim) {
     52   using tensorflow::strings::StrAppend;
     53   using tensorflow::strings::StrCat;
     54   string str = StrCat("(size=", dim.size());
     55   if (dim.stride() != 1) {
     56     StrAppend(&str, ",stride=", dim.stride());
     57   }
     58   if (dim.padding_low() != 0) {
     59     StrAppend(&str, ",padding_low=", dim.padding_low());
     60   }
     61   if (dim.padding_high() != 0) {
     62     StrAppend(&str, ",padding_high=", dim.padding_high());
     63   }
     64   if (dim.base_dilation() != 1) {
     65     StrAppend(&str, ",base_dilation=", dim.base_dilation());
     66   }
     67   if (dim.window_dilation() != 1) {
     68     StrAppend(&str, ",window_dilation=", dim.window_dilation());
     69   }
     70   if (dim.window_reversal()) {
     71     StrAppend(&str, ",window_reversal");
     72   }
     73   StrAppend(&str, ")");
     74   return str;
     75 }
     76 
     77 string ToString(const Window& window) {
     78   using tensorflow::strings::StrAppend;
     79   using tensorflow::strings::StrCat;
     80 
     81   string str;
     82   const auto add_field =
     83       [&](const char* heading,
     84           std::function<string(const WindowDimension&)> format) {
     85         StrAppend(&str, heading, "=");
     86         const char* prefix = "";
     87         for (const auto& window_dimension : window.dimensions()) {
     88           StrAppend(&str, prefix, format(window_dimension));
     89           prefix = "x";
     90         }
     91       };
     92 
     93   add_field("size",
     94             [](const WindowDimension& dim) { return StrCat(dim.size()); });
     95   if (HasStride(window)) {
     96     add_field(" stride",
     97               [](const WindowDimension& dim) { return StrCat(dim.stride()); });
     98   }
     99   if (HasPadding(window)) {
    100     add_field(" pad", [](const WindowDimension& dim) {
    101       return StrCat(dim.padding_low(), "_", dim.padding_high());
    102     });
    103   }
    104   if (HasBaseDilation(window)) {
    105     add_field(" lhs_dilate", [](const WindowDimension& dim) {
    106       return StrCat(dim.base_dilation());
    107     });
    108   }
    109   if (HasWindowDilation(window)) {
    110     add_field(" rhs_dilate", [](const WindowDimension& dim) {
    111       return StrCat(dim.window_dilation());
    112     });
    113   }
    114   if (HasWindowReversal(window)) {
    115     add_field(" rhs_reversal", [](const WindowDimension& dim) {
    116       return StrCat(dim.window_reversal() ? 1 : 0);
    117     });
    118   }
    119   return str;
    120 }
    121 
    122 bool HasStride(const Window& window) {
    123   for (const auto& dim : window.dimensions()) {
    124     if (dim.stride() != 1) {
    125       return true;
    126     }
    127   }
    128   return false;
    129 }
    130 
    131 bool HasPadding(const Window& window) {
    132   for (const auto& dim : window.dimensions()) {
    133     if (dim.padding_low() != 0 || dim.padding_high() != 0) {
    134       return true;
    135     }
    136   }
    137   return false;
    138 }
    139 
    140 bool HasSymmetricPadding(const Window& window) {
    141   return std::all_of(window.dimensions().begin(), window.dimensions().end(),
    142                      [](const WindowDimension& dim) {
    143                        return dim.padding_low() == dim.padding_high();
    144                      });
    145 }
    146 
    147 bool HasSymmetricPadding(const PaddingConfig& padding_config) {
    148   return std::all_of(padding_config.dimensions().begin(),
    149                      padding_config.dimensions().end(),
    150                      [](const PaddingConfig::PaddingConfigDimension& dim) {
    151                        return dim.edge_padding_low() == dim.edge_padding_high();
    152                      });
    153 }
    154 
    155 bool HasNegativePadding(const Window& window) {
    156   return std::any_of(window.dimensions().begin(), window.dimensions().end(),
    157                      [](const WindowDimension& dim) {
    158                        return dim.padding_low() < 0 || dim.padding_high() < 0;
    159                      });
    160 }
    161 
    162 bool HasBaseDilation(const Window& window) {
    163   for (const auto& dim : window.dimensions()) {
    164     if (dim.base_dilation() != 1) {
    165       return true;
    166     }
    167   }
    168   return false;
    169 }
    170 
    171 bool HasWindowDilation(const Window& window) {
    172   for (const auto& dim : window.dimensions()) {
    173     if (dim.window_dilation() != 1) {
    174       return true;
    175     }
    176   }
    177   return false;
    178 }
    179 
    180 bool HasWindowReversal(const Window& window) {
    181   for (const auto& dim : window.dimensions()) {
    182     if (dim.window_reversal()) {
    183       return true;
    184     }
    185   }
    186   return false;
    187 }
    188 
    189 bool HasDilation(const Window& window) {
    190   return HasBaseDilation(window) || HasWindowDilation(window);
    191 }
    192 
    193 bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) {
    194   const WindowDimension& window_dim = window.dimensions(logical_dim);
    195   return window_dim.size() == 1 && window_dim.stride() == 1 &&
    196          window_dim.padding_low() == 0 && window_dim.padding_high() == 0;
    197 }
    198 
    199 int64 DilatedBound(int64 bound, int64 dilation) {
    200   CHECK_GE(bound, 0);
    201   CHECK_GE(dilation, 1);
    202 
    203   // Suppose the array has three entries 123 and the dilation factor is 4. Then
    204   // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
    205   // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
    206   // add 1 to account for the final input element.
    207   return (bound - 1) * dilation + 1;
    208 }
    209 
    210 int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
    211   CHECK_GE(window_size, 0);
    212   CHECK_GE(bound, 0);
    213   CHECK_GE(stride, 1);
    214 
    215   if (window_size > bound) {
    216     return 0;
    217   }
    218 
    219   // Without considering stride, the maximum valid offset is bound -
    220   // window_size. Taking stride into account, the valid offsets then have the
    221   // form q * stride for q = 0, ..., Q such that q * stride <= bound -
    222   // window_size. This implies that Q equals floor(bound - window_size /
    223   // stride). There are Q + 1 valid values of q, yielding the formula below.
    224   return (bound - window_size) / stride + 1;
    225 }
    226 
    227 }  // namespace window_util
    228 }  // namespace xla
    229