Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tflite/import.h"
     16 
     17 #include "flatbuffers/flexbuffers.h"
     18 #include "tensorflow/contrib/lite/schema/schema_generated.h"
     19 #include "tensorflow/contrib/lite/toco/tflite/operator.h"
     20 #include "tensorflow/contrib/lite/toco/tflite/types.h"
     21 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     22 
     23 namespace toco {
     24 
     25 namespace tflite {
     26 
     27 namespace details {
     28 void LoadTensorsTable(const ::tflite::Model& input_model,
     29                       TensorsTable* tensors_table) {
     30   // TODO(aselle): add support to toco for multiple subgraphs.
     31   auto tensors = (*input_model.subgraphs())[0]->tensors();
     32   if (!tensors) return;
     33   for (const auto* tensor : *tensors) {
     34     tensors_table->push_back(tensor->name()->c_str());
     35   }
     36 }
     37 
     38 void LoadOperatorsTable(const ::tflite::Model& input_model,
     39                         OperatorsTable* operators_table) {
     40   auto opcodes = input_model.operator_codes();
     41   if (!opcodes) return;
     42   for (const auto* opcode : *opcodes) {
     43     if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
     44       operators_table->push_back(
     45           EnumNameBuiltinOperator(opcode->builtin_code()));
     46     } else {
     47       operators_table->push_back(opcode->custom_code()->c_str());
     48     }
     49   }
     50 }
     51 }  // namespace details
     52 
     53 void ImportTensors(const ::tflite::Model& input_model, Model* model) {
     54   auto tensors = (*input_model.subgraphs())[0]->tensors();
     55   auto* buffers = input_model.buffers();
     56   // auto tensors = input_model.tensors();
     57   if (!tensors) return;
     58   for (const auto* input_tensor : *tensors) {
     59     Array& array = model->GetOrCreateArray(input_tensor->name()->c_str());
     60     array.data_type = DataType::Deserialize(input_tensor->type());
     61     int buffer_index = input_tensor->buffer();
     62     auto* buffer = buffers->Get(buffer_index);
     63     DataBuffer::Deserialize(*input_tensor, *buffer, &array);
     64 
     65     auto shape = input_tensor->shape();
     66     if (shape) {
     67       for (int i = 0; i < shape->Length(); ++i) {
     68         auto d = shape->Get(i);
     69         array.mutable_shape()->mutable_dims()->push_back(d);
     70       }
     71     }
     72 
     73     auto quantization = input_tensor->quantization();
     74     if (quantization) {
     75       // Note that tf.mini only supports a single quantization parameters for
     76       // the whole array.
     77       if (quantization->min() && quantization->max()) {
     78         CHECK_EQ(1, quantization->min()->Length());
     79         CHECK_EQ(1, quantization->max()->Length());
     80         MinMax& minmax = array.GetOrCreateMinMax();
     81         minmax.min = quantization->min()->Get(0);
     82         minmax.max = quantization->max()->Get(0);
     83       }
     84       if (quantization->scale() && quantization->zero_point()) {
     85         CHECK_EQ(1, quantization->scale()->Length());
     86         CHECK_EQ(1, quantization->zero_point()->Length());
     87         QuantizationParams& q = array.GetOrCreateQuantizationParams();
     88         q.scale = quantization->scale()->Get(0);
     89         q.zero_point = quantization->zero_point()->Get(0);
     90       }
     91     }
     92   }
     93 }
     94 
     95 void ImportOperators(
     96     const ::tflite::Model& input_model,
     97     const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name,
     98     const details::TensorsTable& tensors_table,
     99     const details::OperatorsTable& operators_table, Model* model) {
    100   // TODO(aselle): add support for multiple subgraphs.
    101   auto ops = (*input_model.subgraphs())[0]->operators();
    102 
    103   if (!ops) return;
    104   for (const auto* input_op : *ops) {
    105     int index = input_op->opcode_index();
    106     if (index < 0 || index > operators_table.size()) {
    107       LOG(FATAL) << "Index " << index << " must be between zero and "
    108                  << operators_table.size();
    109     }
    110     string opname = operators_table.at(index);
    111     if (ops_by_name.count(opname) == 0) {
    112       LOG(FATAL) << "Op '" << opname << "' not supported";
    113     }
    114 
    115     auto new_op = ops_by_name.at(opname)->Deserialize(
    116         input_op->builtin_options(), input_op->custom_options());
    117     model->operators.emplace_back(new_op.release());
    118     auto* op = model->operators.back().get();
    119 
    120     auto inputs = input_op->inputs();
    121     for (int i = 0; i < inputs->Length(); i++) {
    122       auto input_index = inputs->Get(i);
    123       // input_index == -1 indicates optional tensor.
    124       if (input_index != -1) {
    125         const string& input_name = tensors_table.at(input_index);
    126         op->inputs.push_back(input_name);
    127       } else {
    128         const string& tensor_name =
    129             toco::AvailableArrayName(*model, "OptionalTensor");
    130         model->CreateOptionalArray(tensor_name);
    131         op->inputs.push_back(tensor_name);
    132       }
    133     }
    134     auto outputs = input_op->outputs();
    135     for (int i = 0; i < outputs->Length(); i++) {
    136       auto output_index = outputs->Get(i);
    137       const string& output_name = tensors_table.at(output_index);
    138       op->outputs.push_back(output_name);
    139     }
    140   }
    141 }
    142 
    143 void ImportIOTensors(const ::tflite::Model& input_model,
    144                      const details::TensorsTable& tensors_table, Model* model) {
    145   auto inputs = (*input_model.subgraphs())[0]->inputs();
    146   if (inputs) {
    147     for (int input : *inputs) {
    148       const string& input_name = tensors_table.at(input);
    149       model->flags.add_input_arrays()->set_name(input_name);
    150     }
    151   }
    152 
    153   auto outputs = (*input_model.subgraphs())[0]->outputs();
    154   if (outputs) {
    155     for (int output : *outputs) {
    156       const string& output_name = tensors_table.at(output);
    157       model->flags.add_output_arrays(output_name);
    158     }
    159   }
    160 }
    161 
    162 std::unique_ptr<Model> Import(const ModelFlags& model_flags,
    163                               const string& input_file_contents) {
    164   const ::tflite::Model* input_model =
    165       ::tflite::GetModel(input_file_contents.data());
    166 
    167   // Full list of all known operators.
    168   const auto ops_by_name = BuildOperatorByNameMap();
    169 
    170   if (input_model->subgraphs()->size() != 1) {
    171     LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now.";
    172   }
    173   std::unique_ptr<Model> model;
    174   model.reset(new Model);
    175 
    176   details::TensorsTable tensors_table;
    177   details::LoadTensorsTable(*input_model, &tensors_table);
    178 
    179   details::OperatorsTable operators_table;
    180   details::LoadOperatorsTable(*input_model, &operators_table);
    181 
    182   ImportTensors(*input_model, model.get());
    183   ImportOperators(*input_model, ops_by_name, tensors_table, operators_table,
    184                   model.get());
    185   ImportIOTensors(*input_model, tensors_table, model.get());
    186 
    187   return model;
    188 }
    189 
    190 }  // namespace tflite
    191 
    192 }  // namespace toco
    193