Home | History | Annotate | Download | only in fb_model
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/fb_model/model-provider-from-fb.h"
     18 
     19 #include "lang_id/common/file/file-utils.h"
     20 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
     21 #include "lang_id/common/flatbuffers/model-utils.h"
     22 #include "lang_id/common/lite_strings/str-split.h"
     23 
     24 namespace libtextclassifier3 {
     25 namespace mobile {
     26 namespace lang_id {
     27 
     28 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)
     29 
     30     // Using mmap as a fast way to read the model bytes.  As the file is
     31     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
     32     // stay alive for the entire lifetime of this object.
     33     : scoped_mmap_(new ScopedMmap(filename)) {
     34   Initialize(scoped_mmap_->handle().to_stringpiece());
     35 }
     36 
     37 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)
     38 
     39     // Using mmap as a fast way to read the model bytes.  As the file is
     40     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
     41     // stay alive for the entire lifetime of this object.
     42     : scoped_mmap_(new ScopedMmap(fd)) {
     43   Initialize(scoped_mmap_->handle().to_stringpiece());
     44 }
     45 
     46 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
     47   // Note: valid_ was initialized to false.  In the code below, we set valid_ to
     48   // true only if all initialization steps completed successfully.  Otherwise,
     49   // we return early, leaving valid_ to its default value false.
     50   model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
     51   if (model_ == nullptr) {
     52     SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
     53     return;
     54   }
     55 
     56   // Initialize context_ parameters.
     57   if (!saft_fbs::FillParameters(*model_, &context_)) {
     58     // FillParameters already performs error logging.
     59     return;
     60   }
     61 
     62   // Init languages_.
     63   const string known_languages_str = context_.Get("supported_languages", "");
     64   for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
     65     languages_.emplace_back(sp);
     66   }
     67   if (languages_.empty()) {
     68     SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
     69     return;
     70   }
     71 
     72   // Init nn_params_.
     73   if (!InitNetworkParams()) {
     74     // InitNetworkParams already performs error logging.
     75     return;
     76   }
     77 
     78   // Everything looks fine.
     79   valid_ = true;
     80 }
     81 
     82 bool ModelProviderFromFlatbuffer::InitNetworkParams() {
     83   const string kInputName = "language-identifier-network";
     84   StringPiece bytes =
     85       saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
     86   if ((bytes.data() == nullptr) || bytes.empty()) {
     87     SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
     88     return false;
     89   }
     90   std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
     91       new EmbeddingNetworkParamsFromFlatbuffer(bytes));
     92   if (!nn_params_from_fb->is_valid()) {
     93     SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
     94     return false;
     95   }
     96   nn_params_ = std::move(nn_params_from_fb);
     97   return true;
     98 }
     99 
    100 }  // namespace lang_id
    101 }  // namespace mobile
    102 }  // namespace nlp_saft
    103