1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/tflite/import.h" 16 17 #include "flatbuffers/flexbuffers.h" 18 #include "tensorflow/contrib/lite/schema/schema_generated.h" 19 #include "tensorflow/contrib/lite/toco/tflite/operator.h" 20 #include "tensorflow/contrib/lite/toco/tflite/types.h" 21 #include "tensorflow/contrib/lite/toco/tooling_util.h" 22 23 namespace toco { 24 25 namespace tflite { 26 27 namespace details { 28 void LoadTensorsTable(const ::tflite::Model& input_model, 29 TensorsTable* tensors_table) { 30 // TODO(aselle): add support to toco for multiple subgraphs. 31 auto tensors = (*input_model.subgraphs())[0]->tensors(); 32 if (!tensors) return; 33 for (const auto* tensor : *tensors) { 34 tensors_table->push_back(tensor->name()->c_str()); 35 } 36 } 37 38 void LoadOperatorsTable(const ::tflite::Model& input_model, 39 OperatorsTable* operators_table) { 40 auto opcodes = input_model.operator_codes(); 41 if (!opcodes) return; 42 for (const auto* opcode : *opcodes) { 43 if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { 44 operators_table->push_back( 45 EnumNameBuiltinOperator(opcode->builtin_code())); 46 } else { 47 operators_table->push_back(opcode->custom_code()->c_str()); 48 } 49 } 50 } 51 } // namespace details 52 53 void ImportTensors(const ::tflite::Model& input_model, Model* model) { 54 auto tensors = (*input_model.subgraphs())[0]->tensors(); 55 auto* buffers = input_model.buffers(); 56 // auto tensors = input_model.tensors(); 57 if (!tensors) return; 58 for (const auto* input_tensor : *tensors) { 59 Array& array = model->GetOrCreateArray(input_tensor->name()->c_str()); 60 array.data_type = DataType::Deserialize(input_tensor->type()); 61 int buffer_index = input_tensor->buffer(); 62 auto* buffer = buffers->Get(buffer_index); 63 DataBuffer::Deserialize(*input_tensor, *buffer, &array); 64 65 auto shape = input_tensor->shape(); 66 if (shape) { 67 for (int i = 0; i < shape->Length(); ++i) { 68 auto d = shape->Get(i); 69 array.mutable_shape()->mutable_dims()->push_back(d); 70 } 71 } 72 73 auto quantization = input_tensor->quantization(); 74 if (quantization) { 75 // Note that tf.mini only supports a single quantization parameters for 76 // the whole array. 77 if (quantization->min() && quantization->max()) { 78 CHECK_EQ(1, quantization->min()->Length()); 79 CHECK_EQ(1, quantization->max()->Length()); 80 MinMax& minmax = array.GetOrCreateMinMax(); 81 minmax.min = quantization->min()->Get(0); 82 minmax.max = quantization->max()->Get(0); 83 } 84 if (quantization->scale() && quantization->zero_point()) { 85 CHECK_EQ(1, quantization->scale()->Length()); 86 CHECK_EQ(1, quantization->zero_point()->Length()); 87 QuantizationParams& q = array.GetOrCreateQuantizationParams(); 88 q.scale = quantization->scale()->Get(0); 89 q.zero_point = quantization->zero_point()->Get(0); 90 } 91 } 92 } 93 } 94 95 void ImportOperators( 96 const ::tflite::Model& input_model, 97 const std::map<string, std::unique_ptr<BaseOperator>>& ops_by_name, 98 const details::TensorsTable& tensors_table, 99 const details::OperatorsTable& operators_table, Model* model) { 100 // TODO(aselle): add support for multiple subgraphs. 101 auto ops = (*input_model.subgraphs())[0]->operators(); 102 103 if (!ops) return; 104 for (const auto* input_op : *ops) { 105 int index = input_op->opcode_index(); 106 if (index < 0 || index > operators_table.size()) { 107 LOG(FATAL) << "Index " << index << " must be between zero and " 108 << operators_table.size(); 109 } 110 string opname = operators_table.at(index); 111 if (ops_by_name.count(opname) == 0) { 112 LOG(FATAL) << "Op '" << opname << "' not supported"; 113 } 114 115 auto new_op = ops_by_name.at(opname)->Deserialize( 116 input_op->builtin_options(), input_op->custom_options()); 117 model->operators.emplace_back(new_op.release()); 118 auto* op = model->operators.back().get(); 119 120 auto inputs = input_op->inputs(); 121 for (int i = 0; i < inputs->Length(); i++) { 122 auto input_index = inputs->Get(i); 123 // input_index == -1 indicates optional tensor. 124 if (input_index != -1) { 125 const string& input_name = tensors_table.at(input_index); 126 op->inputs.push_back(input_name); 127 } else { 128 const string& tensor_name = 129 toco::AvailableArrayName(*model, "OptionalTensor"); 130 model->CreateOptionalArray(tensor_name); 131 op->inputs.push_back(tensor_name); 132 } 133 } 134 auto outputs = input_op->outputs(); 135 for (int i = 0; i < outputs->Length(); i++) { 136 auto output_index = outputs->Get(i); 137 const string& output_name = tensors_table.at(output_index); 138 op->outputs.push_back(output_name); 139 } 140 } 141 } 142 143 void ImportIOTensors(const ::tflite::Model& input_model, 144 const details::TensorsTable& tensors_table, Model* model) { 145 auto inputs = (*input_model.subgraphs())[0]->inputs(); 146 if (inputs) { 147 for (int input : *inputs) { 148 const string& input_name = tensors_table.at(input); 149 model->flags.add_input_arrays()->set_name(input_name); 150 } 151 } 152 153 auto outputs = (*input_model.subgraphs())[0]->outputs(); 154 if (outputs) { 155 for (int output : *outputs) { 156 const string& output_name = tensors_table.at(output); 157 model->flags.add_output_arrays(output_name); 158 } 159 } 160 } 161 162 std::unique_ptr<Model> Import(const ModelFlags& model_flags, 163 const string& input_file_contents) { 164 const ::tflite::Model* input_model = 165 ::tflite::GetModel(input_file_contents.data()); 166 167 // Full list of all known operators. 168 const auto ops_by_name = BuildOperatorByNameMap(); 169 170 if (input_model->subgraphs()->size() != 1) { 171 LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now."; 172 } 173 std::unique_ptr<Model> model; 174 model.reset(new Model); 175 176 details::TensorsTable tensors_table; 177 details::LoadTensorsTable(*input_model, &tensors_table); 178 179 details::OperatorsTable operators_table; 180 details::LoadOperatorsTable(*input_model, &operators_table); 181 182 ImportTensors(*input_model, model.get()); 183 ImportOperators(*input_model, ops_by_name, tensors_table, operators_table, 184 model.get()); 185 ImportIOTensors(*input_model, tensors_table, model.get()); 186 187 return model; 188 } 189 190 } // namespace tflite 191 192 } // namespace toco 193