1 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python wrapper for post training quantization with calibration.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.util.lazy_loader import LazyLoader 21 22 # Lazy load since some of the performance benchmark skylark rules 23 # break dependencies. Must use double quotes to match code internal rewrite 24 # rule. 25 _calibration_wrapper = LazyLoader( 26 "_calibration_wrapper", globals(), 27 "tensorflow.lite.python.optimize." 28 "tensorflow_lite_wrap_calibration_wrapper") 29 30 31 class Calibrator(object): 32 """Calibrates a floating point model and then quantizes it. 33 34 This is an internal class, not a public interface. 35 """ 36 37 def __init__(self, model_content): 38 """Constructor. 39 40 Args: 41 model_content: Content of a TF-Lite Flatbuffer file. 42 43 Raises: 44 ValueError: If the calibrator was unable to open the model. 45 """ 46 if not model_content: 47 raise ValueError("`model_content` must be specified.") 48 try: 49 self._calibrator = (_calibration_wrapper.CalibrationWrapper 50 .CreateWrapperCPPFromBuffer(model_content)) 51 except Exception as e: 52 raise ValueError("Failed to parse the model: %s." % e) 53 if not self._calibrator: 54 raise ValueError("Failed to parse the model.") 55 56 def calibrate_and_quantize(self, dataset_gen): 57 """Calibrates the model with specified generator and then quantizes it. 58 59 Returns: 60 A quantized model. 61 62 Args: 63 dataset_gen: A generator that generates calibration samples. 64 """ 65 self._calibrator.Prepare() 66 for calibration_sample in dataset_gen(): 67 self._calibrator.FeedTensor(calibration_sample) 68 return self._calibrator.QuantizeModel() 69