Home | History | Annotate | Download | only in python
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for lite.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import tempfile
     23 import numpy as np
     24 
     25 from tensorflow.lite.python import lite
     26 from tensorflow.lite.python import lite_constants
     27 from tensorflow.lite.python.interpreter import Interpreter
     28 from tensorflow.python import keras
     29 from tensorflow.python.client import session
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import nn_ops
     36 from tensorflow.python.ops import variable_scope
     37 from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
     38 from tensorflow.python.platform import gfile
     39 from tensorflow.python.platform import resource_loader
     40 from tensorflow.python.platform import test
     41 from tensorflow.python.saved_model import saved_model
     42 from tensorflow.python.training.training_util import write_graph
     43 
     44 
     45 class FromConstructor(test_util.TensorFlowTestCase):
     46 
     47   # Tests invalid constructors using a dummy value for the GraphDef.
     48   def testInvalidConstructor(self):
     49     message = ('If input_tensors and output_tensors are None, both '
     50                'input_arrays_with_shape and output_arrays must be defined.')
     51 
     52     # `output_arrays` is not defined.
     53     with self.assertRaises(ValueError) as error:
     54       lite.TFLiteConverter(
     55           None, None, [], input_arrays_with_shape=[('input', [3, 9])])
     56     self.assertEqual(message, str(error.exception))
     57 
     58     # `input_arrays_with_shape` is not defined.
     59     with self.assertRaises(ValueError) as error:
     60       lite.TFLiteConverter(None, [], None, output_arrays=['output'])
     61     self.assertEqual(message, str(error.exception))
     62 
     63   # Tests valid constructors using a dummy value for the GraphDef.
     64   def testValidConstructor(self):
     65     converter = lite.TFLiteConverter(
     66         None,
     67         None,
     68         None,
     69         input_arrays_with_shape=[('input', [3, 9])],
     70         output_arrays=['output'])
     71     self.assertFalse(converter._has_valid_tensors())
     72     self.assertEqual(converter.get_input_arrays(), ['input'])
     73 
     74     with self.assertRaises(ValueError) as error:
     75       converter._set_batch_size(1)
     76     self.assertEqual(
     77         'The batch size cannot be set for this model. Please use '
     78         'input_shapes parameter.', str(error.exception))
     79 
     80     converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
     81     self.assertTrue(converter._has_valid_tensors())
     82 
     83 
     84 @test_util.run_v1_only('b/120545219')
     85 class FromSessionTest(test_util.TensorFlowTestCase):
     86 
     87   def testFloat(self):
     88     in_tensor = array_ops.placeholder(
     89         shape=[1, 16, 16, 3], dtype=dtypes.float32)
     90     out_tensor = in_tensor + in_tensor
     91     sess = session.Session()
     92 
     93     # Convert model and ensure model is not None.
     94     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
     95                                                   [out_tensor])
     96     tflite_model = converter.convert()
     97     self.assertTrue(tflite_model)
     98 
     99     # Check values from converted model.
    100     interpreter = Interpreter(model_content=tflite_model)
    101     interpreter.allocate_tensors()
    102 
    103     input_details = interpreter.get_input_details()
    104     self.assertEqual(1, len(input_details))
    105     self.assertEqual('Placeholder', input_details[0]['name'])
    106     self.assertEqual(np.float32, input_details[0]['dtype'])
    107     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    108     self.assertEqual((0., 0.), input_details[0]['quantization'])
    109 
    110     output_details = interpreter.get_output_details()
    111     self.assertEqual(1, len(output_details))
    112     self.assertEqual('add', output_details[0]['name'])
    113     self.assertEqual(np.float32, output_details[0]['dtype'])
    114     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    115     self.assertEqual((0., 0.), output_details[0]['quantization'])
    116 
    117   def testString(self):
    118     in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
    119     out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
    120     sess = session.Session()
    121 
    122     # Convert model and ensure model is not None.
    123     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    124                                                   [out_tensor])
    125     tflite_model = converter.convert()
    126     self.assertTrue(tflite_model)
    127 
    128     # Check values from converted model.
    129     interpreter = Interpreter(model_content=tflite_model)
    130     interpreter.allocate_tensors()
    131 
    132     input_details = interpreter.get_input_details()
    133     self.assertEqual(1, len(input_details))
    134     self.assertEqual('Placeholder', input_details[0]['name'])
    135     self.assertEqual(np.string_, input_details[0]['dtype'])
    136     self.assertTrue(([4] == input_details[0]['shape']).all())
    137 
    138     output_details = interpreter.get_output_details()
    139     self.assertEqual(1, len(output_details))
    140     self.assertEqual('Reshape', output_details[0]['name'])
    141     self.assertEqual(np.string_, output_details[0]['dtype'])
    142     self.assertTrue(([2, 2] == output_details[0]['shape']).all())
    143     # TODO(b/122659643): Test setting/getting string data via the python
    144     # interpreter API after support has been added.
    145 
    146   def testQuantization(self):
    147     in_tensor_1 = array_ops.placeholder(
    148         shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
    149     in_tensor_2 = array_ops.placeholder(
    150         shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
    151     out_tensor = array_ops.fake_quant_with_min_max_args(
    152         in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
    153     sess = session.Session()
    154 
    155     # Convert model and ensure model is not None.
    156     converter = lite.TFLiteConverter.from_session(
    157         sess, [in_tensor_1, in_tensor_2], [out_tensor])
    158     converter.inference_type = lite_constants.QUANTIZED_UINT8
    159     converter.quantized_input_stats = {
    160         'inputA': (0., 1.),
    161         'inputB': (0., 1.)
    162     }  # mean, std_dev
    163     tflite_model = converter.convert()
    164     self.assertTrue(tflite_model)
    165 
    166     # Check values from converted model.
    167     interpreter = Interpreter(model_content=tflite_model)
    168     interpreter.allocate_tensors()
    169 
    170     input_details = interpreter.get_input_details()
    171     self.assertEqual(2, len(input_details))
    172     self.assertEqual('inputA', input_details[0]['name'])
    173     self.assertEqual(np.uint8, input_details[0]['dtype'])
    174     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    175     self.assertEqual((1., 0.),
    176                      input_details[0]['quantization'])  # scale, zero_point
    177 
    178     self.assertEqual('inputB', input_details[1]['name'])
    179     self.assertEqual(np.uint8, input_details[1]['dtype'])
    180     self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
    181     self.assertEqual((1., 0.),
    182                      input_details[1]['quantization'])  # scale, zero_point
    183 
    184     output_details = interpreter.get_output_details()
    185     self.assertEqual(1, len(output_details))
    186     self.assertEqual('output', output_details[0]['name'])
    187     self.assertEqual(np.uint8, output_details[0]['dtype'])
    188     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    189     self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
    190 
    191   def testQuantizationInvalid(self):
    192     in_tensor_1 = array_ops.placeholder(
    193         shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
    194     in_tensor_2 = array_ops.placeholder(
    195         shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
    196     out_tensor = array_ops.fake_quant_with_min_max_args(
    197         in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
    198     sess = session.Session()
    199 
    200     # Convert model and ensure model is not None.
    201     converter = lite.TFLiteConverter.from_session(
    202         sess, [in_tensor_1, in_tensor_2], [out_tensor])
    203     converter.inference_type = lite_constants.QUANTIZED_UINT8
    204     converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
    205     with self.assertRaises(ValueError) as error:
    206       converter.convert()
    207     self.assertEqual(
    208         'Quantization input stats are not available for input tensors '
    209         '\'inputB\'.', str(error.exception))
    210 
    211   def testIntermediateInputArray(self):
    212     """Convert a model from an intermediate input array."""
    213     in_tensor_init = array_ops.placeholder(
    214         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    215     in_tensor_final = in_tensor_init + in_tensor_init
    216     out_tensor = in_tensor_final + in_tensor_final
    217     sess = session.Session()
    218 
    219     # Convert model and ensure model is not None.
    220     converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
    221                                                   [out_tensor])
    222     tflite_model = converter.convert()
    223     self.assertTrue(tflite_model)
    224 
    225     # Check values from converted model.
    226     interpreter = Interpreter(model_content=tflite_model)
    227     interpreter.allocate_tensors()
    228 
    229     input_details = interpreter.get_input_details()
    230     self.assertEqual(1, len(input_details))
    231     self.assertEqual('add', input_details[0]['name'])
    232     self.assertEqual(np.float32, input_details[0]['dtype'])
    233     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    234     self.assertEqual((0., 0.), input_details[0]['quantization'])
    235 
    236     output_details = interpreter.get_output_details()
    237     self.assertEqual(1, len(output_details))
    238     self.assertEqual('add_1', output_details[0]['name'])
    239     self.assertEqual(np.float32, output_details[0]['dtype'])
    240     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    241     self.assertEqual((0., 0.), output_details[0]['quantization'])
    242 
    243   def testSizeNoneInvalid(self):
    244     in_tensor = array_ops.placeholder(dtype=dtypes.float32)
    245     out_tensor = in_tensor + in_tensor
    246     sess = session.Session()
    247 
    248     # Test None as shape.
    249     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    250                                                   [out_tensor])
    251     with self.assertRaises(ValueError) as error:
    252       converter.convert()
    253     self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
    254                      str(error.exception))
    255 
    256   def testScalarValid(self):
    257     # Construct a graph using a scalar (empty shape) input.
    258     in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
    259     out_tensor = in_tensor + in_tensor
    260     sess = session.Session()
    261 
    262     # Test conversion with the scalar input shape.
    263     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    264                                                   [out_tensor])
    265     tflite_model = converter.convert()
    266     self.assertTrue(tflite_model)
    267 
    268     # Check values from converted model.
    269     interpreter = Interpreter(model_content=tflite_model)
    270     interpreter.allocate_tensors()
    271 
    272     input_details = interpreter.get_input_details()
    273     self.assertEqual(1, len(input_details))
    274     self.assertEqual('Placeholder', input_details[0]['name'])
    275     self.assertEqual(np.float32, input_details[0]['dtype'])
    276     self.assertTrue(([] == input_details[0]['shape']).all())
    277 
    278     output_details = interpreter.get_output_details()
    279     self.assertEqual(1, len(output_details))
    280     self.assertEqual('add', output_details[0]['name'])
    281     self.assertEqual(np.float32, output_details[0]['dtype'])
    282     self.assertTrue(([] == input_details[0]['shape']).all())
    283 
    284     # Validate inference using the scalar inputs/outputs.
    285     test_input = np.array(4.0, dtype=np.float32)
    286     expected_output = np.array(8.0, dtype=np.float32)
    287     interpreter.set_tensor(input_details[0]['index'], test_input)
    288     interpreter.invoke()
    289 
    290     output_data = interpreter.get_tensor(output_details[0]['index'])
    291     self.assertTrue((expected_output == output_data).all())
    292 
    293   def testSizeInvalid(self):
    294     in_tensor = array_ops.placeholder(
    295         shape=[1, None, 16, 3], dtype=dtypes.float32)
    296     out_tensor = in_tensor + in_tensor
    297     sess = session.Session()
    298 
    299     # Test invalid shape. None after 1st dimension.
    300     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    301                                                   [out_tensor])
    302     with self.assertRaises(ValueError) as error:
    303       converter.convert()
    304     self.assertEqual(
    305         'None is only supported in the 1st dimension. Tensor '
    306         '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
    307         str(error.exception))
    308 
    309   def testBatchSizeValid(self):
    310     in_tensor = array_ops.placeholder(
    311         shape=[None, 16, 16, 3], dtype=dtypes.float32)
    312     out_tensor = in_tensor + in_tensor
    313     sess = session.Session()
    314 
    315     # Convert model and ensure model is not None.
    316     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    317                                                   [out_tensor])
    318     tflite_model = converter.convert()
    319     self.assertTrue(tflite_model)
    320 
    321     # Check values from converted model.
    322     interpreter = Interpreter(model_content=tflite_model)
    323     interpreter.allocate_tensors()
    324 
    325     input_details = interpreter.get_input_details()
    326     self.assertEqual(1, len(input_details))
    327     self.assertEqual('Placeholder', input_details[0]['name'])
    328     self.assertEqual(np.float32, input_details[0]['dtype'])
    329     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    330     self.assertEqual((0., 0.), input_details[0]['quantization'])
    331 
    332     output_details = interpreter.get_output_details()
    333     self.assertEqual(1, len(output_details))
    334     self.assertEqual('add', output_details[0]['name'])
    335     self.assertEqual(np.float32, output_details[0]['dtype'])
    336     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    337     self.assertEqual((0., 0.), output_details[0]['quantization'])
    338 
    339   def testFreezeGraph(self):
    340     in_tensor = array_ops.placeholder(
    341         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    342     var = variable_scope.get_variable(
    343         'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
    344     out_tensor = in_tensor + var
    345     sess = session.Session()
    346     sess.run(_global_variables_initializer())
    347 
    348     # Convert model and ensure model is not None.
    349     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    350                                                   [out_tensor])
    351     tflite_model = converter.convert()
    352     self.assertTrue(tflite_model)
    353 
    354     # Check values from converted model.
    355     interpreter = Interpreter(model_content=tflite_model)
    356     interpreter.allocate_tensors()
    357 
    358     input_details = interpreter.get_input_details()
    359     self.assertEqual(1, len(input_details))
    360     self.assertEqual('Placeholder', input_details[0]['name'])
    361     self.assertEqual(np.float32, input_details[0]['dtype'])
    362     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    363     self.assertEqual((0., 0.), input_details[0]['quantization'])
    364 
    365     output_details = interpreter.get_output_details()
    366     self.assertEqual(1, len(output_details))
    367     self.assertEqual('add', output_details[0]['name'])
    368     self.assertEqual(np.float32, output_details[0]['dtype'])
    369     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    370     self.assertEqual((0., 0.), output_details[0]['quantization'])
    371 
    372   # TODO(nupurgarg): Verify value of contents in GraphViz.
    373   def testGraphviz(self):
    374     in_tensor = array_ops.placeholder(
    375         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    376     out_tensor = in_tensor + in_tensor
    377     sess = session.Session()
    378 
    379     # Convert model and ensure model is not None.
    380     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    381                                                   [out_tensor])
    382     converter.output_format = lite_constants.GRAPHVIZ_DOT
    383     graphviz_output = converter.convert()
    384     self.assertTrue(graphviz_output)
    385 
    386   # TODO(nupurgarg): Verify value of contents in GraphViz.
    387   def testDumpGraphviz(self):
    388     in_tensor = array_ops.placeholder(
    389         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    390     out_tensor = in_tensor + in_tensor
    391     sess = session.Session()
    392 
    393     # Convert model and ensure model is not None.
    394     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    395                                                   [out_tensor])
    396     graphviz_dir = self.get_temp_dir()
    397     converter.dump_graphviz_dir = graphviz_dir
    398     tflite_model = converter.convert()
    399     self.assertTrue(tflite_model)
    400 
    401     # Ensure interpreter is able to allocate and check graphviz data.
    402     interpreter = Interpreter(model_content=tflite_model)
    403     interpreter.allocate_tensors()
    404 
    405     num_items_graphviz = len(os.listdir(graphviz_dir))
    406     self.assertTrue(num_items_graphviz)
    407 
    408     # Convert model and ensure model is not None.
    409     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    410                                                   [out_tensor])
    411     graphviz_dir = self.get_temp_dir()
    412     converter.dump_graphviz_dir = graphviz_dir
    413     converter.dump_graphviz_video = True
    414     tflite_model = converter.convert()
    415     self.assertTrue(tflite_model)
    416 
    417     # Ensure graphviz folder has more data after using video flag.
    418     num_items_graphviz_video = len(os.listdir(graphviz_dir))
    419     self.assertTrue(num_items_graphviz_video > num_items_graphviz)
    420 
    421   def testInferenceInputType(self):
    422     in_tensor = array_ops.placeholder(
    423         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    424     out_tensor = in_tensor + in_tensor
    425     sess = session.Session()
    426 
    427     # Convert model and ensure model is not None.
    428     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    429                                                   [out_tensor])
    430     converter.inference_input_type = lite_constants.QUANTIZED_UINT8
    431     converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
    432     tflite_model = converter.convert()
    433     self.assertTrue(tflite_model)
    434 
    435     # Check values from converted model.
    436     interpreter = Interpreter(model_content=tflite_model)
    437     interpreter.allocate_tensors()
    438 
    439     input_details = interpreter.get_input_details()
    440     self.assertEqual(1, len(input_details))
    441     self.assertEqual('Placeholder', input_details[0]['name'])
    442     self.assertEqual(np.uint8, input_details[0]['dtype'])
    443     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    444     self.assertEqual((1., 0.), input_details[0]['quantization'])
    445 
    446     output_details = interpreter.get_output_details()
    447     self.assertEqual(1, len(output_details))
    448     self.assertEqual('add', output_details[0]['name'])
    449     self.assertEqual(np.float32, output_details[0]['dtype'])
    450     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    451 
    452   def testDefaultRangesStats(self):
    453     in_tensor = array_ops.placeholder(
    454         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    455     out_tensor = in_tensor + in_tensor
    456     sess = session.Session()
    457 
    458     # Convert model and ensure model is not None.
    459     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
    460                                                   [out_tensor])
    461     converter.inference_type = lite_constants.QUANTIZED_UINT8
    462     converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
    463     converter.default_ranges_stats = (0, 6)  # min, max
    464     tflite_model = converter.convert()
    465     self.assertTrue(tflite_model)
    466 
    467     # Check values from converted model.
    468     interpreter = Interpreter(model_content=tflite_model)
    469     interpreter.allocate_tensors()
    470 
    471     input_details = interpreter.get_input_details()
    472     self.assertEqual(1, len(input_details))
    473     self.assertEqual('Placeholder', input_details[0]['name'])
    474     self.assertEqual(np.uint8, input_details[0]['dtype'])
    475     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    476     self.assertEqual((1., 0.), input_details[0]['quantization'])
    477 
    478     output_details = interpreter.get_output_details()
    479     self.assertEqual(1, len(output_details))
    480     self.assertEqual('add', output_details[0]['name'])
    481     self.assertEqual(np.uint8, output_details[0]['dtype'])
    482     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    483     self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
    484 
    485   def testPostTrainingQuantizeDeprecatedAttribute(self):
    486     in_tensor_1 = array_ops.placeholder(
    487         shape=[33, 33], dtype=dtypes.float32, name='inputA')
    488     in_tensor_2 = constant_op.constant(
    489         np.random.uniform(low=-10., high=10., size=(33, 33)),
    490         shape=[33, 33],
    491         dtype=dtypes.float32,
    492         name='inputB')
    493     out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
    494     sess = session.Session()
    495 
    496     quantized_converter = lite.TFLiteConverter.from_session(
    497         sess, [in_tensor_1], [out_tensor])
    498     self.assertFalse(quantized_converter.post_training_quantize)
    499 
    500     quantized_converter.post_training_quantize = True
    501     self.assertTrue(quantized_converter.post_training_quantize)
    502     self.assertEqual(quantized_converter.optimizations,
    503                      [lite.Optimize.OPTIMIZE_FOR_SIZE])
    504 
    505     quantized_tflite = quantized_converter.convert()
    506     self.assertTrue(quantized_tflite)
    507 
    508   def testPostTrainingQuantize(self):
    509     np.random.seed(0)
    510     # We need the tensor to have more than 1024 elements for quantize_weights
    511     # to kick in. Thus, the [33, 33] shape.
    512     in_tensor_1 = array_ops.placeholder(
    513         shape=[33, 33], dtype=dtypes.float32, name='inputA')
    514     in_tensor_2 = constant_op.constant(
    515         np.random.uniform(low=-10., high=10., size=(33, 33)),
    516         shape=[33, 33],
    517         dtype=dtypes.float32,
    518         name='inputB')
    519     out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
    520     sess = session.Session()
    521 
    522     # Convert float model.
    523     float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
    524                                                         [out_tensor])
    525     float_tflite = float_converter.convert()
    526     self.assertTrue(float_tflite)
    527 
    528     # Convert quantized weights model.
    529     quantized_converter = lite.TFLiteConverter.from_session(
    530         sess, [in_tensor_1], [out_tensor])
    531     quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE]
    532     quantized_tflite = quantized_converter.convert()
    533     self.assertTrue(quantized_tflite)
    534 
    535     # Ensure that the quantized weights tflite model is smaller.
    536     self.assertTrue(len(quantized_tflite) < len(float_tflite))
    537 
    538   def testPostTrainingCalibrateAndQuantize(self):
    539     np.random.seed(0)
    540     # Create a mobilenet like model.
    541     output_channel = 16
    542     depth_multiplier = 1
    543     inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3))
    544     conv = nn_ops.conv2d(
    545         inp,
    546         filter=array_ops.zeros([3, 3, 3, output_channel]),
    547         strides=[1, 1, 1, 1],
    548         padding='SAME')
    549     dconv = nn_ops.depthwise_conv2d_native(
    550         conv,
    551         filter=array_ops.zeros(
    552             [16, 16, output_channel, output_channel * depth_multiplier]),
    553         strides=[1, 1, 1, 1],
    554         padding='SAME')
    555     pool = nn_ops.pool(
    556         dconv, window_shape=[2, 2], pooling_type='AVG', padding='SAME')
    557     max_pool = nn_ops.pool(
    558         pool, window_shape=[2, 2], pooling_type='MAX', padding='SAME')
    559     output = nn_ops.softmax(max_pool)
    560 
    561     def calibration_gen():
    562       for _ in range(10):
    563         yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
    564 
    565     sess = session.Session()
    566 
    567     # Convert float model.
    568     float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
    569     float_tflite = float_converter.convert()
    570     self.assertTrue(float_tflite)
    571 
    572     # Convert quantized weights model.
    573     quantized_converter = lite.TFLiteConverter.from_session(
    574         sess, [inp], [output])
    575     quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE]
    576     quantized_converter.representative_dataset = lite.RepresentativeDataset(
    577         calibration_gen)
    578     quantized_tflite = quantized_converter.convert()
    579     self.assertTrue(quantized_tflite)
    580 
    581     # Ensure that the quantized weights tflite model is smaller.
    582     self.assertTrue(len(quantized_tflite) < len(float_tflite))
    583 
    584   def testFloatTocoConverter(self):
    585     """Tests deprecated test TocoConverter."""
    586     in_tensor = array_ops.placeholder(
    587         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    588     out_tensor = in_tensor + in_tensor
    589     sess = session.Session()
    590 
    591     # Convert model and ensure model is not None.
    592     converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
    593     tflite_model = converter.convert()
    594     self.assertTrue(tflite_model)
    595 
    596     # Ensure the interpreter is able to load.
    597     interpreter = Interpreter(model_content=tflite_model)
    598     interpreter.allocate_tensors()
    599 
    600   def testMultipleOutputNodeNames(self):
    601     """Tests converting a graph with an op that have multiple outputs."""
    602     input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
    603     out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0)
    604     sess = session.Session()
    605 
    606     # Convert model and ensure model is not None.
    607     converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
    608                                                   [out0, out1, out2, out3])
    609     tflite_model = converter.convert()
    610     self.assertTrue(tflite_model)
    611 
    612     # Check values from converted model.
    613     interpreter = Interpreter(model_content=tflite_model)
    614     interpreter.allocate_tensors()
    615 
    616     input_details = interpreter.get_input_details()
    617     self.assertEqual(1, len(input_details))
    618     interpreter.set_tensor(input_details[0]['index'],
    619                            np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
    620     interpreter.invoke()
    621 
    622     output_details = interpreter.get_output_details()
    623     self.assertEqual(4, len(output_details))
    624     self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
    625     self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
    626     self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
    627     self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
    628 
    629 
    630 @test_util.run_v1_only('b/120545219')
    631 class FromFrozenGraphFile(test_util.TensorFlowTestCase):
    632 
    633   def testFloat(self):
    634     in_tensor = array_ops.placeholder(
    635         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    636     _ = in_tensor + in_tensor
    637     sess = session.Session()
    638 
    639     # Write graph to file.
    640     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    641     write_graph(sess.graph_def, '', graph_def_file, False)
    642     sess.close()
    643 
    644     # Convert model and ensure model is not None.
    645     converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
    646                                                        ['Placeholder'], ['add'])
    647     tflite_model = converter.convert()
    648     self.assertTrue(tflite_model)
    649 
    650     # Check values from converted model.
    651     interpreter = Interpreter(model_content=tflite_model)
    652     interpreter.allocate_tensors()
    653 
    654     input_details = interpreter.get_input_details()
    655     self.assertEqual(1, len(input_details))
    656     self.assertEqual('Placeholder', input_details[0]['name'])
    657     self.assertEqual(np.float32, input_details[0]['dtype'])
    658     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    659     self.assertEqual((0., 0.), input_details[0]['quantization'])
    660 
    661     output_details = interpreter.get_output_details()
    662     self.assertEqual(1, len(output_details))
    663     self.assertEqual('add', output_details[0]['name'])
    664     self.assertEqual(np.float32, output_details[0]['dtype'])
    665     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    666     self.assertEqual((0., 0.), output_details[0]['quantization'])
    667 
    668   def testFloatWithShapesArray(self):
    669     in_tensor = array_ops.placeholder(
    670         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    671     _ = in_tensor + in_tensor
    672     sess = session.Session()
    673 
    674     # Write graph to file.
    675     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    676     write_graph(sess.graph_def, '', graph_def_file, False)
    677     sess.close()
    678 
    679     # Convert model and ensure model is not None.
    680     converter = lite.TFLiteConverter.from_frozen_graph(
    681         graph_def_file, ['Placeholder'], ['add'],
    682         input_shapes={'Placeholder': [1, 16, 16, 3]})
    683     tflite_model = converter.convert()
    684     self.assertTrue(tflite_model)
    685 
    686     # Check values from converted model.
    687     interpreter = Interpreter(model_content=tflite_model)
    688     interpreter.allocate_tensors()
    689 
    690     input_details = interpreter.get_input_details()
    691     self.assertEqual(1, len(input_details))
    692     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    693 
    694   def testFreezeGraph(self):
    695     in_tensor = array_ops.placeholder(
    696         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    697     var = variable_scope.get_variable(
    698         'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
    699     _ = in_tensor + var
    700     sess = session.Session()
    701 
    702     # Write graph to file.
    703     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    704     write_graph(sess.graph_def, '', graph_def_file, False)
    705     sess.close()
    706 
    707     # Ensure the graph with variables cannot be converted.
    708     with self.assertRaises(ValueError) as error:
    709       lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
    710                                              ['add'])
    711     self.assertEqual('Please freeze the graph using freeze_graph.py.',
    712                      str(error.exception))
    713 
    714   def testPbtxt(self):
    715     in_tensor = array_ops.placeholder(
    716         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    717     _ = in_tensor + in_tensor
    718     sess = session.Session()
    719 
    720     # Write graph to file.
    721     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
    722     write_graph(sess.graph_def, '', graph_def_file, True)
    723     sess.close()
    724 
    725     # Convert model and ensure model is not None.
    726     converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
    727                                                        ['Placeholder'], ['add'])
    728     tflite_model = converter.convert()
    729     self.assertTrue(tflite_model)
    730 
    731     # Check values from converted model.
    732     interpreter = Interpreter(model_content=tflite_model)
    733     interpreter.allocate_tensors()
    734 
    735     input_details = interpreter.get_input_details()
    736     self.assertEqual(1, len(input_details))
    737     self.assertEqual('Placeholder', input_details[0]['name'])
    738     self.assertEqual(np.float32, input_details[0]['dtype'])
    739     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    740     self.assertEqual((0., 0.), input_details[0]['quantization'])
    741 
    742     output_details = interpreter.get_output_details()
    743     self.assertEqual(1, len(output_details))
    744     self.assertEqual('add', output_details[0]['name'])
    745     self.assertEqual(np.float32, output_details[0]['dtype'])
    746     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    747     self.assertEqual((0., 0.), output_details[0]['quantization'])
    748 
    749   def testInvalidFileNotFound(self):
    750     with self.assertRaises(IOError) as error:
    751       lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
    752                                              ['add'])
    753     self.assertEqual('File \'invalid_file\' does not exist.',
    754                      str(error.exception))
    755 
    756   def testInvalidFileBadData(self):
    757     graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
    758     with gfile.Open(graph_def_file, 'wb') as temp_file:
    759       temp_file.write('bad data')
    760       temp_file.flush()
    761 
    762     # Attempts to convert the invalid model.
    763     with self.assertRaises(IOError) as error:
    764       lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
    765                                              ['add'])
    766     self.assertEqual(
    767         'Unable to parse input file \'{}\'.'.format(graph_def_file),
    768         str(error.exception))
    769 
    770   # TODO(nupurgarg): Test model loading in open source.
    771   def _initObjectDetectionArgs(self):
    772     # Initializes the arguments required for the object detection model.
    773     # Looks for the model file which is saved in a different location internally
    774     # and externally.
    775     filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
    776     if not os.path.exists(filename):
    777       filename = os.path.join(
    778           resource_loader.get_root_dir_with_all_resources(),
    779           '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
    780       if not os.path.exists(filename):
    781         raise IOError("File '{0}' does not exist.".format(filename))
    782 
    783     self._graph_def_file = filename
    784     self._input_arrays = ['normalized_input_image_tensor']
    785     self._output_arrays = [
    786         'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
    787         'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
    788     ]
    789     self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
    790 
    791   def testTFLiteGraphDef(self):
    792     # Tests the object detection model that cannot be loaded in TensorFlow.
    793     self._initObjectDetectionArgs()
    794 
    795     converter = lite.TFLiteConverter.from_frozen_graph(
    796         self._graph_def_file, self._input_arrays, self._output_arrays,
    797         self._input_shapes)
    798     converter.allow_custom_ops = True
    799     tflite_model = converter.convert()
    800     self.assertTrue(tflite_model)
    801 
    802     # Check values from converted model.
    803     interpreter = Interpreter(model_content=tflite_model)
    804     interpreter.allocate_tensors()
    805 
    806     input_details = interpreter.get_input_details()
    807     self.assertEqual(1, len(input_details))
    808     self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
    809     self.assertEqual(np.float32, input_details[0]['dtype'])
    810     self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
    811     self.assertEqual((0., 0.), input_details[0]['quantization'])
    812 
    813     output_details = interpreter.get_output_details()
    814     self.assertEqual(4, len(output_details))
    815     self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
    816     self.assertEqual(np.float32, output_details[0]['dtype'])
    817     self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
    818     self.assertEqual((0., 0.), output_details[0]['quantization'])
    819 
    820     self.assertEqual('TFLite_Detection_PostProcess:1',
    821                      output_details[1]['name'])
    822     self.assertTrue(([1, 10] == output_details[1]['shape']).all())
    823     self.assertEqual('TFLite_Detection_PostProcess:2',
    824                      output_details[2]['name'])
    825     self.assertTrue(([1, 10] == output_details[2]['shape']).all())
    826     self.assertEqual('TFLite_Detection_PostProcess:3',
    827                      output_details[3]['name'])
    828     self.assertTrue(([1] == output_details[3]['shape']).all())
    829 
    830   def testTFLiteGraphDefMissingShape(self):
    831     # Tests invalid cases for the model that cannot be loaded in TensorFlow.
    832     self._initObjectDetectionArgs()
    833 
    834     # Missing `input_shapes`.
    835     with self.assertRaises(ValueError) as error:
    836       lite.TFLiteConverter.from_frozen_graph(
    837           self._graph_def_file, self._input_arrays, self._output_arrays)
    838     self.assertEqual('input_shapes must be defined for this model.',
    839                      str(error.exception))
    840 
    841   def testTFLiteGraphDefInvalidShape(self):
    842     # Tests invalid cases for the model that cannot be loaded in TensorFlow.
    843     self._initObjectDetectionArgs()
    844 
    845     # `input_shapes` does not contain the names in `input_arrays`.
    846     with self.assertRaises(ValueError) as error:
    847       lite.TFLiteConverter.from_frozen_graph(
    848           self._graph_def_file,
    849           self._input_arrays,
    850           self._output_arrays,
    851           input_shapes={'invalid-value': [1, 19]})
    852     self.assertEqual(
    853         'input_shapes must contain a value for each item in input_array.',
    854         str(error.exception))
    855 
    856   def testFloatTocoConverter(self):
    857     in_tensor = array_ops.placeholder(
    858         shape=[1, 16, 16, 3], dtype=dtypes.float32)
    859     _ = in_tensor + in_tensor
    860     sess = session.Session()
    861 
    862     # Write graph to file.
    863     graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
    864     write_graph(sess.graph_def, '', graph_def_file, False)
    865     sess.close()
    866 
    867     # Convert model and ensure model is not None.
    868     converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
    869                                                      ['Placeholder'], ['add'])
    870     tflite_model = converter.convert()
    871     self.assertTrue(tflite_model)
    872 
    873     # Ensure the model is able to load.
    874     interpreter = Interpreter(model_content=tflite_model)
    875     interpreter.allocate_tensors()
    876 
    877 
    878 @test_util.run_v1_only('b/120545219')
    879 class FromSavedModelTest(test_util.TensorFlowTestCase):
    880 
    881   def _createSavedModel(self, shape):
    882     """Create a simple SavedModel."""
    883     saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
    884     with session.Session() as sess:
    885       in_tensor_1 = array_ops.placeholder(
    886           shape=shape, dtype=dtypes.float32, name='inputB')
    887       in_tensor_2 = array_ops.placeholder(
    888           shape=shape, dtype=dtypes.float32, name='inputA')
    889       out_tensor = in_tensor_1 + in_tensor_2
    890       inputs = {'x': in_tensor_1, 'y': in_tensor_2}
    891       outputs = {'z': out_tensor}
    892       saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
    893     return saved_model_dir
    894 
    895   def testSimpleModel(self):
    896     """Test a SavedModel."""
    897     saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
    898 
    899     # Convert model and ensure model is not None.
    900     converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
    901     tflite_model = converter.convert()
    902     self.assertTrue(tflite_model)
    903 
    904     interpreter = Interpreter(model_content=tflite_model)
    905     interpreter.allocate_tensors()
    906 
    907     input_details = interpreter.get_input_details()
    908     self.assertEqual(2, len(input_details))
    909     self.assertEqual('inputA', input_details[0]['name'])
    910     self.assertEqual(np.float32, input_details[0]['dtype'])
    911     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    912     self.assertEqual((0., 0.), input_details[0]['quantization'])
    913 
    914     self.assertEqual('inputB', input_details[1]['name'])
    915     self.assertEqual(np.float32, input_details[1]['dtype'])
    916     self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
    917     self.assertEqual((0., 0.), input_details[1]['quantization'])
    918 
    919     output_details = interpreter.get_output_details()
    920     self.assertEqual(1, len(output_details))
    921     self.assertEqual('add', output_details[0]['name'])
    922     self.assertEqual(np.float32, output_details[0]['dtype'])
    923     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    924     self.assertEqual((0., 0.), output_details[0]['quantization'])
    925 
    926   def testNoneBatchSize(self):
    927     """Test a SavedModel, with None in input tensor's shape."""
    928     saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
    929 
    930     converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
    931     tflite_model = converter.convert()
    932     self.assertTrue(tflite_model)
    933 
    934     # Check values from converted model.
    935     interpreter = Interpreter(model_content=tflite_model)
    936     interpreter.allocate_tensors()
    937 
    938     input_details = interpreter.get_input_details()
    939     self.assertEqual(2, len(input_details))
    940     self.assertEqual('inputA', input_details[0]['name'])
    941     self.assertEqual(np.float32, input_details[0]['dtype'])
    942     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    943     self.assertEqual((0., 0.), input_details[0]['quantization'])
    944 
    945     self.assertEqual('inputB', input_details[1]['name'])
    946     self.assertEqual(np.float32, input_details[1]['dtype'])
    947     self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
    948     self.assertEqual((0., 0.), input_details[1]['quantization'])
    949 
    950     output_details = interpreter.get_output_details()
    951     self.assertEqual(1, len(output_details))
    952     self.assertEqual('add', output_details[0]['name'])
    953     self.assertEqual(np.float32, output_details[0]['dtype'])
    954     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    955     self.assertEqual((0., 0.), output_details[0]['quantization'])
    956 
    957   def testOrderInputArrays(self):
    958     """Test a SavedModel ordering of input arrays."""
    959     saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
    960 
    961     converter = lite.TFLiteConverter.from_saved_model(
    962         saved_model_dir, input_arrays=['inputB', 'inputA'])
    963     tflite_model = converter.convert()
    964     self.assertTrue(tflite_model)
    965 
    966     # Check values from converted model.
    967     interpreter = Interpreter(model_content=tflite_model)
    968     interpreter.allocate_tensors()
    969 
    970     input_details = interpreter.get_input_details()
    971     self.assertEqual(2, len(input_details))
    972     self.assertEqual('inputA', input_details[0]['name'])
    973     self.assertEqual(np.float32, input_details[0]['dtype'])
    974     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    975     self.assertEqual((0., 0.), input_details[0]['quantization'])
    976 
    977     self.assertEqual('inputB', input_details[1]['name'])
    978     self.assertEqual(np.float32, input_details[1]['dtype'])
    979     self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
    980     self.assertEqual((0., 0.), input_details[1]['quantization'])
    981 
    982     output_details = interpreter.get_output_details()
    983     self.assertEqual(1, len(output_details))
    984     self.assertEqual('add', output_details[0]['name'])
    985     self.assertEqual(np.float32, output_details[0]['dtype'])
    986     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    987     self.assertEqual((0., 0.), output_details[0]['quantization'])
    988 
    989   def testSubsetInputArrays(self):
    990     """Test a SavedModel with a subset of the input array names of the model."""
    991     saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
    992 
    993     # Check case where input shape is given.
    994     converter = lite.TFLiteConverter.from_saved_model(
    995         saved_model_dir,
    996         input_arrays=['inputA'],
    997         input_shapes={'inputA': [1, 16, 16, 3]})
    998 
    999     tflite_model = converter.convert()
   1000     self.assertTrue(tflite_model)
   1001 
   1002     # Check case where input shape is None.
   1003     converter = lite.TFLiteConverter.from_saved_model(
   1004         saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
   1005 
   1006     tflite_model = converter.convert()
   1007     self.assertTrue(tflite_model)
   1008 
   1009   def testSimpleModelTocoConverter(self):
   1010     """Test a SavedModel with deprecated TocoConverter."""
   1011     saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
   1012 
   1013     # Convert model and ensure model is not None.
   1014     converter = lite.TocoConverter.from_saved_model(saved_model_dir)
   1015     tflite_model = converter.convert()
   1016     self.assertTrue(tflite_model)
   1017 
   1018     # Ensure the model is able to load.
   1019     interpreter = Interpreter(model_content=tflite_model)
   1020     interpreter.allocate_tensors()
   1021 
   1022 
   1023 @test_util.run_v1_only('b/120545219')
   1024 class FromKerasFile(test_util.TensorFlowTestCase):
   1025 
   1026   def setUp(self):
   1027     keras.backend.clear_session()
   1028 
   1029   def _getSequentialModel(self):
   1030     with session.Session().as_default():
   1031       model = keras.models.Sequential()
   1032       model.add(keras.layers.Dense(2, input_shape=(3,)))
   1033       model.add(keras.layers.RepeatVector(3))
   1034       model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
   1035       model.compile(
   1036           loss=keras.losses.MSE,
   1037           optimizer=keras.optimizers.RMSprop(),
   1038           metrics=[keras.metrics.categorical_accuracy],
   1039           sample_weight_mode='temporal')
   1040       x = np.random.random((1, 3))
   1041       y = np.random.random((1, 3, 3))
   1042       model.train_on_batch(x, y)
   1043       model.predict(x)
   1044 
   1045       try:
   1046         fd, keras_file = tempfile.mkstemp('.h5')
   1047         keras.models.save_model(model, keras_file)
   1048       finally:
   1049         os.close(fd)
   1050       return keras_file
   1051 
   1052   def testSequentialModel(self):
   1053     """Test a Sequential tf.keras model with default inputs."""
   1054     keras_file = self._getSequentialModel()
   1055 
   1056     converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
   1057     tflite_model = converter.convert()
   1058     self.assertTrue(tflite_model)
   1059 
   1060     # Check tensor details of converted model.
   1061     interpreter = Interpreter(model_content=tflite_model)
   1062     interpreter.allocate_tensors()
   1063 
   1064     input_details = interpreter.get_input_details()
   1065     self.assertEqual(1, len(input_details))
   1066     self.assertEqual('dense_input', input_details[0]['name'])
   1067     self.assertEqual(np.float32, input_details[0]['dtype'])
   1068     self.assertTrue(([1, 3] == input_details[0]['shape']).all())
   1069     self.assertEqual((0., 0.), input_details[0]['quantization'])
   1070 
   1071     output_details = interpreter.get_output_details()
   1072     self.assertEqual(1, len(output_details))
   1073     self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
   1074     self.assertEqual(np.float32, output_details[0]['dtype'])
   1075     self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
   1076     self.assertEqual((0., 0.), output_details[0]['quantization'])
   1077 
   1078     # Check inference of converted model.
   1079     input_data = np.array([[1, 2, 3]], dtype=np.float32)
   1080     interpreter.set_tensor(input_details[0]['index'], input_data)
   1081     interpreter.invoke()
   1082     tflite_result = interpreter.get_tensor(output_details[0]['index'])
   1083 
   1084     keras_model = keras.models.load_model(keras_file)
   1085     keras_result = keras_model.predict(input_data)
   1086 
   1087     np.testing.assert_almost_equal(tflite_result, keras_result, 5)
   1088     os.remove(keras_file)
   1089 
   1090   def testSequentialModelInputArray(self):
   1091     """Test a Sequential tf.keras model testing input arrays argument."""
   1092     keras_file = self._getSequentialModel()
   1093 
   1094     # Invalid input array raises error.
   1095     with self.assertRaises(ValueError) as error:
   1096       lite.TFLiteConverter.from_keras_model_file(
   1097           keras_file, input_arrays=['invalid-input'])
   1098     self.assertEqual("Invalid tensors 'invalid-input' were found.",
   1099                      str(error.exception))
   1100 
   1101     # Valid input array.
   1102     converter = lite.TFLiteConverter.from_keras_model_file(
   1103         keras_file, input_arrays=['dense_input'])
   1104     tflite_model = converter.convert()
   1105     os.remove(keras_file)
   1106     self.assertTrue(tflite_model)
   1107 
   1108   def testSequentialModelInputShape(self):
   1109     """Test a Sequential tf.keras model testing input shapes argument."""
   1110     keras_file = self._getSequentialModel()
   1111 
   1112     # Passing in shape of invalid input array raises error.
   1113     with self.assertRaises(ValueError) as error:
   1114       converter = lite.TFLiteConverter.from_keras_model_file(
   1115           keras_file, input_shapes={'invalid-input': [2, 3]})
   1116     self.assertEqual(
   1117         "Invalid tensor 'invalid-input' found in tensor shapes map.",
   1118         str(error.exception))
   1119 
   1120     # Passing in shape of valid input array.
   1121     converter = lite.TFLiteConverter.from_keras_model_file(
   1122         keras_file, input_shapes={'dense_input': [2, 3]})
   1123     tflite_model = converter.convert()
   1124     os.remove(keras_file)
   1125     self.assertTrue(tflite_model)
   1126 
   1127     # Check input shape from converted model.
   1128     interpreter = Interpreter(model_content=tflite_model)
   1129     interpreter.allocate_tensors()
   1130 
   1131     input_details = interpreter.get_input_details()
   1132     self.assertEqual(1, len(input_details))
   1133     self.assertEqual('dense_input', input_details[0]['name'])
   1134     self.assertTrue(([2, 3] == input_details[0]['shape']).all())
   1135 
   1136   def testSequentialModelOutputArray(self):
   1137     """Test a Sequential tf.keras model testing output arrays argument."""
   1138     keras_file = self._getSequentialModel()
   1139 
   1140     # Invalid output array raises error.
   1141     with self.assertRaises(ValueError) as error:
   1142       lite.TFLiteConverter.from_keras_model_file(
   1143           keras_file, output_arrays=['invalid-output'])
   1144     self.assertEqual("Invalid tensors 'invalid-output' were found.",
   1145                      str(error.exception))
   1146 
   1147     # Valid output array.
   1148     converter = lite.TFLiteConverter.from_keras_model_file(
   1149         keras_file, output_arrays=['time_distributed/Reshape_1'])
   1150     tflite_model = converter.convert()
   1151     os.remove(keras_file)
   1152     self.assertTrue(tflite_model)
   1153 
   1154   def testFunctionalModel(self):
   1155     """Test a Functional tf.keras model with default inputs."""
   1156     with session.Session().as_default():
   1157       inputs = keras.layers.Input(shape=(3,), name='input')
   1158       x = keras.layers.Dense(2)(inputs)
   1159       output = keras.layers.Dense(3)(x)
   1160 
   1161       model = keras.models.Model(inputs, output)
   1162       model.compile(
   1163           loss=keras.losses.MSE,
   1164           optimizer=keras.optimizers.RMSprop(),
   1165           metrics=[keras.metrics.categorical_accuracy])
   1166       x = np.random.random((1, 3))
   1167       y = np.random.random((1, 3))
   1168       model.train_on_batch(x, y)
   1169 
   1170       model.predict(x)
   1171       fd, keras_file = tempfile.mkstemp('.h5')
   1172       try:
   1173         keras.models.save_model(model, keras_file)
   1174       finally:
   1175         os.close(fd)
   1176 
   1177     # Convert to TFLite model.
   1178     converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
   1179     tflite_model = converter.convert()
   1180     self.assertTrue(tflite_model)
   1181 
   1182     # Check tensor details of converted model.
   1183     interpreter = Interpreter(model_content=tflite_model)
   1184     interpreter.allocate_tensors()
   1185 
   1186     input_details = interpreter.get_input_details()
   1187     self.assertEqual(1, len(input_details))
   1188     self.assertEqual('input', input_details[0]['name'])
   1189     self.assertEqual(np.float32, input_details[0]['dtype'])
   1190     self.assertTrue(([1, 3] == input_details[0]['shape']).all())
   1191     self.assertEqual((0., 0.), input_details[0]['quantization'])
   1192 
   1193     output_details = interpreter.get_output_details()
   1194     self.assertEqual(1, len(output_details))
   1195     self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
   1196     self.assertEqual(np.float32, output_details[0]['dtype'])
   1197     self.assertTrue(([1, 3] == output_details[0]['shape']).all())
   1198     self.assertEqual((0., 0.), output_details[0]['quantization'])
   1199 
   1200     # Check inference of converted model.
   1201     input_data = np.array([[1, 2, 3]], dtype=np.float32)
   1202     interpreter.set_tensor(input_details[0]['index'], input_data)
   1203     interpreter.invoke()
   1204     tflite_result = interpreter.get_tensor(output_details[0]['index'])
   1205 
   1206     keras_model = keras.models.load_model(keras_file)
   1207     keras_result = keras_model.predict(input_data)
   1208 
   1209     np.testing.assert_almost_equal(tflite_result, keras_result, 5)
   1210     os.remove(keras_file)
   1211 
   1212   def testFunctionalModelMultipleInputs(self):
   1213     """Test a Functional tf.keras model with multiple inputs and outputs."""
   1214     with session.Session().as_default():
   1215       a = keras.layers.Input(shape=(3,), name='input_a')
   1216       b = keras.layers.Input(shape=(3,), name='input_b')
   1217       dense = keras.layers.Dense(4, name='dense')
   1218       c = dense(a)
   1219       d = dense(b)
   1220       e = keras.layers.Dropout(0.5, name='dropout')(c)
   1221 
   1222       model = keras.models.Model([a, b], [d, e])
   1223       model.compile(
   1224           loss=keras.losses.MSE,
   1225           optimizer=keras.optimizers.RMSprop(),
   1226           metrics=[keras.metrics.mae],
   1227           loss_weights=[1., 0.5])
   1228 
   1229       input_a_np = np.random.random((10, 3))
   1230       input_b_np = np.random.random((10, 3))
   1231       output_d_np = np.random.random((10, 4))
   1232       output_e_np = np.random.random((10, 4))
   1233       model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
   1234 
   1235       model.predict([input_a_np, input_b_np], batch_size=5)
   1236       fd, keras_file = tempfile.mkstemp('.h5')
   1237       try:
   1238         keras.models.save_model(model, keras_file)
   1239       finally:
   1240         os.close(fd)
   1241 
   1242     # Convert to TFLite model.
   1243     converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
   1244     tflite_model = converter.convert()
   1245     self.assertTrue(tflite_model)
   1246 
   1247     os.remove(keras_file)
   1248 
   1249     # Check values from converted model.
   1250     interpreter = Interpreter(model_content=tflite_model)
   1251     interpreter.allocate_tensors()
   1252 
   1253     input_details = interpreter.get_input_details()
   1254     self.assertEqual(2, len(input_details))
   1255     self.assertEqual('input_a', input_details[0]['name'])
   1256     self.assertEqual(np.float32, input_details[0]['dtype'])
   1257     self.assertTrue(([1, 3] == input_details[0]['shape']).all())
   1258     self.assertEqual((0., 0.), input_details[0]['quantization'])
   1259 
   1260     self.assertEqual('input_b', input_details[1]['name'])
   1261     self.assertEqual(np.float32, input_details[1]['dtype'])
   1262     self.assertTrue(([1, 3] == input_details[1]['shape']).all())
   1263     self.assertEqual((0., 0.), input_details[1]['quantization'])
   1264 
   1265     output_details = interpreter.get_output_details()
   1266     self.assertEqual(2, len(output_details))
   1267     self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
   1268     self.assertEqual(np.float32, output_details[0]['dtype'])
   1269     self.assertTrue(([1, 4] == output_details[0]['shape']).all())
   1270     self.assertEqual((0., 0.), output_details[0]['quantization'])
   1271 
   1272     self.assertEqual('dropout/Identity', output_details[1]['name'])
   1273     self.assertEqual(np.float32, output_details[1]['dtype'])
   1274     self.assertTrue(([1, 4] == output_details[1]['shape']).all())
   1275     self.assertEqual((0., 0.), output_details[1]['quantization'])
   1276 
   1277   def testFunctionalSequentialModel(self):
   1278     """Test a Functional tf.keras model containing a Sequential model."""
   1279     with session.Session().as_default():
   1280       model = keras.models.Sequential()
   1281       model.add(keras.layers.Dense(2, input_shape=(3,)))
   1282       model.add(keras.layers.RepeatVector(3))
   1283       model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
   1284       model = keras.models.Model(model.input, model.output)
   1285 
   1286       model.compile(
   1287           loss=keras.losses.MSE,
   1288           optimizer=keras.optimizers.RMSprop(),
   1289           metrics=[keras.metrics.categorical_accuracy],
   1290           sample_weight_mode='temporal')
   1291       x = np.random.random((1, 3))
   1292       y = np.random.random((1, 3, 3))
   1293       model.train_on_batch(x, y)
   1294       model.predict(x)
   1295 
   1296       model.predict(x)
   1297       fd, keras_file = tempfile.mkstemp('.h5')
   1298       try:
   1299         keras.models.save_model(model, keras_file)
   1300       finally:
   1301         os.close(fd)
   1302 
   1303     # Convert to TFLite model.
   1304     converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
   1305     tflite_model = converter.convert()
   1306     self.assertTrue(tflite_model)
   1307 
   1308     # Check tensor details of converted model.
   1309     interpreter = Interpreter(model_content=tflite_model)
   1310     interpreter.allocate_tensors()
   1311 
   1312     input_details = interpreter.get_input_details()
   1313     self.assertEqual(1, len(input_details))
   1314     self.assertEqual('dense_input', input_details[0]['name'])
   1315     self.assertEqual(np.float32, input_details[0]['dtype'])
   1316     self.assertTrue(([1, 3] == input_details[0]['shape']).all())
   1317     self.assertEqual((0., 0.), input_details[0]['quantization'])
   1318 
   1319     output_details = interpreter.get_output_details()
   1320     self.assertEqual(1, len(output_details))
   1321     self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
   1322     self.assertEqual(np.float32, output_details[0]['dtype'])
   1323     self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
   1324     self.assertEqual((0., 0.), output_details[0]['quantization'])
   1325 
   1326     # Check inference of converted model.
   1327     input_data = np.array([[1, 2, 3]], dtype=np.float32)
   1328     interpreter.set_tensor(input_details[0]['index'], input_data)
   1329     interpreter.invoke()
   1330     tflite_result = interpreter.get_tensor(output_details[0]['index'])
   1331 
   1332     keras_model = keras.models.load_model(keras_file)
   1333     keras_result = keras_model.predict(input_data)
   1334 
   1335     np.testing.assert_almost_equal(tflite_result, keras_result, 5)
   1336     os.remove(keras_file)
   1337 
   1338   def testSequentialModelTocoConverter(self):
   1339     """Test a Sequential tf.keras model with deprecated TocoConverter."""
   1340     keras_file = self._getSequentialModel()
   1341 
   1342     converter = lite.TocoConverter.from_keras_model_file(keras_file)
   1343     tflite_model = converter.convert()
   1344     self.assertTrue(tflite_model)
   1345 
   1346     # Ensure the model is able to load.
   1347     interpreter = Interpreter(model_content=tflite_model)
   1348     interpreter.allocate_tensors()
   1349 
   1350 
   1351 if __name__ == '__main__':
   1352   test.main()
   1353