Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/ops_util.h"
     17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     18 #include "tensorflow/core/framework/tensor.h"
     19 #include "tensorflow/core/platform/test.h"
     20 
     21 namespace tensorflow {
     22 namespace {
     23 
     24 class OpsUtilTest : public ::testing::Test {
     25  protected:
     26   OpsUtilTest() {}
     27   ~OpsUtilTest() override {}
     28 
     29   // Padding structure.
     30   struct padding_struct {
     31     // Input parameters.
     32     struct {
     33       int in_height;
     34       int in_width;
     35       int filter_height;
     36       int filter_width;
     37       int row_stride;
     38       int col_stride;
     39       Padding padding;
     40     } input;
     41     // Output.
     42     struct {
     43       int new_height;
     44       int new_width;
     45       int pad_top;
     46       int pad_bottom;
     47       int pad_left;
     48       int pad_right;
     49     } output;
     50   };
     51 
     52   // Broadcast structure.
     53   struct bcast_struct {
     54     // Input parameters.
     55     struct {
     56       int index;     // Current index.
     57       int in_size;   // Size of the dimension.
     58       int ksize;     // Kernel size.
     59       int stride;    // Stride.
     60       int pad_size;  // Padding size.
     61     } input;
     62     // Output.
     63     struct {
     64       int new_index;  // New starting index.
     65       int new_size;   // New broadcast size.
     66     } output;
     67   };
     68 
     69   static void VerifyGet2dOutputSizeBoundaries(padding_struct pad_struct,
     70                                               error::Code code) {
     71     int64 new_height, new_width, pad_rows, pad_cols;
     72     Status status = GetWindowedOutputSize(
     73         pad_struct.input.in_height, pad_struct.input.filter_height,
     74         pad_struct.input.row_stride, pad_struct.input.padding, &new_height,
     75         &pad_rows);
     76     EXPECT_EQ(status.code(), code) << status;
     77     status = GetWindowedOutputSize(
     78         pad_struct.input.in_width, pad_struct.input.filter_width,
     79         pad_struct.input.col_stride, pad_struct.input.padding, &new_width,
     80         &pad_cols);
     81     EXPECT_EQ(status.code(), code) << status;
     82   }
     83 
     84   static void VerifyGet2dOutputSizeValues(padding_struct pad_struct,
     85                                           error::Code code) {
     86     int64 new_height, new_width, pad_rows, pad_cols;
     87     Status status = GetWindowedOutputSize(
     88         pad_struct.input.in_height, pad_struct.input.filter_height,
     89         pad_struct.input.row_stride, pad_struct.input.padding, &new_height,
     90         &pad_rows);
     91     EXPECT_EQ(status.code(), code) << status;
     92     status = GetWindowedOutputSize(
     93         pad_struct.input.in_width, pad_struct.input.filter_width,
     94         pad_struct.input.col_stride, pad_struct.input.padding, &new_width,
     95         &pad_cols);
     96     EXPECT_EQ(status.code(), code) << status;
     97     EXPECT_EQ(pad_struct.output.new_height, new_height);
     98     EXPECT_EQ(pad_struct.output.new_width, new_width);
     99     EXPECT_EQ(pad_struct.output.pad_top, pad_rows);
    100     EXPECT_EQ(pad_struct.output.pad_left, pad_cols);
    101   }
    102 
    103   static void VerifyGet2dOutputVerboseSizeValues(padding_struct pad_struct,
    104                                                  error::Code code) {
    105     int64 new_height, new_width, pad_top, pad_bottom, pad_left, pad_right;
    106     Status status = GetWindowedOutputSizeVerbose(
    107         pad_struct.input.in_height, pad_struct.input.filter_height,
    108         pad_struct.input.row_stride, pad_struct.input.padding, &new_height,
    109         &pad_top, &pad_bottom);
    110     EXPECT_EQ(status.code(), code) << status;
    111     status = GetWindowedOutputSizeVerbose(
    112         pad_struct.input.in_width, pad_struct.input.filter_width,
    113         pad_struct.input.col_stride, pad_struct.input.padding, &new_width,
    114         &pad_left, &pad_right);
    115     EXPECT_EQ(status.code(), code) << status;
    116     EXPECT_EQ(pad_struct.output.new_height, new_height);
    117     EXPECT_EQ(pad_struct.output.new_width, new_width);
    118     EXPECT_EQ(pad_struct.output.pad_top, pad_top);
    119     EXPECT_EQ(pad_struct.output.pad_bottom, pad_bottom);
    120     EXPECT_EQ(pad_struct.output.pad_left, pad_left);
    121     EXPECT_EQ(pad_struct.output.pad_right, pad_right);
    122   }
    123 
    124   static void VerifyBoundaries(bcast_struct bcast, error::Code code) {
    125     int new_index, new_size;
    126     Status status = GetBroadcastSize(
    127         bcast.input.index, bcast.input.in_size, bcast.input.ksize,
    128         bcast.input.stride, bcast.input.pad_size, &new_index, &new_size);
    129     EXPECT_EQ(status.code(), code) << status;
    130   }
    131 
    132   static void VerifyBcastValues(bcast_struct bcast) {
    133     int new_index, new_size;
    134     EXPECT_EQ(Status::OK(),
    135               GetBroadcastSize(bcast.input.index, bcast.input.in_size,
    136                                bcast.input.ksize, bcast.input.stride,
    137                                bcast.input.pad_size, &new_index, &new_size));
    138     EXPECT_EQ(bcast.output.new_index, new_index);
    139     EXPECT_EQ(bcast.output.new_size, new_size);
    140   }
    141 };
    142 
    143 TEST_F(OpsUtilTest, Get2dOutputSizeNegativeSizeTest) {
    144   padding_struct pad_struct = {{1, 1, 3, 3, 1, 1, VALID}, {-1, -1, 0, 0, 0, 0}};
    145   VerifyGet2dOutputSizeBoundaries(pad_struct, error::INVALID_ARGUMENT);
    146 }
    147 
    148 TEST_F(OpsUtilTest, Get2dOutputSizeSquareFilterTest) {
    149   padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 0, 0, 0}};
    150   padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}};
    151   VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
    152   VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
    153 }
    154 
    155 TEST_F(OpsUtilTest, Get2dOutputSizeNonSquareFilterTest) {
    156   padding_struct pad_struct1 = {{4, 5, 1, 2, 1, 1, SAME}, {4, 5, 0, 0, 0, 0}};
    157   padding_struct pad_struct2 = {{4, 5, 1, 2, 1, 1, VALID}, {4, 4, 0, 0, 0, 0}};
    158   VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
    159   VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
    160 }
    161 
    162 TEST_F(OpsUtilTest, Get2dOutputSizeUnevenStrideTest) {
    163   padding_struct pad_struct1 = {{4, 4, 2, 2, 1, 2, VALID}, {3, 2, 0, 0, 0, 0}};
    164   padding_struct pad_struct2 = {{4, 4, 2, 2, 2, 1, VALID}, {2, 3, 0, 0, 0, 0}};
    165   VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
    166   VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
    167 }
    168 
    169 TEST_F(OpsUtilTest, Get2dOutputSizeVerbose) {
    170   padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 1, 0, 1}};
    171   padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}};
    172   VerifyGet2dOutputVerboseSizeValues(pad_struct1, error::OK);
    173   VerifyGet2dOutputVerboseSizeValues(pad_struct2, error::OK);
    174 }
    175 
    176 // Test index * stride > in_size fails with INVALID_ARGUMENT.
    177 TEST_F(OpsUtilTest, GetBroadcastTestBadIndex) {
    178   bcast_struct bcast = {{2, 3, 1, 2, 0}, {0, 3}};
    179   VerifyBoundaries(bcast, error::INVALID_ARGUMENT);
    180 }
    181 
    182 // in_size = 3, ksize = 3, stride = 1, pad_size = 0
    183 TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_0) {
    184   bcast_struct bcast[] = {
    185       {{0, 3, 3, 1, 0}, {0, 3}},
    186       {{1, 3, 3, 1, 0}, {1, 2}},
    187       {{2, 3, 3, 1, 0}, {2, 1}},
    188   };
    189   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    190     VerifyBcastValues(bcast[i]);
    191   }
    192 }
    193 
    194 // in_size = 3, ksize = 3, stride = 1, pad_size = 1
    195 TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_1) {
    196   bcast_struct bcast[] = {
    197       {{0, 3, 3, 1, 1}, {0, 2}},
    198       {{1, 3, 3, 1, 1}, {0, 3}},
    199       {{2, 3, 3, 1, 1}, {1, 2}},
    200   };
    201   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    202     VerifyBcastValues(bcast[i]);
    203   }
    204 }
    205 
    206 // in_size = 3, ksize = 3, stride = 1, pad_size = 2
    207 TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_2) {
    208   bcast_struct bcast[] = {
    209       {{0, 3, 3, 1, 2}, {0, 1}},
    210       {{1, 3, 3, 1, 2}, {0, 2}},
    211       {{2, 3, 3, 1, 2}, {0, 3}},
    212   };
    213   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    214     VerifyBcastValues(bcast[i]);
    215   }
    216 }
    217 
    218 // in_size = 3, ksize = 3, stride = 2, pad_size = 0
    219 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_0) {
    220   bcast_struct bcast[] = {
    221       {{0, 3, 3, 2, 0}, {0, 3}},
    222       {{1, 3, 3, 2, 0}, {2, 1}},
    223   };
    224   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    225     VerifyBcastValues(bcast[i]);
    226   }
    227 }
    228 
    229 // in_size = 3, ksize = 3, stride = 2, pad_size = 1
    230 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_1) {
    231   bcast_struct bcast[] = {
    232       {{0, 3, 3, 2, 1}, {0, 2}},
    233       {{1, 3, 3, 2, 1}, {1, 2}},
    234   };
    235   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    236     VerifyBcastValues(bcast[i]);
    237   }
    238 }
    239 
    240 // in_size = 3, ksize = 3, stride = 2, pad_size = 2
    241 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_2) {
    242   bcast_struct bcast[] = {
    243       {{0, 3, 3, 2, 2}, {0, 1}},
    244   };
    245   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    246     VerifyBcastValues(bcast[i]);
    247   }
    248 }
    249 
    250 // in_size = 3, ksize = 3, stride = 3, pad_size = 0
    251 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_0) {
    252   bcast_struct bcast[] = {
    253       {{0, 3, 3, 3, 0}, {0, 3}},
    254   };
    255   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    256     VerifyBcastValues(bcast[i]);
    257   }
    258 }
    259 
    260 // in_size = 3, ksize = 3, stride = 3, pad_size = 1
    261 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_1) {
    262   bcast_struct bcast[] = {
    263       {{0, 3, 3, 3, 1}, {0, 2}},
    264       {{1, 3, 3, 3, 1}, {2, 1}},
    265   };
    266   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    267     VerifyBcastValues(bcast[i]);
    268   }
    269 }
    270 
    271 // in_size = 3, ksize = 3, stride = 3, pad_size = 2
    272 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) {
    273   bcast_struct bcast[] = {
    274       {{0, 3, 3, 3, 2}, {0, 1}},
    275   };
    276   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    277     VerifyBcastValues(bcast[i]);
    278   }
    279 }
    280 
    281 // in_size = 3, ksize = 1, stride = 2, pad_size = 0
    282 TEST_F(OpsUtilTest, GetBroadcastTest3_1_2_0) {
    283   bcast_struct bcast[] = {
    284       {{0, 3, 1, 2, 0}, {0, 1}},
    285       {{1, 3, 1, 2, 0}, {2, 1}},
    286   };
    287   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    288     VerifyBcastValues(bcast[i]);
    289   }
    290 }
    291 
    292 // in_size = 3, ksize = 2, stride = 3, pad_size = 0
    293 TEST_F(OpsUtilTest, GetBroadcastTest3_2_3_0) {
    294   bcast_struct bcast[] = {
    295       {{0, 3, 2, 3, 0}, {0, 2}},
    296   };
    297   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    298     VerifyBcastValues(bcast[i]);
    299   }
    300 }
    301 
    302 // in_size = 3, ksize = 2, stride = 3, pad_size = 1
    303 TEST_F(OpsUtilTest, GetBroadcastTest3_2_3_1) {
    304   bcast_struct bcast[] = {
    305       {{0, 3, 2, 3, 1}, {0, 1}},
    306       {{1, 3, 2, 3, 1}, {2, 1}},
    307   };
    308   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
    309     VerifyBcastValues(bcast[i]);
    310   }
    311 }
    312 
    313 TEST_F(OpsUtilTest, SanitizeThreadSuffix) {
    314   EXPECT_EQ("_aBc123_-___", SanitizeThreadSuffix("/aBc123_-  /"));
    315 }
    316 
    317 TEST_F(OpsUtilTest, Aligned1DSlice) {
    318 #if EIGEN_MAX_ALIGN_BYTES == 0
    319   // When EIGEN_MAX_ALIGN_BYTES is 0, a 1D tensor is always aligned.
    320   Tensor t(DT_FLOAT, TensorShape({3}));
    321   int64 start = 0;
    322   int64 end = 1;
    323   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
    324   EXPECT_EQ(output, true);
    325 #else
    326   Tensor t(DT_FLOAT, TensorShape({EIGEN_MAX_ALIGN_BYTES * 2}));
    327   int64 start = 0;
    328   int64 end = EIGEN_MAX_ALIGN_BYTES;
    329   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
    330   EXPECT_EQ(output, true);
    331   // Checks sliced 1D tensor is aligned for sanity.
    332   Tensor sliced;
    333   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({end - start})));
    334   EXPECT_EQ(sliced.IsAligned(), true);
    335 #endif
    336 }
    337 
    338 #if EIGEN_MAX_ALIGN_BYTES > 0
    339 TEST_F(OpsUtilTest, Misaligned1DSlice) {
    340   Tensor t(DT_FLOAT, TensorShape({EIGEN_MAX_ALIGN_BYTES * 2}));
    341   int64 start = 1;
    342   int64 end = EIGEN_MAX_ALIGN_BYTES + 1;
    343   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
    344   EXPECT_EQ(output, false);
    345   // Checks sliced 1D tensor is misaligned for sanity.
    346   Tensor sliced;
    347   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({end - start})));
    348   EXPECT_EQ(sliced.IsAligned(), false);
    349 }
    350 #endif
    351 
    352 TEST_F(OpsUtilTest, Aligned2DSliceOfDim0) {
    353 #if EIGEN_MAX_ALIGN_BYTES == 0
    354   // When EIGEN_MAX_ALIGN_BYTES is 0 and the size of the first dimension is
    355   // nonzero, a multidimensional tensor is always aligned.
    356   Tensor t(DT_FLOAT, TensorShape({3, 4}));
    357   int64 start = 1;
    358   int64 end = 2;
    359   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
    360   EXPECT_EQ(output, true);
    361 #else
    362   // For multidimensional tensors, alignment is dictated by inner_dim_size.
    363   int64 inner_dim_size = EIGEN_MAX_ALIGN_BYTES;
    364   Tensor t(DT_FLOAT, TensorShape({3, inner_dim_size}));
    365   int64 start = 1;
    366   int64 end = 2;
    367   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
    368   EXPECT_EQ(output, true);
    369   // Checks sliced 2D is aligned, for sanity.
    370   Tensor sliced;
    371   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({1, inner_dim_size})));
    372   EXPECT_EQ(sliced.IsAligned(), true);
    373 #endif
    374 }
    375 
    376 #if EIGEN_MAX_ALIGN_BYTES > 0
    377 TEST_F(OpsUtilTest, Misaligned2DSliceOfDim0) {
    378   // For multidimensional tensors, alignment is dictated by inner_dim_size.
    379   int64 inner_dim_size = EIGEN_MAX_ALIGN_BYTES + 1;
    380   Tensor t(DT_FLOAT, TensorShape({3, inner_dim_size}));
    381   int64 start = 1;
    382   int64 end = 2;
    383   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
    384   EXPECT_EQ(output, false);
    385   // Checks sliced 2D is misaligned, for sanity.
    386   Tensor sliced;
    387   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({1, inner_dim_size})));
    388   EXPECT_EQ(sliced.IsAligned(), false);
    389 }
    390 #endif
    391 
    392 TEST_F(OpsUtilTest, MisalignedEmptyShape) {
    393   TensorShape shape({});
    394   int64 start = 1;
    395   int64 end = 2;
    396   bool output = IsDim0SliceAligned<float>(shape, start, end);
    397   EXPECT_EQ(output, false);
    398 }
    399 
    400 TEST_F(OpsUtilTest, MisalignedEmptyDim0) {
    401   TensorShape shape({0, 1, 2});
    402   int64 start = 0;
    403   int64 end = 1;
    404   bool output = IsDim0SliceAligned<float>(shape, start, end);
    405   EXPECT_EQ(output, false);
    406 }
    407 
    408 }  // namespace
    409 }  // namespace tensorflow
    410