Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
     20 #include "tensorflow/core/framework/fake_input.h"
     21 #include "tensorflow/core/framework/node_def_builder.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_testutil.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/framework/types.pb.h"
     26 #include "tensorflow/core/kernels/batch_norm_op.h"
     27 #include "tensorflow/core/kernels/ops_testutil.h"
     28 #include "tensorflow/core/kernels/quantization_utils.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/lib/core/threadpool.h"
     31 #include "tensorflow/core/platform/test.h"
     32 
     33 namespace tensorflow {
     34 
     35 class QuantizedBatchNormOpTest : public OpsTestBase {};
     36 
     37 TEST_F(QuantizedBatchNormOpTest, Simple) {
     38   TF_EXPECT_OK(NodeDefBuilder("quantized_batch_norm_op",
     39                               "QuantizedBatchNormWithGlobalNormalization")
     40                    .Input(FakeInput(DT_QUINT8))
     41                    .Input(FakeInput(DT_FLOAT))
     42                    .Input(FakeInput(DT_FLOAT))
     43                    .Input(FakeInput(DT_QUINT8))
     44                    .Input(FakeInput(DT_FLOAT))
     45                    .Input(FakeInput(DT_FLOAT))
     46                    .Input(FakeInput(DT_QUINT8))
     47                    .Input(FakeInput(DT_FLOAT))
     48                    .Input(FakeInput(DT_FLOAT))
     49                    .Input(FakeInput(DT_QUINT8))
     50                    .Input(FakeInput(DT_FLOAT))
     51                    .Input(FakeInput(DT_FLOAT))
     52                    .Input(FakeInput(DT_QUINT8))
     53                    .Input(FakeInput(DT_FLOAT))
     54                    .Input(FakeInput(DT_FLOAT))
     55                    .Attr("scale_after_normalization", false)
     56                    .Attr("variance_epsilon", 0.001)
     57                    .Attr("Tinput", DT_QUINT8)
     58                    .Attr("out_type", DT_QINT32)
     59                    .Finalize(node_def()));
     60   TF_ASSERT_OK(InitOp());
     61   const float input_min = -128.0f;
     62   const float input_max = 127.0f;
     63   const int input_batch = 1;
     64   const int input_height = 1;
     65   const int input_width = 6;
     66   const int input_depth = 2;
     67   Tensor input_float(DT_FLOAT,
     68                      {input_batch, input_height, input_width, input_depth});
     69   test::FillValues<float>(&input_float,
     70                           {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6});
     71   Tensor input_quantized =
     72       FloatTensorToQuantized<quint8>(input_float, input_min, input_max);
     73   const float mean_min = 0.0f;
     74   const float mean_max = 20.0f;
     75   Tensor mean_float(DT_FLOAT, {input_depth});
     76   test::FillValues<float>(&mean_float, {10, 20});
     77   Tensor mean_quantized =
     78       FloatTensorToQuantized<quint8>(mean_float, mean_min, mean_max);
     79   const float variance_min = 0.0f;
     80   const float variance_max = 1.0f;
     81   Tensor variance_float(DT_FLOAT, {input_depth});
     82   test::FillValues<float>(&variance_float, {0.25, 0.5});
     83   Tensor variance_quantized = FloatTensorToQuantized<quint8>(
     84       variance_float, variance_min, variance_max);
     85   const float beta_min = 0.0f;
     86   const float beta_max = 1.0f;
     87   Tensor beta_float(DT_FLOAT, {input_depth});
     88   test::FillValues<float>(&beta_float, {0.1, 0.6});
     89   Tensor beta_quantized =
     90       FloatTensorToQuantized<quint8>(beta_float, beta_min, beta_max);
     91   const float gamma_min = 0.0f;
     92   const float gamma_max = 1.0f;
     93   Tensor gamma_float(DT_FLOAT, {input_depth});
     94   test::FillValues<float>(&gamma_float, {0.0, 0.0});
     95   Tensor gamma_quantized =
     96       FloatTensorToQuantized<quint8>(gamma_float, gamma_min, gamma_max);
     97 
     98   AddInputFromArray<quint8>(input_quantized.shape(),
     99                             input_quantized.flat<quint8>());
    100   AddInputFromArray<float>(TensorShape({1}), {input_min});
    101   AddInputFromArray<float>(TensorShape({1}), {input_max});
    102   AddInputFromArray<quint8>(mean_quantized.shape(),
    103                             mean_quantized.flat<quint8>());
    104   AddInputFromArray<float>(TensorShape({1}), {mean_min});
    105   AddInputFromArray<float>(TensorShape({1}), {mean_max});
    106   AddInputFromArray<quint8>(variance_quantized.shape(),
    107                             variance_quantized.flat<quint8>());
    108   AddInputFromArray<float>(TensorShape({1}), {variance_min});
    109   AddInputFromArray<float>(TensorShape({1}), {variance_max});
    110   AddInputFromArray<quint8>(beta_quantized.shape(),
    111                             beta_quantized.flat<quint8>());
    112   AddInputFromArray<float>(TensorShape({1}), {beta_min});
    113   AddInputFromArray<float>(TensorShape({1}), {beta_max});
    114   AddInputFromArray<quint8>(gamma_quantized.shape(),
    115                             gamma_quantized.flat<quint8>());
    116   AddInputFromArray<float>(TensorShape({1}), {gamma_min});
    117   AddInputFromArray<float>(TensorShape({1}), {gamma_max});
    118   TF_ASSERT_OK(RunOpKernel());
    119 
    120   Tensor expected_float(
    121       allocator(), DT_FLOAT,
    122       TensorShape({input_batch, input_height, input_width, input_depth}));
    123   test::FillValues<float>(
    124       &expected_float, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86,
    125                         -33.31, -23.85, -34.72, -25.85, -36.13});
    126   const Tensor& output_quantized = *GetOutput(0);
    127   const float output_min = GetOutput(1)->flat<float>()(0);
    128   const float output_max = GetOutput(2)->flat<float>()(0);
    129   Tensor output_float =
    130       QuantizedTensorToFloat<qint32>(output_quantized, output_min, output_max);
    131   test::ExpectTensorNear<float>(expected_float, output_float, 0.1);
    132 }
    133 
    134 TEST_F(QuantizedBatchNormOpTest, SameAsFloat) {
    135   TF_EXPECT_OK(NodeDefBuilder("quantized_batch_norm_op",
    136                               "QuantizedBatchNormWithGlobalNormalization")
    137                    .Input(FakeInput(DT_QUINT8))
    138                    .Input(FakeInput(DT_FLOAT))
    139                    .Input(FakeInput(DT_FLOAT))
    140                    .Input(FakeInput(DT_QUINT8))
    141                    .Input(FakeInput(DT_FLOAT))
    142                    .Input(FakeInput(DT_FLOAT))
    143                    .Input(FakeInput(DT_QUINT8))
    144                    .Input(FakeInput(DT_FLOAT))
    145                    .Input(FakeInput(DT_FLOAT))
    146                    .Input(FakeInput(DT_QUINT8))
    147                    .Input(FakeInput(DT_FLOAT))
    148                    .Input(FakeInput(DT_FLOAT))
    149                    .Input(FakeInput(DT_QUINT8))
    150                    .Input(FakeInput(DT_FLOAT))
    151                    .Input(FakeInput(DT_FLOAT))
    152                    .Attr("scale_after_normalization", false)
    153                    .Attr("variance_epsilon", 0.001)
    154                    .Attr("Tinput", DT_QUINT8)
    155                    .Attr("out_type", DT_QINT32)
    156                    .Finalize(node_def()));
    157   TF_ASSERT_OK(InitOp());
    158   const float input_min = -128.0f;
    159   const float input_max = 127.0f;
    160   const int input_batch = 1;
    161   const int input_height = 1;
    162   const int input_width = 6;
    163   const int input_depth = 2;
    164   Tensor input_float(DT_FLOAT,
    165                      {input_batch, input_height, input_width, input_depth});
    166   test::FillValues<float>(&input_float,
    167                           {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6});
    168   Tensor input_quantized =
    169       FloatTensorToQuantized<quint8>(input_float, input_min, input_max);
    170   const float mean_min = 0.0f;
    171   const float mean_max = 20.0f;
    172   Tensor mean_float(DT_FLOAT, {input_depth});
    173   test::FillValues<float>(&mean_float, {10, 20});
    174   Tensor mean_quantized =
    175       FloatTensorToQuantized<quint8>(mean_float, mean_min, mean_max);
    176   const float variance_min = 0.0f;
    177   const float variance_max = 1.0f;
    178   Tensor variance_float(DT_FLOAT, {input_depth});
    179   test::FillValues<float>(&variance_float, {0.25, 0.5});
    180   Tensor variance_quantized = FloatTensorToQuantized<quint8>(
    181       variance_float, variance_min, variance_max);
    182   const float beta_min = 0.0f;
    183   const float beta_max = 1.0f;
    184   Tensor beta_float(DT_FLOAT, {input_depth});
    185   test::FillValues<float>(&beta_float, {0.1, 0.6});
    186   Tensor beta_quantized =
    187       FloatTensorToQuantized<quint8>(beta_float, beta_min, beta_max);
    188   const float gamma_min = 0.0f;
    189   const float gamma_max = 1.0f;
    190   Tensor gamma_float(DT_FLOAT, {input_depth});
    191   test::FillValues<float>(&gamma_float, {0.0, 0.0});
    192   Tensor gamma_quantized =
    193       FloatTensorToQuantized<quint8>(gamma_float, gamma_min, gamma_max);
    194 
    195   AddInputFromArray<quint8>(input_quantized.shape(),
    196                             input_quantized.flat<quint8>());
    197   AddInputFromArray<float>(TensorShape({1}), {input_min});
    198   AddInputFromArray<float>(TensorShape({1}), {input_max});
    199   AddInputFromArray<quint8>(mean_quantized.shape(),
    200                             mean_quantized.flat<quint8>());
    201   AddInputFromArray<float>(TensorShape({1}), {mean_min});
    202   AddInputFromArray<float>(TensorShape({1}), {mean_max});
    203   AddInputFromArray<quint8>(variance_quantized.shape(),
    204                             variance_quantized.flat<quint8>());
    205   AddInputFromArray<float>(TensorShape({1}), {variance_min});
    206   AddInputFromArray<float>(TensorShape({1}), {variance_max});
    207   AddInputFromArray<quint8>(beta_quantized.shape(),
    208                             beta_quantized.flat<quint8>());
    209   AddInputFromArray<float>(TensorShape({1}), {beta_min});
    210   AddInputFromArray<float>(TensorShape({1}), {beta_max});
    211   AddInputFromArray<quint8>(gamma_quantized.shape(),
    212                             gamma_quantized.flat<quint8>());
    213   AddInputFromArray<float>(TensorShape({1}), {gamma_min});
    214   AddInputFromArray<float>(TensorShape({1}), {gamma_max});
    215   TF_ASSERT_OK(RunOpKernel());
    216 
    217   Tensor expected_float(
    218       allocator(), DT_FLOAT,
    219       TensorShape({input_batch, input_height, input_width, input_depth}));
    220   thread::ThreadPool threadpool(Env::Default(), "test", 1);
    221   EigenThreadPoolWrapper wrapper(&threadpool);
    222   Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, 1);
    223   const Tensor& const_input_float = input_float;
    224   const Tensor& const_mean_float = mean_float;
    225   const Tensor& const_variance_float = variance_float;
    226   const Tensor& const_beta_float = beta_float;
    227   const Tensor& const_gamma_float = gamma_float;
    228   functor::BatchNorm<Eigen::ThreadPoolDevice, float>()(
    229       eigen_cpu_device, const_input_float.tensor<float, 4>(),
    230       const_mean_float.vec<float>(), const_variance_float.vec<float>(),
    231       const_beta_float.vec<float>(), const_gamma_float.vec<float>(), 0.001,
    232       false, expected_float.tensor<float, 4>());
    233 
    234   const Tensor& output_quantized = *GetOutput(0);
    235   const float output_min = GetOutput(1)->flat<float>()(0);
    236   const float output_max = GetOutput(2)->flat<float>()(0);
    237   Tensor output_float =
    238       QuantizedTensorToFloat<qint32>(output_quantized, output_min, output_max);
    239   test::ExpectTensorNear<float>(expected_float, output_float, 0.1);
    240 }
    241 
    242 }  // namespace tensorflow
    243