Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "smartselect/model-params.h"
     18 
     19 #include "common/memory_image/memory-image-reader.h"
     20 
     21 namespace libtextclassifier {
     22 
     23 using nlp_core::EmbeddingNetworkProto;
     24 using nlp_core::MemoryImageReader;
     25 
     26 ModelParams* ModelParamsBuilder(
     27     const void* start, uint64 num_bytes,
     28     std::shared_ptr<EmbeddingParams> external_embedding_params) {
     29   MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes);
     30 
     31   ModelOptions model_options;
     32   auto model_options_extension_id = model_options_in_embedding_network_proto;
     33   if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
     34     model_options =
     35         reader.trimmed_proto().GetExtension(model_options_extension_id);
     36   }
     37 
     38   FeatureProcessorOptions feature_processor_options;
     39   auto feature_processor_extension_id =
     40       feature_processor_options_in_embedding_network_proto;
     41   if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) {
     42     feature_processor_options =
     43         reader.trimmed_proto().GetExtension(feature_processor_extension_id);
     44 
     45     // If no tokenization codepoint config is present, tokenize on space.
     46     if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
     47       TokenizationCodepointRange* config;
     48       // New line character.
     49       config = feature_processor_options.add_tokenization_codepoint_config();
     50       config->set_start(10);
     51       config->set_end(11);
     52       config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
     53 
     54       // Space character.
     55       config = feature_processor_options.add_tokenization_codepoint_config();
     56       config->set_start(32);
     57       config->set_end(33);
     58       config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
     59     }
     60   } else {
     61     return nullptr;
     62   }
     63 
     64   SelectionModelOptions selection_options;
     65   auto selection_options_extension_id =
     66       selection_model_options_in_embedding_network_proto;
     67   if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
     68     selection_options =
     69         reader.trimmed_proto().GetExtension(selection_options_extension_id);
     70   } else {
     71     // Default values when SelectionModelOptions is not present.
     72     for (const auto codepoint_pair : std::vector<std::pair<int, int>>(
     73              {{33, 35},       {37, 39},       {42, 42},       {44, 47},
     74               {58, 59},       {63, 64},       {91, 93},       {95, 95},
     75               {123, 123},     {125, 125},     {161, 161},     {171, 171},
     76               {183, 183},     {187, 187},     {191, 191},     {894, 894},
     77               {903, 903},     {1370, 1375},   {1417, 1418},   {1470, 1470},
     78               {1472, 1472},   {1475, 1475},   {1478, 1478},   {1523, 1524},
     79               {1548, 1549},   {1563, 1563},   {1566, 1567},   {1642, 1645},
     80               {1748, 1748},   {1792, 1805},   {2404, 2405},   {2416, 2416},
     81               {3572, 3572},   {3663, 3663},   {3674, 3675},   {3844, 3858},
     82               {3898, 3901},   {3973, 3973},   {4048, 4049},   {4170, 4175},
     83               {4347, 4347},   {4961, 4968},   {5741, 5742},   {5787, 5788},
     84               {5867, 5869},   {5941, 5942},   {6100, 6102},   {6104, 6106},
     85               {6144, 6154},   {6468, 6469},   {6622, 6623},   {6686, 6687},
     86               {8208, 8231},   {8240, 8259},   {8261, 8273},   {8275, 8286},
     87               {8317, 8318},   {8333, 8334},   {9001, 9002},   {9140, 9142},
     88               {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648},
     89               {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519},
     90               {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305},
     91               {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448},
     92               {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106},
     93               {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131},
     94               {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307},
     95               {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371},
     96               {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463},
     97               {68176, 68184}})) {
     98       for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) {
     99         selection_options.add_punctuation_to_strip(i);
    100       }
    101       selection_options.set_strip_punctuation(true);
    102       selection_options.set_enforce_symmetry(true);
    103       selection_options.set_symmetry_context_size(
    104           feature_processor_options.context_size() * 2);
    105     }
    106   }
    107 
    108   SharingModelOptions sharing_options;
    109   auto sharing_options_extension_id =
    110       sharing_model_options_in_embedding_network_proto;
    111   if (reader.trimmed_proto().HasExtension(sharing_options_extension_id)) {
    112     sharing_options =
    113         reader.trimmed_proto().GetExtension(sharing_options_extension_id);
    114   } else {
    115     // Default values when SharingModelOptions is not present.
    116     sharing_options.set_always_accept_url_hint(true);
    117     sharing_options.set_always_accept_email_hint(true);
    118   }
    119 
    120   if (!model_options.use_shared_embeddings()) {
    121     std::shared_ptr<EmbeddingParams> embedding_params(new EmbeddingParams(
    122         start, num_bytes, feature_processor_options.context_size()));
    123     return new ModelParams(start, num_bytes, embedding_params,
    124                            selection_options, sharing_options,
    125                            feature_processor_options);
    126   } else {
    127     return new ModelParams(
    128         start, num_bytes, std::move(external_embedding_params),
    129         selection_options, sharing_options, feature_processor_options);
    130   }
    131 }
    132 
    133 }  // namespace libtextclassifier
    134