1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "lang_id/fb_model/model-provider-from-fb.h" 18 19 #include "lang_id/common/file/file-utils.h" 20 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h" 21 #include "lang_id/common/flatbuffers/model-utils.h" 22 #include "lang_id/common/lite_strings/str-split.h" 23 24 namespace libtextclassifier3 { 25 namespace mobile { 26 namespace lang_id { 27 28 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename) 29 30 // Using mmap as a fast way to read the model bytes. As the file is 31 // unmapped only when the field scoped_mmap_ is destructed, the model bytes 32 // stay alive for the entire lifetime of this object. 33 : scoped_mmap_(new ScopedMmap(filename)) { 34 Initialize(scoped_mmap_->handle().to_stringpiece()); 35 } 36 37 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd) 38 39 // Using mmap as a fast way to read the model bytes. As the file is 40 // unmapped only when the field scoped_mmap_ is destructed, the model bytes 41 // stay alive for the entire lifetime of this object. 42 : scoped_mmap_(new ScopedMmap(fd)) { 43 Initialize(scoped_mmap_->handle().to_stringpiece()); 44 } 45 46 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) { 47 // Note: valid_ was initialized to false. In the code below, we set valid_ to 48 // true only if all initialization steps completed successfully. Otherwise, 49 // we return early, leaving valid_ to its default value false. 50 model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes); 51 if (model_ == nullptr) { 52 SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer"; 53 return; 54 } 55 56 // Initialize context_ parameters. 57 if (!saft_fbs::FillParameters(*model_, &context_)) { 58 // FillParameters already performs error logging. 59 return; 60 } 61 62 // Init languages_. 63 const string known_languages_str = context_.Get("supported_languages", ""); 64 for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) { 65 languages_.emplace_back(sp); 66 } 67 if (languages_.empty()) { 68 SAFTM_LOG(ERROR) << "Unable to find list of supported_languages"; 69 return; 70 } 71 72 // Init nn_params_. 73 if (!InitNetworkParams()) { 74 // InitNetworkParams already performs error logging. 75 return; 76 } 77 78 // Everything looks fine. 79 valid_ = true; 80 } 81 82 bool ModelProviderFromFlatbuffer::InitNetworkParams() { 83 const string kInputName = "language-identifier-network"; 84 StringPiece bytes = 85 saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName)); 86 if ((bytes.data() == nullptr) || bytes.empty()) { 87 SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName; 88 return false; 89 } 90 std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb( 91 new EmbeddingNetworkParamsFromFlatbuffer(bytes)); 92 if (!nn_params_from_fb->is_valid()) { 93 SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid"; 94 return false; 95 } 96 nn_params_ = std::move(nn_params_from_fb); 97 return true; 98 } 99 100 } // namespace lang_id 101 } // namespace mobile 102 } // namespace nlp_saft 103