Home | History | Annotate | Download | only in tools
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/lite/tools/verifier.h"
     17 #include <climits>
     18 #include "tensorflow/contrib/lite/schema/schema_generated.h"
     19 #include "tensorflow/contrib/lite/string_util.h"
     20 #include "tensorflow/contrib/lite/version.h"
     21 
     22 namespace tflite {
     23 
     24 namespace {
     25 
     26 // Reports error message when the reporter is set.
     27 void ReportError(ErrorReporter* error_reporter, const char* format, ...) {
     28   if (error_reporter) {
     29     va_list args;
     30     va_start(args, format);
     31     error_reporter->Report(format, args);
     32     va_end(args);
     33   }
     34 }
     35 
     36 // Returns the int32_t value pointed by ptr.
     37 const uint32_t* GetIntPtr(const char* ptr) {
     38   return reinterpret_cast<const uint32_t*>(ptr);
     39 }
     40 
     41 // Verifies flatbuffer format of the model contents and returns the in-memory
     42 // model.
     43 const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) {
     44   ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
     45   if (VerifyModelBuffer(verifier)) {
     46     return ::tflite::GetModel(buf);
     47   } else {
     48     return nullptr;
     49   }
     50 }
     51 
     52 const uint32_t kMaxNumString = UINT_MAX / sizeof(int32_t) - 2;
     53 
     54 // Verifies string tensor has legit buffer contents that follow the schema
     55 // defined in lite/string_util.h
     56 bool VerifyStringTensorBuffer(const Buffer& buffer,
     57                               ErrorReporter* error_reporter) {
     58   uint32_t buffer_size = buffer.data()->size();
     59   const char* buffer_ptr = reinterpret_cast<const char*>(buffer.data()->data());
     60 
     61   uint32_t num_strings = *GetIntPtr(buffer_ptr);
     62   if (num_strings > kMaxNumString) {
     63     ReportError(error_reporter,
     64                 "String tensor has invalid num of string set: %d", num_strings);
     65     return false;
     66   }
     67   uint32_t header_offsets =
     68       static_cast<uint32_t>(num_strings + 2) * sizeof(int32_t);
     69 
     70   if (buffer_size < header_offsets) {
     71     ReportError(error_reporter,
     72                 "String tensor buffer requires at least %d bytes, but is "
     73                 "allocated with %d bytes",
     74                 header_offsets, buffer_size);
     75     return false;
     76   }
     77 
     78   uint32_t prev_ptr = header_offsets;
     79   uint32_t offset = sizeof(int32_t);
     80 
     81   if (*GetIntPtr(buffer_ptr + offset) != header_offsets) {
     82     ReportError(error_reporter,
     83                 "String tensor buffer initial offset must be: %d",
     84                 header_offsets);
     85     return false;
     86   }
     87   offset += sizeof(int32_t);
     88   for (int i = 1; i <= num_strings; i++, offset += sizeof(int32_t)) {
     89     int string_offset = *GetIntPtr(buffer_ptr + offset);
     90     if (string_offset < prev_ptr || string_offset > buffer_size) {
     91       ReportError(error_reporter, "String tensor buffer is invalid: index %d",
     92                   i);
     93       return false;
     94     }
     95   }
     96   if (*GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) {
     97     ReportError(error_reporter, "String tensor buffer last offset must be %d",
     98                 buffer_size);
     99     return false;
    100   }
    101   return true;
    102 }
    103 
    104 // Verifies numeric tensor has legit buffer.
    105 bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
    106                                ErrorReporter* error_reporter) {
    107   uint64_t bytes_required = 1;
    108   for (int dim : *tensor.shape()) {
    109     bytes_required *= dim;
    110     if (bytes_required > UINT_MAX) {
    111       ReportError(error_reporter, "Tensor dimension overflow");
    112       return false;
    113     }
    114   }
    115   switch (tensor.type()) {
    116     case TensorType_FLOAT32:
    117       bytes_required *= sizeof(float);
    118       break;
    119     case TensorType_INT32:
    120       bytes_required *= sizeof(int32_t);
    121       break;
    122     case TensorType_UINT8:
    123       bytes_required *= sizeof(uint8_t);
    124       break;
    125     case TensorType_INT64:
    126       bytes_required *= sizeof(int64_t);
    127       break;
    128     case TensorType_FLOAT16:
    129       // FALLTHROUGH_INTENDED;
    130     default:
    131       ReportError(error_reporter, "Invalid tensor type: %d", tensor.type());
    132       return false;
    133   }
    134   if (bytes_required > UINT_MAX) {
    135     ReportError(error_reporter, "Tensor dimension overflow");
    136     return false;
    137   }
    138 
    139   if (bytes_required != buffer.data()->size()) {
    140     ReportError(
    141         error_reporter,
    142         "Tensor requires %d bytes, but is allocated with %d bytes buffer",
    143         bytes_required, buffer.data()->size());
    144     return false;
    145   }
    146   return true;
    147 
    148   // TODO(yichengfan): verify quantized tensors.
    149 }
    150 
    151 // Verifies tensors have valid properties and legit buffer if set.
    152 bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
    153   if (!model.subgraphs()) {
    154     return true;
    155   }
    156   for (const auto& subgraph : *model.subgraphs()) {
    157     if (!subgraph->tensors()) {
    158       continue;
    159     }
    160     for (const auto& tensor : *subgraph->tensors()) {
    161       if (!tensor->buffer()) {
    162         continue;
    163       }
    164       if (tensor->buffer() >= model.buffers()->size()) {
    165         ReportError(error_reporter, "Invalid tensor buffer index: %d",
    166                     tensor->buffer());
    167         return false;
    168       }
    169       auto* buffer = model.buffers()->Get(tensor->buffer());
    170       if (!buffer || !buffer->data()) {
    171         ReportError(error_reporter, "Tensor buffer %d not set",
    172                     tensor->buffer());
    173         return false;
    174       }
    175 
    176       if (tensor->type() == TensorType_STRING) {
    177         if (!VerifyStringTensorBuffer(*buffer, error_reporter)) {
    178           return false;
    179         }
    180       } else {
    181         if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) {
    182           return false;
    183         }
    184       }
    185     }
    186   }
    187   return true;
    188 }
    189 
    190 bool VerifyOps(const Model& model, const OpResolver& resolver,
    191                ErrorReporter* error_reporter) {
    192   if (!model.operator_codes()) {
    193     return true;
    194   }
    195   for (const auto& opcode : *model.operator_codes()) {
    196     if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
    197       if (!resolver.FindOp(opcode->custom_code()->c_str())) {
    198         ReportError(error_reporter, "Unsupported custom op: %s",
    199                     opcode->custom_code()->c_str());
    200         return false;
    201       }
    202     } else {
    203       if (!resolver.FindOp(opcode->builtin_code())) {
    204         ReportError(error_reporter, "Unsupported builtin op: %s",
    205                     EnumNameBuiltinOperator(opcode->builtin_code()));
    206         return false;
    207       }
    208     }
    209   }
    210   return true;
    211 }
    212 
    213 }  // namespace
    214 
    215 bool Verify(const void* buf, size_t len, const OpResolver& resolver,
    216             ErrorReporter* error_reporter) {
    217   const Model* model = VerifyFlatbufferAndGetModel(buf, len);
    218   if (model == nullptr) {
    219     ReportError(error_reporter, "Invalid flatbuffer format");
    220     return false;
    221   }
    222   if (model->version() != TFLITE_SCHEMA_VERSION) {
    223     ReportError(error_reporter, "Invalid model version %d", model->version());
    224     return false;
    225   }
    226   if (!VerifyTensors(*model, error_reporter)) {
    227     return false;
    228   }
    229   if (!VerifyOps(*model, resolver, error_reporter)) {
    230     return false;
    231   }
    232   return true;
    233 }
    234 }  // namespace tflite
    235