1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/tflite/operator.h" 16 17 #include "flatbuffers/flexbuffers.h" 18 #include <gmock/gmock.h> 19 #include <gtest/gtest.h> 20 #include "tensorflow/contrib/lite/toco/tooling_util.h" 21 22 #include "tensorflow/core/framework/attr_value.pb.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 25 namespace toco { 26 27 namespace tflite { 28 namespace { 29 30 class OperatorTest : public ::testing::Test { 31 protected: 32 // Return the operator for the given name and type. 33 const BaseOperator& GetOperator(const string& name, OperatorType type) { 34 using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>; 35 using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>; 36 37 static auto* by_name = new OpsByName(BuildOperatorByNameMap()); 38 static auto* by_type = new OpsByType(BuildOperatorByTypeMap()); 39 40 // Make sure the two maps were consitently built. 41 CHECK(by_name->count(name)) << "No operator for '" << name << "'."; 42 BaseOperator* op1 = by_name->at(name).get(); 43 CHECK(op1->type() == type) << "while verifying '" << name << "'."; 44 45 CHECK(by_type->count(type)) 46 << "No operator for '" << OperatorTypeName(type) << "'."; 47 BaseOperator* op2 = by_type->at(type).get(); 48 CHECK(op2->name() == name) 49 << "while verifying '" << OperatorTypeName(type) << "'."; 50 51 return *op1; 52 } 53 54 // Use the given BaseOperator to serialize the tf.mini operator into a set of 55 // TF Lite options. Proceed to deserialize the options back into a new 56 // tf.mini operator, which is then returned. If `options` is given, it will 57 // be populated with the serialized options. 58 template <typename T> 59 std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op, 60 const T& toco_op, 61 Options* options = nullptr) { 62 flatbuffers::FlatBufferBuilder builder; 63 Options input_options = op.Serialize(toco_op, &builder); 64 65 if (options) { 66 *options = input_options; 67 } 68 69 builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type, 70 input_options.builtin, input_options.custom, 71 ::tflite::CustomOptionsFormat_FLEXBUFFERS)); 72 auto* output_options = 73 flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer()); 74 auto new_toco_op = op.Deserialize(output_options->builtin_options(), 75 output_options->custom_options()); 76 77 CHECK(dynamic_cast<T*>(new_toco_op.get())) 78 << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " 79 << HelpfulOperatorTypeName(toco_op); 80 81 return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release())); 82 } 83 84 // Verify serialization and deserialization of simple operators (those 85 // that don't have any configuration parameters). 86 template <typename T> 87 void CheckSimpleOperator(const string& name, OperatorType type) { 88 Options options; 89 auto output_toco_op = 90 SerializeAndDeserialize(GetOperator(name, type), T(), &options); 91 92 ASSERT_EQ(0, options.builtin.o); 93 ASSERT_EQ(0, options.custom.o); 94 ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type); 95 96 ASSERT_NE(nullptr, output_toco_op.get()); 97 } 98 }; 99 100 TEST_F(OperatorTest, SimpleOperators) { 101 CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE", 102 OperatorType::kDequantize); 103 CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor); 104 CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu); 105 CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1); 106 CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6); 107 CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic); 108 CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh); 109 CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp); 110 } 111 112 TEST_F(OperatorTest, BuiltinAdd) { 113 AddOperator op; 114 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 115 auto output_toco_op = 116 SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op); 117 EXPECT_EQ(op.fused_activation_function, 118 output_toco_op->fused_activation_function); 119 } 120 121 TEST_F(OperatorTest, BuiltinMean) { 122 MeanOperator op; 123 op.keep_dims = false; 124 125 auto output_toco_op = 126 SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op); 127 EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); 128 } 129 130 TEST_F(OperatorTest, CustomCast) { 131 CastOperator op; 132 op.src_data_type = ArrayDataType::kFloat; 133 op.dst_data_type = ArrayDataType::kUint8; 134 auto output_toco_op = 135 SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op); 136 EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type); 137 EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type); 138 } 139 140 TEST_F(OperatorTest, CustomConcatenation) { 141 ConcatenationOperator op; 142 op.axis = 123; 143 auto output_toco_op = SerializeAndDeserialize( 144 GetOperator("CONCATENATION", OperatorType::kConcatenation), op); 145 EXPECT_EQ(op.axis, output_toco_op->axis); 146 } 147 148 TEST_F(OperatorTest, CustomDepthToSpace) { 149 DepthToSpaceOperator op; 150 op.block_size = 123; 151 auto output_toco_op = SerializeAndDeserialize( 152 GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op); 153 EXPECT_EQ(op.block_size, output_toco_op->block_size); 154 } 155 156 TEST_F(OperatorTest, CustomFakeQuant) { 157 FakeQuantOperator op; 158 auto* minmax = new MinMax; 159 minmax->min = -10; 160 minmax->max = 200; 161 op.minmax.reset(minmax); 162 auto output_toco_op = SerializeAndDeserialize( 163 GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op); 164 EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min); 165 EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max); 166 } 167 168 TEST_F(OperatorTest, CustomFullyConnected) { 169 FullyConnectedOperator op; 170 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 171 auto output_toco_op = SerializeAndDeserialize( 172 GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op); 173 EXPECT_EQ(op.fused_activation_function, 174 output_toco_op->fused_activation_function); 175 } 176 177 TEST_F(OperatorTest, BuiltinGather) { 178 GatherOperator op; 179 auto output_toco_op = 180 SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op); 181 ASSERT_NE(nullptr, output_toco_op.get()); 182 } 183 184 TEST_F(OperatorTest, BuiltinL2Pool) { 185 L2PoolOperator op; 186 op.stride_width = 123; 187 op.stride_height = 124; 188 op.padding.type = PaddingType::kValid; 189 op.kwidth = 480; 190 op.kheight = 1080; 191 auto output_toco_op = SerializeAndDeserialize( 192 GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op); 193 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 194 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 195 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 196 EXPECT_EQ(op.kwidth, output_toco_op->kwidth); 197 EXPECT_EQ(op.kheight, output_toco_op->kheight); 198 } 199 200 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) { 201 LocalResponseNormalizationOperator op; 202 op.range = 123; 203 op.bias = 1.23; 204 op.alpha = 12.3; 205 op.beta = .123; 206 auto output_toco_op = SerializeAndDeserialize( 207 GetOperator("LOCAL_RESPONSE_NORMALIZATION", 208 OperatorType::kLocalResponseNormalization), 209 op); 210 EXPECT_EQ(op.range, output_toco_op->range); 211 EXPECT_EQ(op.bias, output_toco_op->bias); 212 EXPECT_EQ(op.alpha, output_toco_op->alpha); 213 EXPECT_EQ(op.beta, output_toco_op->beta); 214 } 215 216 TEST_F(OperatorTest, BuiltinMaxPool) { 217 MaxPoolOperator op; 218 op.stride_width = 123; 219 op.stride_height = 124; 220 op.padding.type = PaddingType::kValid; 221 op.kwidth = 480; 222 op.kheight = 1080; 223 auto output_toco_op = SerializeAndDeserialize( 224 GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op); 225 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 226 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 227 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 228 EXPECT_EQ(op.kwidth, output_toco_op->kwidth); 229 EXPECT_EQ(op.kheight, output_toco_op->kheight); 230 } 231 232 TEST_F(OperatorTest, BuiltinReshape) { 233 TensorFlowReshapeOperator op; 234 op.shape = {1, 2, 4, 5, 8}; 235 auto output_toco_op = SerializeAndDeserialize( 236 GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); 237 EXPECT_EQ(op.shape, output_toco_op->shape); 238 } 239 240 TEST_F(OperatorTest, CustomSoftmax) { 241 SoftmaxOperator op; 242 op.beta = 123.1; 243 auto output_toco_op = SerializeAndDeserialize( 244 GetOperator("SOFTMAX", OperatorType::kSoftmax), op); 245 EXPECT_EQ(op.beta, output_toco_op->beta); 246 } 247 248 TEST_F(OperatorTest, BuiltinSpaceToDepth) { 249 SpaceToDepthOperator op; 250 op.block_size = 123; 251 auto output_toco_op = SerializeAndDeserialize( 252 GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op); 253 EXPECT_EQ(op.block_size, output_toco_op->block_size); 254 } 255 256 TEST_F(OperatorTest, CustomSplit) { 257 TensorFlowSplitOperator op; 258 op.num_split = 123; 259 auto output_toco_op = SerializeAndDeserialize( 260 GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); 261 EXPECT_EQ(op.num_split, output_toco_op->num_split); 262 } 263 264 TEST_F(OperatorTest, BuiltinAveragePool) { 265 AveragePoolOperator op; 266 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 267 op.stride_width = 123; 268 op.stride_height = 124; 269 op.padding.type = PaddingType::kValid; 270 op.kwidth = 480; 271 op.kheight = 1080; 272 auto output_toco_op = SerializeAndDeserialize( 273 GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op); 274 EXPECT_EQ(op.fused_activation_function, 275 output_toco_op->fused_activation_function); 276 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 277 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 278 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 279 EXPECT_EQ(op.kwidth, output_toco_op->kwidth); 280 EXPECT_EQ(op.kheight, output_toco_op->kheight); 281 } 282 283 TEST_F(OperatorTest, BuiltinConvolution) { 284 ConvOperator op; 285 op.stride_width = 123; 286 op.stride_height = 124; 287 op.padding.type = PaddingType::kValid; 288 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 289 auto output_toco_op = 290 SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op); 291 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 292 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 293 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 294 EXPECT_EQ(op.fused_activation_function, 295 output_toco_op->fused_activation_function); 296 } 297 298 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) { 299 DepthwiseConvOperator op; 300 op.stride_width = 123; 301 op.stride_height = 124; 302 op.padding.type = PaddingType::kValid; 303 op.depth_multiplier = 6; 304 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 305 auto output_toco_op = SerializeAndDeserialize( 306 GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op); 307 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 308 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 309 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 310 EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier); 311 EXPECT_EQ(op.fused_activation_function, 312 output_toco_op->fused_activation_function); 313 } 314 315 TEST_F(OperatorTest, BuiltinL2Norm) { 316 L2NormalizationOperator op; 317 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 318 auto output_toco_op = SerializeAndDeserialize( 319 GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op); 320 EXPECT_EQ(op.fused_activation_function, 321 output_toco_op->fused_activation_function); 322 } 323 324 TEST_F(OperatorTest, BuiltinMul) { 325 MulOperator op; 326 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 327 auto output_toco_op = 328 SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op); 329 EXPECT_EQ(op.fused_activation_function, 330 output_toco_op->fused_activation_function); 331 } 332 333 TEST_F(OperatorTest, ResizeBilinear) { 334 ResizeBilinearOperator op; 335 op.align_corners = true; 336 auto output_toco_op = SerializeAndDeserialize( 337 GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op); 338 EXPECT_EQ(op.align_corners, output_toco_op->align_corners); 339 } 340 341 TEST_F(OperatorTest, Svdf) { 342 SvdfOperator op; 343 op.fused_activation_function = FusedActivationFunctionType::kRelu; 344 op.rank = 1; 345 auto output_toco_op = 346 SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op); 347 EXPECT_EQ(op.fused_activation_function, 348 output_toco_op->fused_activation_function); 349 EXPECT_EQ(op.rank, output_toco_op->rank); 350 } 351 352 TEST_F(OperatorTest, Squeeze) { 353 SqueezeOperator op; 354 op.squeeze_dims = {-2, -3, 4, 1, 4}; 355 356 auto output_toco_op = SerializeAndDeserialize( 357 GetOperator("SQUEEZE", OperatorType::kSqueeze), op); 358 EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims); 359 } 360 361 TEST_F(OperatorTest, StridedSlice) { 362 StridedSliceOperator op; 363 364 op.begin_mask = 1; 365 op.end_mask = 2; 366 op.ellipsis_mask = 1; 367 op.new_axis_mask = 1; 368 op.shrink_axis_mask = 2; 369 370 auto output_toco_op = SerializeAndDeserialize( 371 GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op); 372 EXPECT_EQ(op.start_indices, output_toco_op->start_indices); 373 EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices); 374 EXPECT_EQ(op.strides, output_toco_op->strides); 375 EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask); 376 EXPECT_EQ(op.end_mask, output_toco_op->end_mask); 377 EXPECT_EQ(op.end_mask, output_toco_op->end_mask); 378 EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask); 379 EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask); 380 EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask); 381 } 382 383 TEST_F(OperatorTest, BuiltinTopKV2) { 384 TopKV2Operator op; 385 auto output_toco_op = SerializeAndDeserialize( 386 GetOperator("TOPK_V2", OperatorType::kTopK_V2), op); 387 ASSERT_NE(nullptr, output_toco_op.get()); 388 } 389 390 TEST_F(OperatorTest, TensorFlowUnsupported) { 391 TensorFlowUnsupportedOperator op; 392 op.tensorflow_op = "MyCustomUnsupportedOp"; 393 394 ::tensorflow::NodeDef node_def; 395 auto attr = node_def.mutable_attr(); 396 (*attr)["float_attr"].set_f(2.0); 397 (*attr)["str_attr"].set_s("Hello World"); 398 (*attr)["int_attr"].set_i(17); 399 (*attr)["bool_attr"].set_b(true); 400 node_def.SerializeToString(&op.tensorflow_node_def); 401 402 auto output_toco_op = 403 SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", 404 OperatorType::kTensorFlowUnsupported), 405 op); 406 407 ::tensorflow::NodeDef output_node_def; 408 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); 409 const auto& output_attr = output_node_def.attr(); 410 EXPECT_EQ(2.0, output_attr.at("float_attr").f()); 411 EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); 412 EXPECT_EQ(17, output_attr.at("int_attr").i()); 413 EXPECT_EQ(true, output_attr.at("bool_attr").b()); 414 } 415 416 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { 417 TensorFlowUnsupportedOperator op; 418 op.tensorflow_op = "MyCustomUnsupportedOp"; 419 auto output_toco_op = 420 SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", 421 OperatorType::kTensorFlowUnsupported), 422 op); 423 424 ::tensorflow::NodeDef output_node_def; 425 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); 426 EXPECT_TRUE(output_node_def.attr().empty()); 427 } 428 429 } // namespace 430 } // namespace tflite 431 432 } // namespace toco 433