Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <string>
     16 #include <vector>
     17 
     18 #include <gmock/gmock.h>
     19 #include <gtest/gtest.h>
     20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
     21 #include "tensorflow/lite/toco/model.h"
     22 #include "tensorflow/lite/toco/tooling_util.h"
     23 
     24 namespace toco {
     25 
     26 namespace {
     27 // A gmock matcher that check that elements of a float vector match to a given
     28 // tolerance.
     29 std::vector<testing::Matcher<float>> ArrayFloatNear(
     30     const std::vector<float>& values, float max_abs_error = 1e-5) {
     31   std::vector<testing::Matcher<float>> matchers;
     32   matchers.reserve(values.size());
     33   for (const float& v : values) {
     34     matchers.emplace_back(testing::FloatNear(v, max_abs_error));
     35   }
     36   return matchers;
     37 }
     38 }  // namespace
     39 
     40 // The following 3 tests make sure the concatenation operation on different axis
     41 // values match TensorFlow results listed below:
     42 //
     43 // x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
     44 // x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]]
     45 // x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]]
     46 // x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]]
     47 //
     48 // ConcatAtAxis0 test:
     49 // t0 = tf.concat([x0, x1, x2, x3], 0)
     50 // [[[ 0  1]
     51 //   [ 2  3]]
     52 //
     53 //  [[ 4  5]
     54 //   [ 6  7]]
     55 //
     56 //  [[10 11]
     57 //   [12 13]]
     58 //
     59 //  [[14 15]
     60 //   [16 17]]
     61 //
     62 //  [[20 21]
     63 //   [22 23]]
     64 //
     65 //  [[24 25]
     66 //   [26 27]]
     67 //
     68 //  [[30 31]
     69 //   [32 33]]
     70 //
     71 //  [[34 35]
     72 //   [36 37]]]
     73 //
     74 // ConcatAtAxis1 test:
     75 // t1 = tf.concat([x0, x1, x2, x3], 1)
     76 // [[[ 0  1]
     77 //   [ 2  3]
     78 //   [10 11]
     79 //   [12 13]
     80 //   [20 21]
     81 //   [22 23]
     82 //   [30 31]
     83 //   [32 33]]
     84 //
     85 //  [[ 4  5]
     86 //   [ 6  7]
     87 //   [14 15]
     88 //   [16 17]
     89 //   [24 25]
     90 //   [26 27]
     91 //   [34 35]
     92 //   [36 37]]]
     93 //
     94 // ConcatAtAxis2 test:
     95 // t2 = tf.concat([x0, x1, x2, x3], 2)
     96 // [[[ 0  1 10 11 20 21 30 31]
     97 //   [ 2  3 12 13 22 23 32 33]]
     98 //
     99 //  [[ 4  5 14 15 24 25 34 35]
    100 //   [ 6  7 16 17 26 27 36 37]]]
    101 
    102 class ResolveConstantConcatenationTest : public ::testing::Test {
    103  protected:
    104   ResolveConstantConcatenationTest() {}
    105 
    106   // Prepare a hypothetical TOCO model with one Concatenation operator in it
    107   // together with 4 arrays as its inputs.
    108   // It receives the dimension of concatenation as input.
    109   void PrepareModel(Model* model, int axis) {
    110     std::vector<string> concat_input_names = {"array0", "array1", "array2",
    111                                               "array3"};
    112 
    113     const int kDim = 3;
    114     const int kElementPerDim = 2;
    115     const int kBufSize = 8;
    116     const int kNumArrays = 4;
    117     static float in_buf[kNumArrays][kBufSize] = {
    118         {0., 1., 2., 3., 4., 5., 6., 7.},
    119         {10., 11., 12., 13., 14., 15., 16., 17.},
    120         {20., 21., 22., 23., 24., 25., 26., 27.},
    121         {30., 31., 32., 33., 34., 35., 36., 37.}};
    122     int cnt = 0;
    123     for (const string& concat_input_name : concat_input_names) {
    124       Array& in_array = model->GetOrCreateArray(concat_input_name);
    125       in_array.data_type = ArrayDataType::kFloat;
    126 
    127       // Initialize shape for the input array.
    128       Shape* in_array_shape = in_array.mutable_shape();
    129       std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
    130       for (int i = 0; i < kDim; i++) {
    131         in_array_shape_dim->push_back(kElementPerDim);
    132       }
    133       auto& in_array_buffer =
    134           in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
    135       in_array_buffer.data.resize(kBufSize);
    136       float* buf_ptr =
    137           in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data();
    138       std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr);
    139       cnt++;
    140     }
    141     auto* concatenation_op = new ConcatenationOperator;
    142     concatenation_op->axis = axis;
    143     concatenation_op->inputs = concat_input_names;
    144     concatenation_op->outputs = {"concat_op_outputs"};
    145     Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
    146     out_array.data_type = ArrayDataType::kFloat;
    147     Shape* out_array_shape = out_array.mutable_shape();
    148     std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
    149     out_array_shape_dim->resize(kDim);
    150     for (int i = 0; i < kDim; i++) {
    151       if (i == axis) {
    152         (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim;
    153       } else {
    154         (*out_array_shape_dim)[i] = kElementPerDim;
    155       }
    156     }
    157     model->operators.push_back(std::unique_ptr<Operator>(concatenation_op));
    158   }
    159 };
    160 
    161 TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
    162   Model model;
    163   const int axis = 0;
    164   PrepareModel(&model, axis);
    165 
    166   GraphTransformationsSet graph_transformation_set;
    167   graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
    168   EXPECT_THAT(model.GetArrayMap().size(), 5);
    169   bool modified;
    170   ASSERT_TRUE((*graph_transformation_set.begin())
    171                   ->Run(&model, /*op_index=*/0, &modified)
    172                   .ok());
    173   EXPECT_THAT(model.GetArrayMap().size(), 1);
    174 
    175   auto& concatenated_array = (*model.GetArrayMap().begin()).second;
    176   EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
    177               ElementsAreArray(ArrayFloatNear(
    178                   {0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  10., 11., 12.,
    179                    13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
    180                    26., 27., 30., 31., 32., 33., 34., 35., 36., 37.})));
    181 }
    182 
    183 TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
    184   Model model;
    185   const int axis = 1;
    186   PrepareModel(&model, axis);
    187 
    188   GraphTransformationsSet graph_transformation_set;
    189   graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
    190   EXPECT_THAT(model.GetArrayMap().size(), 5);
    191   bool modified;
    192   ASSERT_TRUE((*graph_transformation_set.begin())
    193                   ->Run(&model, /*op_index=*/0, &modified)
    194                   .ok());
    195   EXPECT_THAT(model.GetArrayMap().size(), 1);
    196 
    197   auto& concatenated_array = (*model.GetArrayMap().begin()).second;
    198   EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
    199               ElementsAreArray(ArrayFloatNear(
    200                   {0.,  1.,  2.,  3.,  10., 11., 12., 13., 20., 21., 22.,
    201                    23., 30., 31., 32., 33., 4.,  5.,  6.,  7.,  14., 15.,
    202                    16., 17., 24., 25., 26., 27., 34., 35., 36., 37.})));
    203 }
    204 
    205 TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
    206   Model model;
    207   const int axis = 2;
    208   PrepareModel(&model, axis);
    209 
    210   GraphTransformationsSet graph_transformation_set;
    211   graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
    212   EXPECT_THAT(model.GetArrayMap().size(), 5);
    213   bool modified;
    214   ASSERT_TRUE((*graph_transformation_set.begin())
    215                   ->Run(&model, /*op_index=*/0, &modified)
    216                   .ok());
    217   EXPECT_THAT(model.GetArrayMap().size(), 1);
    218 
    219   auto& concatenated_array = (*model.GetArrayMap().begin()).second;
    220   EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
    221               ElementsAreArray(ArrayFloatNear(
    222                   {0.,  1.,  10., 11., 20., 21., 30., 31., 2.,  3.,  12.,
    223                    13., 22., 23., 32., 33., 4.,  5.,  14., 15., 24., 25.,
    224                    34., 35., 6.,  7.,  16., 17., 26., 27., 36., 37.})));
    225 }
    226 
    227 }  // namespace toco
    228