1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <string> 16 #include <vector> 17 18 #include <gmock/gmock.h> 19 #include <gtest/gtest.h> 20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" 21 #include "tensorflow/lite/toco/model.h" 22 #include "tensorflow/lite/toco/tooling_util.h" 23 24 namespace toco { 25 26 namespace { 27 // A gmock matcher that check that elements of a float vector match to a given 28 // tolerance. 29 std::vector<testing::Matcher<float>> ArrayFloatNear( 30 const std::vector<float>& values, float max_abs_error = 1e-5) { 31 std::vector<testing::Matcher<float>> matchers; 32 matchers.reserve(values.size()); 33 for (const float& v : values) { 34 matchers.emplace_back(testing::FloatNear(v, max_abs_error)); 35 } 36 return matchers; 37 } 38 } // namespace 39 40 // The following 3 tests make sure the concatenation operation on different axis 41 // values match TensorFlow results listed below: 42 // 43 // x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 44 // x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]] 45 // x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]] 46 // x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]] 47 // 48 // ConcatAtAxis0 test: 49 // t0 = tf.concat([x0, x1, x2, x3], 0) 50 // [[[ 0 1] 51 // [ 2 3]] 52 // 53 // [[ 4 5] 54 // [ 6 7]] 55 // 56 // [[10 11] 57 // [12 13]] 58 // 59 // [[14 15] 60 // [16 17]] 61 // 62 // [[20 21] 63 // [22 23]] 64 // 65 // [[24 25] 66 // [26 27]] 67 // 68 // [[30 31] 69 // [32 33]] 70 // 71 // [[34 35] 72 // [36 37]]] 73 // 74 // ConcatAtAxis1 test: 75 // t1 = tf.concat([x0, x1, x2, x3], 1) 76 // [[[ 0 1] 77 // [ 2 3] 78 // [10 11] 79 // [12 13] 80 // [20 21] 81 // [22 23] 82 // [30 31] 83 // [32 33]] 84 // 85 // [[ 4 5] 86 // [ 6 7] 87 // [14 15] 88 // [16 17] 89 // [24 25] 90 // [26 27] 91 // [34 35] 92 // [36 37]]] 93 // 94 // ConcatAtAxis2 test: 95 // t2 = tf.concat([x0, x1, x2, x3], 2) 96 // [[[ 0 1 10 11 20 21 30 31] 97 // [ 2 3 12 13 22 23 32 33]] 98 // 99 // [[ 4 5 14 15 24 25 34 35] 100 // [ 6 7 16 17 26 27 36 37]]] 101 102 class ResolveConstantConcatenationTest : public ::testing::Test { 103 protected: 104 ResolveConstantConcatenationTest() {} 105 106 // Prepare a hypothetical TOCO model with one Concatenation operator in it 107 // together with 4 arrays as its inputs. 108 // It receives the dimension of concatenation as input. 109 void PrepareModel(Model* model, int axis) { 110 std::vector<string> concat_input_names = {"array0", "array1", "array2", 111 "array3"}; 112 113 const int kDim = 3; 114 const int kElementPerDim = 2; 115 const int kBufSize = 8; 116 const int kNumArrays = 4; 117 static float in_buf[kNumArrays][kBufSize] = { 118 {0., 1., 2., 3., 4., 5., 6., 7.}, 119 {10., 11., 12., 13., 14., 15., 16., 17.}, 120 {20., 21., 22., 23., 24., 25., 26., 27.}, 121 {30., 31., 32., 33., 34., 35., 36., 37.}}; 122 int cnt = 0; 123 for (const string& concat_input_name : concat_input_names) { 124 Array& in_array = model->GetOrCreateArray(concat_input_name); 125 in_array.data_type = ArrayDataType::kFloat; 126 127 // Initialize shape for the input array. 128 Shape* in_array_shape = in_array.mutable_shape(); 129 std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims(); 130 for (int i = 0; i < kDim; i++) { 131 in_array_shape_dim->push_back(kElementPerDim); 132 } 133 auto& in_array_buffer = 134 in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>(); 135 in_array_buffer.data.resize(kBufSize); 136 float* buf_ptr = 137 in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data(); 138 std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr); 139 cnt++; 140 } 141 auto* concatenation_op = new ConcatenationOperator; 142 concatenation_op->axis = axis; 143 concatenation_op->inputs = concat_input_names; 144 concatenation_op->outputs = {"concat_op_outputs"}; 145 Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]); 146 out_array.data_type = ArrayDataType::kFloat; 147 Shape* out_array_shape = out_array.mutable_shape(); 148 std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims(); 149 out_array_shape_dim->resize(kDim); 150 for (int i = 0; i < kDim; i++) { 151 if (i == axis) { 152 (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim; 153 } else { 154 (*out_array_shape_dim)[i] = kElementPerDim; 155 } 156 } 157 model->operators.push_back(std::unique_ptr<Operator>(concatenation_op)); 158 } 159 }; 160 161 TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { 162 Model model; 163 const int axis = 0; 164 PrepareModel(&model, axis); 165 166 GraphTransformationsSet graph_transformation_set; 167 graph_transformation_set.Add(new toco::ResolveConstantConcatenation); 168 EXPECT_THAT(model.GetArrayMap().size(), 5); 169 bool modified; 170 ASSERT_TRUE((*graph_transformation_set.begin()) 171 ->Run(&model, /*op_index=*/0, &modified) 172 .ok()); 173 EXPECT_THAT(model.GetArrayMap().size(), 1); 174 175 auto& concatenated_array = (*model.GetArrayMap().begin()).second; 176 EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, 177 ElementsAreArray(ArrayFloatNear( 178 {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 179 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25., 180 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.}))); 181 } 182 183 TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { 184 Model model; 185 const int axis = 1; 186 PrepareModel(&model, axis); 187 188 GraphTransformationsSet graph_transformation_set; 189 graph_transformation_set.Add(new toco::ResolveConstantConcatenation); 190 EXPECT_THAT(model.GetArrayMap().size(), 5); 191 bool modified; 192 ASSERT_TRUE((*graph_transformation_set.begin()) 193 ->Run(&model, /*op_index=*/0, &modified) 194 .ok()); 195 EXPECT_THAT(model.GetArrayMap().size(), 1); 196 197 auto& concatenated_array = (*model.GetArrayMap().begin()).second; 198 EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, 199 ElementsAreArray(ArrayFloatNear( 200 {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22., 201 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15., 202 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.}))); 203 } 204 205 TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { 206 Model model; 207 const int axis = 2; 208 PrepareModel(&model, axis); 209 210 GraphTransformationsSet graph_transformation_set; 211 graph_transformation_set.Add(new toco::ResolveConstantConcatenation); 212 EXPECT_THAT(model.GetArrayMap().size(), 5); 213 bool modified; 214 ASSERT_TRUE((*graph_transformation_set.begin()) 215 ->Run(&model, /*op_index=*/0, &modified) 216 .ok()); 217 EXPECT_THAT(model.GetArrayMap().size(), 1); 218 219 auto& concatenated_array = (*model.GetArrayMap().begin()).second; 220 EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, 221 ElementsAreArray(ArrayFloatNear( 222 {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12., 223 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25., 224 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.}))); 225 } 226 227 } // namespace toco 228