1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/lite/tools/verifier.h" 17 #include <climits> 18 #include "tensorflow/contrib/lite/schema/schema_generated.h" 19 #include "tensorflow/contrib/lite/string_util.h" 20 #include "tensorflow/contrib/lite/version.h" 21 22 namespace tflite { 23 24 namespace { 25 26 // Reports error message when the reporter is set. 27 void ReportError(ErrorReporter* error_reporter, const char* format, ...) { 28 if (error_reporter) { 29 va_list args; 30 va_start(args, format); 31 error_reporter->Report(format, args); 32 va_end(args); 33 } 34 } 35 36 // Returns the int32_t value pointed by ptr. 37 const uint32_t* GetIntPtr(const char* ptr) { 38 return reinterpret_cast<const uint32_t*>(ptr); 39 } 40 41 // Verifies flatbuffer format of the model contents and returns the in-memory 42 // model. 43 const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) { 44 ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len); 45 if (VerifyModelBuffer(verifier)) { 46 return ::tflite::GetModel(buf); 47 } else { 48 return nullptr; 49 } 50 } 51 52 const uint32_t kMaxNumString = UINT_MAX / sizeof(int32_t) - 2; 53 54 // Verifies string tensor has legit buffer contents that follow the schema 55 // defined in lite/string_util.h 56 bool VerifyStringTensorBuffer(const Buffer& buffer, 57 ErrorReporter* error_reporter) { 58 uint32_t buffer_size = buffer.data()->size(); 59 const char* buffer_ptr = reinterpret_cast<const char*>(buffer.data()->data()); 60 61 uint32_t num_strings = *GetIntPtr(buffer_ptr); 62 if (num_strings > kMaxNumString) { 63 ReportError(error_reporter, 64 "String tensor has invalid num of string set: %d", num_strings); 65 return false; 66 } 67 uint32_t header_offsets = 68 static_cast<uint32_t>(num_strings + 2) * sizeof(int32_t); 69 70 if (buffer_size < header_offsets) { 71 ReportError(error_reporter, 72 "String tensor buffer requires at least %d bytes, but is " 73 "allocated with %d bytes", 74 header_offsets, buffer_size); 75 return false; 76 } 77 78 uint32_t prev_ptr = header_offsets; 79 uint32_t offset = sizeof(int32_t); 80 81 if (*GetIntPtr(buffer_ptr + offset) != header_offsets) { 82 ReportError(error_reporter, 83 "String tensor buffer initial offset must be: %d", 84 header_offsets); 85 return false; 86 } 87 offset += sizeof(int32_t); 88 for (int i = 1; i <= num_strings; i++, offset += sizeof(int32_t)) { 89 int string_offset = *GetIntPtr(buffer_ptr + offset); 90 if (string_offset < prev_ptr || string_offset > buffer_size) { 91 ReportError(error_reporter, "String tensor buffer is invalid: index %d", 92 i); 93 return false; 94 } 95 } 96 if (*GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) { 97 ReportError(error_reporter, "String tensor buffer last offset must be %d", 98 buffer_size); 99 return false; 100 } 101 return true; 102 } 103 104 // Verifies numeric tensor has legit buffer. 105 bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, 106 ErrorReporter* error_reporter) { 107 uint64_t bytes_required = 1; 108 for (int dim : *tensor.shape()) { 109 bytes_required *= dim; 110 if (bytes_required > UINT_MAX) { 111 ReportError(error_reporter, "Tensor dimension overflow"); 112 return false; 113 } 114 } 115 switch (tensor.type()) { 116 case TensorType_FLOAT32: 117 bytes_required *= sizeof(float); 118 break; 119 case TensorType_INT32: 120 bytes_required *= sizeof(int32_t); 121 break; 122 case TensorType_UINT8: 123 bytes_required *= sizeof(uint8_t); 124 break; 125 case TensorType_INT64: 126 bytes_required *= sizeof(int64_t); 127 break; 128 case TensorType_FLOAT16: 129 // FALLTHROUGH_INTENDED; 130 default: 131 ReportError(error_reporter, "Invalid tensor type: %d", tensor.type()); 132 return false; 133 } 134 if (bytes_required > UINT_MAX) { 135 ReportError(error_reporter, "Tensor dimension overflow"); 136 return false; 137 } 138 139 if (bytes_required != buffer.data()->size()) { 140 ReportError( 141 error_reporter, 142 "Tensor requires %d bytes, but is allocated with %d bytes buffer", 143 bytes_required, buffer.data()->size()); 144 return false; 145 } 146 return true; 147 148 // TODO(yichengfan): verify quantized tensors. 149 } 150 151 // Verifies tensors have valid properties and legit buffer if set. 152 bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { 153 if (!model.subgraphs()) { 154 return true; 155 } 156 for (const auto& subgraph : *model.subgraphs()) { 157 if (!subgraph->tensors()) { 158 continue; 159 } 160 for (const auto& tensor : *subgraph->tensors()) { 161 if (!tensor->buffer()) { 162 continue; 163 } 164 if (tensor->buffer() >= model.buffers()->size()) { 165 ReportError(error_reporter, "Invalid tensor buffer index: %d", 166 tensor->buffer()); 167 return false; 168 } 169 auto* buffer = model.buffers()->Get(tensor->buffer()); 170 if (!buffer || !buffer->data()) { 171 ReportError(error_reporter, "Tensor buffer %d not set", 172 tensor->buffer()); 173 return false; 174 } 175 176 if (tensor->type() == TensorType_STRING) { 177 if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { 178 return false; 179 } 180 } else { 181 if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { 182 return false; 183 } 184 } 185 } 186 } 187 return true; 188 } 189 190 bool VerifyOps(const Model& model, const OpResolver& resolver, 191 ErrorReporter* error_reporter) { 192 if (!model.operator_codes()) { 193 return true; 194 } 195 for (const auto& opcode : *model.operator_codes()) { 196 if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { 197 if (!resolver.FindOp(opcode->custom_code()->c_str())) { 198 ReportError(error_reporter, "Unsupported custom op: %s", 199 opcode->custom_code()->c_str()); 200 return false; 201 } 202 } else { 203 if (!resolver.FindOp(opcode->builtin_code())) { 204 ReportError(error_reporter, "Unsupported builtin op: %s", 205 EnumNameBuiltinOperator(opcode->builtin_code())); 206 return false; 207 } 208 } 209 } 210 return true; 211 } 212 213 } // namespace 214 215 bool Verify(const void* buf, size_t len, const OpResolver& resolver, 216 ErrorReporter* error_reporter) { 217 const Model* model = VerifyFlatbufferAndGetModel(buf, len); 218 if (model == nullptr) { 219 ReportError(error_reporter, "Invalid flatbuffer format"); 220 return false; 221 } 222 if (model->version() != TFLITE_SCHEMA_VERSION) { 223 ReportError(error_reporter, "Invalid model version %d", model->version()); 224 return false; 225 } 226 if (!VerifyTensors(*model, error_reporter)) { 227 return false; 228 } 229 if (!VerifyOps(*model, resolver, error_reporter)) { 230 return false; 231 } 232 return true; 233 } 234 } // namespace tflite 235