Home | History | Annotate | Download | only in schema
      1 # ==============================================================================
      2 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 """Upgrade script to move from pre-release schema to new schema.
     16 
     17 Usage examples:
     18 
     19 bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json
     20 bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin
     21 bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json
     22 bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin
     23 bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite
     24 """
     25 from __future__ import absolute_import
     26 from __future__ import division
     27 from __future__ import print_function
     28 
     29 import argparse
     30 import contextlib
     31 import json
     32 import os
     33 import shutil
     34 import subprocess
     35 import sys
     36 import tempfile
     37 
     38 import tensorflow as tf
     39 from tensorflow.python.platform import resource_loader
     40 
     41 parser = argparse.ArgumentParser(
     42     description="Script to move TFLite models from pre-release schema to"
     43     " new schema.")
     44 parser.add_argument(
     45     "input",
     46     type=str,
     47     help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.")
     48 parser.add_argument(
     49     "output",
     50     type=str,
     51     help="Output json or bin TensorFlow lite model compliant with"
     52     "the new schema. Extension must be `.json`, `.bin` or `.tflite`.")
     53 
     54 
     55 # RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles.
     56 @contextlib.contextmanager
     57 def TemporaryDirectoryResource():
     58   temporary = tempfile.mkdtemp()
     59   try:
     60     yield temporary
     61   finally:
     62     shutil.rmtree(temporary)
     63 
     64 
     65 class Converter(object):
     66   """Converts TensorFlow flatbuffer models from old to new version of schema.
     67 
     68   This can convert between any version to the latest version. It uses
     69   an incremental upgrade strategy to go from version to version.
     70 
     71   Usage:
     72     converter = Converter()
     73     converter.Convert("a.tflite", "a.json")
     74     converter.Convert("b.json", "b.tflite")
     75   """
     76 
     77   def __init__(self):
     78     # TODO(aselle): make this work in the open source version with better
     79     # path.
     80     paths_to_try = [
     81         "../../../../flatbuffers/flatc",  # not bazel
     82         "../../../../external/flatbuffers/flatc"  # bazel
     83     ]
     84     for p in paths_to_try:
     85       self._flatc_path = resource_loader.get_path_to_datafile(p)
     86       if os.path.exists(self._flatc_path): break
     87 
     88     def FindSchema(base_name):
     89       return resource_loader.get_path_to_datafile("%s" % base_name)
     90 
     91     # Supported schemas for upgrade.
     92     self._schemas = [
     93         (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
     94         (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
     95         (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
     96         (3, FindSchema("schema_v3.fbs"), False, None)  # Non-callable by design.
     97     ]
     98     # Ensure schemas are sorted, and extract latest version and upgrade
     99     # dispatch function table.
    100     self._schemas.sort()
    101     self._new_version, self._new_schema = self._schemas[-1][:2]
    102     self._upgrade_dispatch = dict(
    103         (version, dispatch)
    104         for version, unused1, unused2, dispatch in self._schemas)
    105 
    106   def _Read(self, input_file, schema, raw_binary=False):
    107     """Read a tflite model assuming the given flatbuffer schema.
    108 
    109     If `input_file` is in bin, then we must use flatc to convert the schema
    110     from binary to json.
    111 
    112     Args:
    113       input_file: a binary (flatbuffer) or json file to read from. Extension
    114         must  be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or
    115         FlatBuffer JSON.
    116       schema: which schema to use for reading
    117       raw_binary: whether to assume raw_binary (versions previous to v3)
    118         that lacked file_identifier require this.
    119 
    120     Raises:
    121       RuntimeError: When flatc cannot be invoked.
    122       ValueError: When the extension is not json or bin.
    123 
    124     Returns:
    125       A dictionary representing the read tflite model.
    126     """
    127     raw_binary = ["--raw-binary"] if raw_binary else []
    128     with TemporaryDirectoryResource() as tempdir:
    129       basename = os.path.basename(input_file)
    130       basename_no_extension, extension = os.path.splitext(basename)
    131       if extension in [".bin", ".tflite"]:
    132         # Convert to json using flatc
    133         returncode = subprocess.call([
    134             self._flatc_path,
    135             "-t",
    136             "--strict-json",
    137             "--defaults-json",
    138         ] + raw_binary + ["-o", tempdir, schema, "--", input_file])
    139         if returncode != 0:
    140           raise RuntimeError("flatc failed to convert from binary to json.")
    141         json_file = os.path.join(tempdir, basename_no_extension + ".json")
    142         if not os.path.exists(json_file):
    143           raise RuntimeError("Could not find %r" % json_file)
    144       elif extension == ".json":
    145         json_file = input_file
    146       else:
    147         raise ValueError("Invalid extension on input file %r" % input_file)
    148       return json.load(open(json_file))
    149 
    150   def _Write(self, data, output_file):
    151     """Output a json or bin version of the flatbuffer model.
    152 
    153     Args:
    154       data: Dict representing the TensorFlow Lite model to write.
    155       output_file: filename to write the converted flatbuffer to. (json,
    156         tflite, or bin extension is required).
    157     Raises:
    158       ValueError: When the extension is not json or bin
    159       RuntimeError: When flatc fails to convert json data to binary.
    160     """
    161     _, extension = os.path.splitext(output_file)
    162     with TemporaryDirectoryResource() as tempdir:
    163       if extension == ".json":
    164         json.dump(data, open(output_file, "w"), sort_keys=True, indent=2)
    165       elif extension in [".tflite", ".bin"]:
    166         input_json = os.path.join(tempdir, "temp.json")
    167         with open(input_json, "w") as fp:
    168           json.dump(data, fp, sort_keys=True, indent=2)
    169         returncode = subprocess.call([
    170             self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o",
    171             tempdir, self._new_schema, input_json
    172         ])
    173         if returncode != 0:
    174           raise RuntimeError("flatc failed to convert upgraded json to binary.")
    175 
    176         shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file)
    177       else:
    178         raise ValueError("Invalid extension on output file %r" % output_file)
    179 
    180   def _Upgrade0To1(self, data):
    181     """Upgrade data from Version 0 to Version 1.
    182 
    183     Changes: Added subgraphs (which contains a subset of formally global
    184     entries).
    185 
    186     Args:
    187       data: Dictionary representing the TensorFlow lite data to be upgraded.
    188         This will be modified in-place to be an upgraded version.
    189     """
    190     subgraph = {}
    191     for key_to_promote in ["tensors", "operators", "inputs", "outputs"]:
    192       subgraph[key_to_promote] = data[key_to_promote]
    193       del data[key_to_promote]
    194     data["subgraphs"] = [subgraph]
    195 
    196   def _Upgrade1To2(self, data):
    197     """Upgrade data from Version 1 to Version 2.
    198 
    199     Changes: Rename operators to Conform to NN API.
    200 
    201     Args:
    202       data: Dictionary representing the TensorFlow lite data to be upgraded.
    203         This will be modified in-place to be an upgraded version.
    204     Raises:
    205       ValueError: Throws when model builtins are numeric rather than symbols.
    206     """
    207 
    208     def RemapOperator(opcode_name):
    209       """Go from old schema op name to new schema op name.
    210 
    211       Args:
    212         opcode_name: String representing the ops (see :schema.fbs).
    213       Returns:
    214         Converted opcode_name from V1 to V2.
    215       """
    216       old_name_to_new_name = {
    217           "CONVOLUTION": "CONV_2D",
    218           "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D",
    219           "AVERAGE_POOL": "AVERAGE_POOL_2D",
    220           "MAX_POOL": "MAX_POOL_2D",
    221           "L2_POOL": "L2_POOL_2D",
    222           "SIGMOID": "LOGISTIC",
    223           "L2NORM": "L2_NORMALIZATION",
    224           "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION",
    225           "Basic_RNN": "RNN",
    226       }
    227 
    228       return (old_name_to_new_name[opcode_name]
    229               if opcode_name in old_name_to_new_name else opcode_name)
    230 
    231     def RemapOperatorType(operator_type):
    232       """Remap operator structs from old names to new names.
    233 
    234       Args:
    235         operator_type: String representing the builtin operator data type
    236           string.
    237         (see :schema.fbs).
    238       Returns:
    239         Upgraded builtin operator data type as a string.
    240       """
    241       old_to_new = {
    242           "PoolOptions": "Pool2DOptions",
    243           "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions",
    244           "ConvolutionOptions": "Conv2DOptions",
    245           "LocalResponseNormOptions": "LocalResponseNormalizationOptions",
    246           "BasicRNNOptions": "RNNOptions",
    247       }
    248       return (old_to_new[operator_type]
    249               if operator_type in old_to_new else operator_type)
    250 
    251     for subgraph in data["subgraphs"]:
    252       for ops in subgraph["operators"]:
    253         ops["builtin_options_type"] = RemapOperatorType(
    254             ops["builtin_options_type"])
    255 
    256     # Upgrade the operator codes
    257     for operator_code in data["operator_codes"]:
    258       # Check if builtin_code is the appropriate string type
    259       # use type("") instead of str or unicode. for py2and3
    260       if not isinstance(operator_code["builtin_code"], type(u"")):
    261         raise ValueError("builtin_code %r is non-string. this usually means"
    262                          "your model has consistency problems." %
    263                          (operator_code["builtin_code"]))
    264       operator_code["builtin_code"] = (RemapOperator(
    265           operator_code["builtin_code"]))
    266 
    267   def _Upgrade2To3(self, data):
    268     """Upgrade data from Version 2 to Version 3.
    269 
    270     Changed actual read-only tensor data to be in a buffers table instead
    271     of inline with the tensor.
    272 
    273     Args:
    274       data: Dictionary representing the TensorFlow lite data to be upgraded.
    275         This will be modified in-place to be an upgraded version.
    276     """
    277     buffers = [{"data": []}]  # Start with 1 empty buffer
    278     for subgraph in data["subgraphs"]:
    279       if "tensors" not in subgraph:
    280         continue
    281       for tensor in subgraph["tensors"]:
    282         if "data_buffer" not in tensor:
    283           tensor["buffer"] = 0
    284         else:
    285           if tensor["data_buffer"]:
    286             tensor[u"buffer"] = len(buffers)
    287             buffers.append({"data": tensor["data_buffer"]})
    288           else:
    289             tensor["buffer"] = 0
    290           del tensor["data_buffer"]
    291     data["buffers"] = buffers
    292 
    293   def _PerformUpgrade(self, data):
    294     """Manipulate the `data` (parsed JSON) based on changes in format.
    295 
    296     This incrementally will upgrade from version to version within data.
    297 
    298     Args:
    299       data: Dictionary representing the TensorFlow data. This will be upgraded
    300         in place.
    301     """
    302     while data["version"] < self._new_version:
    303       self._upgrade_dispatch[data["version"]](data)
    304       data["version"] += 1
    305 
    306   def Convert(self, input_file, output_file):
    307     """Perform schema conversion from input_file to output_file.
    308 
    309     Args:
    310       input_file: Filename of TensorFlow Lite data to convert from. Must
    311         be `.json` or `.bin` extension files for JSON or Binary forms of
    312         the TensorFlow FlatBuffer schema.
    313       output_file: Filename to write to. Extension also must be `.json`
    314         or `.bin`.
    315 
    316     Raises:
    317       RuntimeError: Generated when none of the upgrader supported schemas
    318         matche the `input_file` data.
    319     """
    320     # Read data in each schema (since they are incompatible). Version is
    321     # always present. Use the read data that matches the version of the
    322     # schema.
    323     for version, schema, raw_binary, _ in self._schemas:
    324       try:
    325         data_candidate = self._Read(input_file, schema, raw_binary)
    326       except RuntimeError:
    327         continue  # Skip and hope another schema works
    328       if "version" not in data_candidate:  # Assume version 1 if not present.
    329         data_candidate["version"] = 1
    330       elif data_candidate["version"] == 0:  # Version 0 doesn't exist in wild.
    331         data_candidate["version"] = 1
    332 
    333       if data_candidate["version"] == version:
    334         self._PerformUpgrade(data_candidate)
    335         self._Write(data_candidate, output_file)
    336         return
    337     raise RuntimeError("No schema that the converter understands worked with "
    338                        "the data file you provided.")
    339 
    340 
    341 def main(argv):
    342   del argv
    343   Converter().Convert(FLAGS.input, FLAGS.output)
    344 
    345 
    346 if __name__ == "__main__":
    347   FLAGS, unparsed = parser.parse_known_args()
    348   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    349