Home | History | Annotate | Download | only in flatbuffers
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/common/flatbuffers/model-utils.h"
     18 
     19 #include <string.h>
     20 
     21 #include "lang_id/common/lite_base/logging.h"
     22 #include "lang_id/common/math/checksum.h"
     23 
     24 namespace libtextclassifier3 {
     25 namespace saft_fbs {
     26 
     27 namespace {
     28 
     29 // Returns true if we have clear evidence that |model| fails its checksum.
     30 //
     31 // E.g., if |model| has the crc32 field, and the value of that field does not
     32 // match the checksum, then this function returns true.  If there is no crc32
     33 // field, then we don't know what the original (at build time) checksum was, so
     34 // we don't know anything clear and this function returns false.
     35 bool ClearlyFailsChecksum(const Model &model) {
     36   if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) {
     37     SAFTM_LOG(WARNING)
     38         << "No CRC32, most likely an old model; skip CRC32 check";
     39     return false;
     40   }
     41   const mobile::uint32 expected_crc32 = model.crc32();
     42   const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model);
     43   if (actual_crc32 != expected_crc32) {
     44     SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32
     45                      << " vs " << expected_crc32;
     46     return true;
     47   }
     48   SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32;
     49   return false;
     50 }
     51 }  // namespace
     52 
     53 const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) {
     54   if ((data == nullptr) || (num_bytes == 0)) {
     55     SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes";
     56     return nullptr;
     57   }
     58   const uint8_t *start = reinterpret_cast<const uint8_t *>(data);
     59   flatbuffers::Verifier verifier(start, num_bytes);
     60   if (!VerifyModelBuffer(verifier)) {
     61     SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer";
     62     return nullptr;
     63   }
     64   const Model *model = GetModel(start);
     65   if (model == nullptr) {
     66     return nullptr;
     67   }
     68   if (ClearlyFailsChecksum(*model)) {
     69     return nullptr;
     70   }
     71   return model;
     72 }
     73 
     74 const ModelInput *GetInputByName(const Model *model, const string &name) {
     75   if (model == nullptr) {
     76     SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr";
     77     return nullptr;
     78   }
     79   const auto *inputs = model->inputs();
     80   if (inputs == nullptr) {
     81     // We should always have a list of inputs; maybe an empty one, if no inputs,
     82     // but the list should be there.
     83     SAFTM_LOG(ERROR) << "null inputs";
     84     return nullptr;
     85   }
     86   for (const ModelInput *input : *inputs) {
     87     if (input != nullptr) {
     88       const flatbuffers::String *input_name = input->name();
     89       if (input_name && input_name->str() == name) {
     90         return input;
     91       }
     92     }
     93   }
     94   return nullptr;
     95 }
     96 
     97 mobile::StringPiece GetInputBytes(const ModelInput *input) {
     98   if ((input == nullptr) || (input->data() == nullptr)) {
     99     SAFTM_LOG(ERROR) << "ModelInput has no content";
    100     return mobile::StringPiece(nullptr, 0);
    101   }
    102   const flatbuffers::Vector<uint8_t> *input_data = input->data();
    103   if (input_data == nullptr) {
    104     SAFTM_LOG(ERROR) << "null input data";
    105     return mobile::StringPiece(nullptr, 0);
    106   }
    107   return mobile::StringPiece(reinterpret_cast<const char *>(input_data->data()),
    108                              input_data->size());
    109 }
    110 
    111 bool FillParameters(const Model &model, mobile::TaskContext *context) {
    112   if (context == nullptr) {
    113     SAFTM_LOG(ERROR) << "null context";
    114     return false;
    115   }
    116   const auto *parameters = model.parameters();
    117   if (parameters == nullptr) {
    118     // We should always have a list of parameters; maybe an empty one, if no
    119     // parameters, but the list should be there.
    120     SAFTM_LOG(ERROR) << "null list of parameters";
    121     return false;
    122   }
    123   for (const ModelParameter *p : *parameters) {
    124     if (p == nullptr) {
    125       SAFTM_LOG(ERROR) << "null parameter";
    126       return false;
    127     }
    128     if (p->name() == nullptr) {
    129       SAFTM_LOG(ERROR) << "null parameter name";
    130       return false;
    131     }
    132     const string name = p->name()->str();
    133     if (name.empty()) {
    134       SAFTM_LOG(ERROR) << "empty parameter name";
    135       return false;
    136     }
    137     if (p->value() == nullptr) {
    138       SAFTM_LOG(ERROR) << "null parameter name";
    139       return false;
    140     }
    141     context->SetParameter(name, p->value()->str());
    142   }
    143   return true;
    144 }
    145 
    146 namespace {
    147 // Updates |*crc| with the information from |s|.  Auxiliary for
    148 // ComputeCrc2Checksum.
    149 //
    150 // The bytes from |info| are also used to update the CRC32 checksum.  |info|
    151 // should be a brief tag that indicates what |s| represents.  The idea is to add
    152 // some structure to the information that goes into the CRC32 computation.
    153 template <typename T>
    154 void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector<T> *s,
    155                mobile::StringPiece info) {
    156   crc->Update("|");
    157   crc->Update(info.data(), info.size());
    158   crc->Update(":");
    159   if (s == nullptr) {
    160     crc->Update("empty");
    161   } else {
    162     crc->Update(reinterpret_cast<const char *>(s->data()),
    163                 s->size() * sizeof(T));
    164   }
    165 }
    166 }  // namespace
    167 
    168 mobile::uint32 ComputeCrc2Checksum(const Model *model) {
    169   // Implementation note: originally, I (salcianu@) thought we can just compute
    170   // a CRC32 checksum of the model bytes.  Unfortunately, the expected checksum
    171   // is there too (and because we don't control the flatbuffer format, we can't
    172   // "arrange" for it to be placed at the head / tail of those bytes).  Instead,
    173   // we traverse |model| and feed into the CRC32 computation those parts we are
    174   // interested in (which excludes the crc32 field).
    175   //
    176   // Note: storing the checksum outside the Model would be too disruptive for
    177   // the way we currently ship our models.
    178   mobile::Crc32 crc;
    179   if (model == nullptr) {
    180     return crc.Get();
    181   }
    182   crc.Update("|Parameters:");
    183   const auto *parameters = model->parameters();
    184   if (parameters != nullptr) {
    185     for (const ModelParameter *p : *parameters) {
    186       if (p != nullptr) {
    187         UpdateCrc(&crc, p->name(), "name");
    188         UpdateCrc(&crc, p->value(), "value");
    189       }
    190     }
    191   }
    192   crc.Update("|Inputs:");
    193   const auto *inputs = model->inputs();
    194   if (inputs != nullptr) {
    195     for (const ModelInput *input : *inputs) {
    196       if (input != nullptr) {
    197         UpdateCrc(&crc, input->name(), "name");
    198         UpdateCrc(&crc, input->type(), "type");
    199         UpdateCrc(&crc, input->sub_type(), "sub-type");
    200         UpdateCrc(&crc, input->data(), "data");
    201       }
    202     }
    203   }
    204   return crc.Get();
    205 }
    206 
    207 }  // namespace saft_fbs
    208 }  // namespace nlp_saft
    209