1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Script to test TF-TRT INT8 conversion without calibration on Mnist model.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import nn 29 from tensorflow.python.platform import test 30 31 32 class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase): 33 34 def GetParams(self): 35 # TODO(laigd): we should test the following cases: 36 # - batch size is not changed, other dims are changing 37 # - batch size is decreasing, other dims are identical 38 # - batch size is decreasing, other dims are changing 39 # - batch size is increasing, other dims are identical 40 # - batch size is increasing, other dims are changing 41 input_dims = [[[1, 5, 5, 1]], [[10, 5, 5, 1]], [[3, 5, 5, 1]], 42 [[1, 5, 5, 1]], [[1, 3, 1, 1]], [[2, 9, 9, 1]], 43 [[1, 224, 224, 1]], [[1, 128, 224, 1]]] 44 expected_output_dims = input_dims 45 46 g = ops.Graph() 47 with g.as_default(): 48 x = array_ops.placeholder( 49 shape=(None, None, None, 1), dtype=dtypes.float32, name="input") 50 conv_filter1 = constant_op.constant( 51 np.ones([3, 3, 1, 8]), name="weights1", dtype=dtypes.float32) 52 bias1 = constant_op.constant(np.random.randn(8), dtype=dtypes.float32) 53 x = nn.conv2d( 54 input=x, 55 filter=conv_filter1, 56 strides=[1, 1, 1, 1], 57 padding="SAME", 58 name="conv") 59 x = nn.bias_add(x, bias1) 60 x = nn.relu(x) 61 conv_filter2 = constant_op.constant( 62 np.ones([3, 3, 8, 1]), name="weights2", dtype=dtypes.float32) 63 bias2 = constant_op.constant(np.random.randn(1), dtype=dtypes.float32) 64 x = nn.conv2d( 65 input=x, 66 filter=conv_filter2, 67 strides=[1, 1, 1, 1], 68 padding="SAME", 69 name="conv") 70 x = nn.bias_add(x, bias2) 71 x = array_ops.identity(x, name="output") 72 73 return trt_test.TfTrtIntegrationTestParams( 74 gdef=g.as_graph_def(), 75 input_names=["input"], 76 input_dims=input_dims, 77 output_names=["output"], 78 expected_output_dims=expected_output_dims) 79 80 def GetConversionParams(self, run_params): 81 """Return a ConversionParams for test.""" 82 conversion_params = super(DynamicInputShapesTest, 83 self).GetConversionParams(run_params) 84 return conversion_params._replace( 85 maximum_cached_engines=10, 86 # Disable layout optimizer, since it will convert BiasAdd with NHWC 87 # format to NCHW format under four dimentional input. 88 rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) 89 90 def ExpectedEnginesToBuild(self, run_params): 91 return ["TRTEngineOp_0"] 92 93 def ShouldRunTest(self, run_params): 94 return (run_params.dynamic_engine and 95 not trt_test.IsQuantizationMode(run_params.precision_mode)) 96 97 def ExpectedAbsoluteTolerance(self, run_params): 98 """The absolute tolerance to compare floating point results.""" 99 return 1.e-03 if run_params.precision_mode == "FP32" else 1.e-01 100 101 def ExpectedRelativeTolerance(self, run_params): 102 """The relative tolerance to compare floating point results.""" 103 return 1.e-03 if run_params.precision_mode == "FP32" else 1.e-01 104 105 106 if __name__ == "__main__": 107 test.main() 108