Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
     17 #include "tensorflow/core/framework/fake_input.h"
     18 #include "tensorflow/core/framework/node_def_builder.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/kernels/ops_testutil.h"
     21 #include "tensorflow/core/platform/test.h"
     22 #include "tensorflow/core/platform/test_benchmark.h"
     23 
     24 namespace tensorflow {
     25 
     26 class ResizeBicubicOpTest : public OpsTestBase {
     27  protected:
     28   ResizeBicubicOpTest() {
     29     TF_EXPECT_OK(NodeDefBuilder("resize_bicubic_op", "ResizeBicubic")
     30                      .Input(FakeInput(DT_FLOAT))
     31                      .Input(FakeInput(DT_INT32))
     32                      .Attr("align_corners", false)
     33                      .Finalize(node_def()));
     34     TF_EXPECT_OK(InitOp());
     35   }
     36 
     37   const Tensor* SetRandomImageInput(const TensorShape& shape) {
     38     inputs_.clear();
     39 
     40     CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions.";
     41     bool is_ref = IsRefType(input_types_[inputs_.size()]);
     42     Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
     43                                DataTypeToEnum<float>::v(), shape);
     44     input->flat<float>().setRandom();
     45     tensors_.push_back(input);
     46     if (is_ref) {
     47       CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
     48                DataTypeToEnum<float>::v());
     49       inputs_.push_back({&lock_for_refs_, input});
     50     } else {
     51       CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<float>::v());
     52       inputs_.push_back({nullptr, input});
     53     }
     54     return input;
     55   }
     56 
     57  private:
     58   static const int64 kTableSize = (1 << 10);
     59 
     60   const float* InitCoeffsTable() {
     61     // Allocate and initialize coefficients table using Bicubic
     62     // convolution algorithm.
     63     // https://en.wikipedia.org/wiki/Bicubic_interpolation
     64     float* coeffs_tab = new float[(kTableSize + 1) * 2];
     65     static const double A = -0.75;
     66     for (int i = 0; i <= kTableSize; ++i) {
     67       float x = i * 1.0 / kTableSize;
     68       coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1;
     69       x += 1.0;
     70       coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
     71     }
     72     return coeffs_tab;
     73   }
     74 
     75   const float* GetCoeffsTable() {
     76     // Static so that we initialize it on first use
     77     static const float* coeffs_tab = InitCoeffsTable();
     78     return coeffs_tab;
     79   }
     80 
     81   // Used in the baseline implementation
     82   inline int64 Bound(int64 val, int64 limit) {
     83     return std::min(limit - 1ll, std::max(0ll, val));
     84   }
     85 
     86   // Used in the baseline implementation
     87   inline void GetWeightsAndIndices(float scale, int64 out_loc, int64 limit,
     88                                    std::array<float, 4>* weights,
     89                                    std::array<int64, 4>* indices) {
     90     const int64 in_loc = scale * out_loc;
     91     const float delta = scale * out_loc - in_loc;
     92     const int64 offset = lrintf(delta * kTableSize);
     93     const float* coeffs_tab = GetCoeffsTable();
     94     *weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2],
     95                  coeffs_tab[(kTableSize - offset) * 2],
     96                  coeffs_tab[(kTableSize - offset) * 2 + 1]}};
     97     *indices = {{Bound(in_loc - 1, limit), Bound(in_loc, limit),
     98                  Bound(in_loc + 1, limit), Bound(in_loc + 2, limit)}};
     99   }
    100 
    101   // Used in the baseline implementation
    102   inline float Interpolate1D(const std::array<float, 4>& weights,
    103                              const std::array<float, 4>& values) {
    104     return values[0] * weights[0] + values[1] * weights[1] +
    105            values[2] * weights[2] + values[3] * weights[3];
    106   }
    107 
    108   // This is the straight forward unoptimized implementation of resize bicubic
    109   // We use this to confirm that the optimized version is exactly identical.
    110   void ResizeBicubicBaseline(TTypes<float, 4>::ConstTensor images,
    111                              TTypes<float, 4>::Tensor output) {
    112     const int batch_size = images.dimension(0);
    113     const int64 in_height = images.dimension(1);
    114     const int64 in_width = images.dimension(2);
    115     const int channels = images.dimension(3);
    116 
    117     ASSERT_EQ(batch_size, output.dimension(0));
    118     ASSERT_EQ(channels, output.dimension(3));
    119 
    120     const int64 out_height = output.dimension(1);
    121     const int64 out_width = output.dimension(2);
    122 
    123     const float height_scale = in_height / static_cast<float>(out_height);
    124     const float width_scale = in_width / static_cast<float>(out_width);
    125 
    126     std::array<float, 4> coeff = {{0.0, 0.0, 0.0, 0.0}};
    127     for (int64 b = 0; b < batch_size; ++b) {
    128       for (int64 y = 0; y < out_height; ++y) {
    129         std::array<float, 4> y_weights;
    130         std::array<int64, 4> y_indices;
    131         GetWeightsAndIndices(height_scale, y, in_height, &y_weights,
    132                              &y_indices);
    133         for (int64 x = 0; x < out_width; ++x) {
    134           std::array<float, 4> x_weights;
    135           std::array<int64, 4> x_indices;
    136           GetWeightsAndIndices(width_scale, x, in_width, &x_weights,
    137                                &x_indices);
    138           for (int64 c = 0; c < channels; ++c) {
    139             // Use a 4x4 patch to compute the interpolated output value at
    140             // (b, y, x, c).
    141             for (int64 i = 0; i < 4; ++i) {
    142               const std::array<float, 4> values = {
    143                   {static_cast<float>(images(b, y_indices[i], x_indices[0], c)),
    144                    static_cast<float>(images(b, y_indices[i], x_indices[1], c)),
    145                    static_cast<float>(images(b, y_indices[i], x_indices[2], c)),
    146                    static_cast<float>(
    147                        images(b, y_indices[i], x_indices[3], c))}};
    148               coeff[i] = Interpolate1D(x_weights, values);
    149             }
    150             output(b, y, x, c) = Interpolate1D(y_weights, coeff);
    151           }
    152         }
    153       }
    154     }
    155   }
    156 
    157  protected:
    158   void RunRandomTest(const int batch_size, const int64 in_height,
    159                      const int64 in_width, const int target_height,
    160                      const int target_width, int channels) {
    161     LOG(INFO) << "Running random test " << in_height << "x" << in_width << "x"
    162               << channels << " to " << target_height << "x" << target_width
    163               << "x" << channels;
    164     const Tensor* input = SetRandomImageInput(
    165         TensorShape({batch_size, in_height, in_width, channels}));
    166     AddInputFromArray<int32>(TensorShape({2}), {target_height, target_width});
    167 
    168     TF_ASSERT_OK(RunOpKernel());
    169 
    170     std::unique_ptr<Tensor> expected(new Tensor(
    171         device_->GetAllocator(AllocatorAttributes()),
    172         DataTypeToEnum<float>::v(),
    173         TensorShape({batch_size, target_height, target_width, channels})));
    174 
    175     ResizeBicubicBaseline(input->tensor<float, 4>(),
    176                           expected->tensor<float, 4>());
    177     // Note: the baseline implementation reduces first in the x direction, and
    178     // then in the y direction. The optimized version reduces first in the y
    179     // direction, and then the X direction. As a result, there may be
    180     // some slight floating point inaccuracies. We thus ensure we're within
    181     // 0.00001 of the previous implementation.
    182     test::ExpectTensorNear<float>(*expected, *GetOutput(0), 0.00001);
    183   }
    184 
    185   void RunManyRandomTests(int channels) {
    186     for (int batch_size : {1, 2, 5}) {
    187       for (int in_w : {2, 4, 7, 20, 165}) {
    188         for (int in_h : {1, 3, 5, 8, 100, 233}) {
    189           for (int target_height : {1, 2, 3, 50, 113}) {
    190             for (int target_width : {target_height, target_height / 2 + 1}) {
    191               RunRandomTest(batch_size, in_h, in_w, target_height, target_width,
    192                             channels);
    193             }
    194           }
    195         }
    196       }
    197     }
    198   }
    199 };
    200 
    201 TEST_F(ResizeBicubicOpTest, TestBicubic2x2To1x1) {
    202   // Input:
    203   // 1, 2
    204   // 3, 4
    205   AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
    206   AddInputFromArray<int32>(TensorShape({2}), {1, 1});
    207   TF_ASSERT_OK(RunOpKernel());
    208 
    209   // When scaling down, we have to arbitrarily pick a pixel from the
    210   // original input. In this case, we choose the top/left most pixel.
    211   Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1}));
    212   test::FillValues<float>(&expected, {1.0});
    213   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    214 }
    215 
    216 TEST_F(ResizeBicubicOpTest, TestBicubic2x2To0x0) {
    217   AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
    218   AddInputFromArray<int32>(TensorShape({2}), {0, 0});
    219 
    220   Status s = RunOpKernel();
    221   EXPECT_TRUE(
    222       StringPiece(s.ToString())
    223           .contains("Invalid argument: output dimensions must be positive"))
    224       << s;
    225 }
    226 
    227 TEST_F(ResizeBicubicOpTest, TestBicubicRandom141x186) {
    228   RunRandomTest(2, 141, 186, 299, 299, 1 /* channels */);
    229   RunRandomTest(2, 141, 186, 299, 299, 3 /* channels */);
    230 }
    231 
    232 TEST_F(ResizeBicubicOpTest, TestBicubicRandom183x229) {
    233   RunRandomTest(2, 183, 229, 299, 299, 1 /* channels */);
    234   RunRandomTest(2, 183, 229, 299, 299, 3 /* channels */);
    235 }
    236 
    237 TEST_F(ResizeBicubicOpTest, TestBicubicRandom749x603) {
    238   RunRandomTest(2, 749, 603, 299, 299, 1 /* channels */);
    239   RunRandomTest(2, 749, 603, 299, 299, 3 /* channels */);
    240 }
    241 
    242 TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes1Channel) {
    243   RunManyRandomTests(1);
    244 }
    245 
    246 TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes3Channels) {
    247   RunManyRandomTests(3);
    248 }
    249 
    250 TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) {
    251   RunManyRandomTests(4);
    252 }
    253 
    254 static Graph* ResizeBicubic(int batch_size, int size, int channels,
    255                             float scale_y = 0.3, float scale_x = 0.7) {
    256   Graph* g = new Graph(OpRegistry::Global());
    257   Tensor input(DT_FLOAT, TensorShape({batch_size, size, size, channels}));
    258   input.flat<float>().setRandom();
    259   Tensor shape(DT_INT32, TensorShape({2}));
    260   auto shape_t = shape.flat<int32>();
    261   shape_t(0) = scale_y * size;
    262   shape_t(1) = scale_x * size;
    263   test::graph::Binary(g, "ResizeBicubic", test::graph::Constant(g, input),
    264                       test::graph::Constant(g, shape));
    265   return g;
    266 }
    267 
    268 #define BM_ResizeBicubicDev(BATCH, SIZE, CHANNELS)                            \
    269   static void BM_ResizeBicubic##_##BATCH##_##SIZE##_##CHANNELS(int iters) {   \
    270     testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * SIZE * SIZE * \
    271                             CHANNELS);                                        \
    272     test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS)).Run(iters);  \
    273   }                                                                           \
    274   BENCHMARK(BM_ResizeBicubic##_##BATCH##_##SIZE##_##CHANNELS);
    275 
    276 BM_ResizeBicubicDev(8, 32, 3);
    277 BM_ResizeBicubicDev(8, 128, 3);
    278 BM_ResizeBicubicDev(8, 512, 3);
    279 BM_ResizeBicubicDev(8, 1024, 3);
    280 BM_ResizeBicubicDev(16, 32, 3);
    281 BM_ResizeBicubicDev(16, 128, 3);
    282 BM_ResizeBicubicDev(16, 512, 3);
    283 BM_ResizeBicubicDev(16, 1024, 3);
    284 BM_ResizeBicubicDev(32, 32, 3);
    285 BM_ResizeBicubicDev(32, 128, 3);
    286 BM_ResizeBicubicDev(32, 512, 3);
    287 BM_ResizeBicubicDev(32, 1024, 3);
    288 
    289 #define BM_ResizeBicubicExpand(BATCH, SIZE, CHANNELS)                         \
    290   static void BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS(         \
    291       int iters) {                                                            \
    292     testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * SIZE * SIZE * \
    293                             CHANNELS * 8 * 8);                                \
    294     test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS, 8, 8))        \
    295         .Run(iters);                                                          \
    296   }                                                                           \
    297   BENCHMARK(BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS);
    298 
    299 BM_ResizeBicubicExpand(12, 48, 1);
    300 BM_ResizeBicubicExpand(12, 48, 3);
    301 BM_ResizeBicubicExpand(12, 48, 40);
    302 
    303 }  // end namespace tensorflow
    304