1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "smartselect/model-params.h" 18 19 #include "common/memory_image/memory-image-reader.h" 20 21 namespace libtextclassifier { 22 23 using nlp_core::EmbeddingNetworkProto; 24 using nlp_core::MemoryImageReader; 25 26 ModelParams* ModelParamsBuilder( 27 const void* start, uint64 num_bytes, 28 std::shared_ptr<EmbeddingParams> external_embedding_params) { 29 MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes); 30 31 ModelOptions model_options; 32 auto model_options_extension_id = model_options_in_embedding_network_proto; 33 if (reader.trimmed_proto().HasExtension(model_options_extension_id)) { 34 model_options = 35 reader.trimmed_proto().GetExtension(model_options_extension_id); 36 } 37 38 FeatureProcessorOptions feature_processor_options; 39 auto feature_processor_extension_id = 40 feature_processor_options_in_embedding_network_proto; 41 if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) { 42 feature_processor_options = 43 reader.trimmed_proto().GetExtension(feature_processor_extension_id); 44 45 // If no tokenization codepoint config is present, tokenize on space. 46 if (feature_processor_options.tokenization_codepoint_config_size() == 0) { 47 TokenizationCodepointRange* config; 48 // New line character. 49 config = feature_processor_options.add_tokenization_codepoint_config(); 50 config->set_start(10); 51 config->set_end(11); 52 config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); 53 54 // Space character. 55 config = feature_processor_options.add_tokenization_codepoint_config(); 56 config->set_start(32); 57 config->set_end(33); 58 config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); 59 } 60 } else { 61 return nullptr; 62 } 63 64 SelectionModelOptions selection_options; 65 auto selection_options_extension_id = 66 selection_model_options_in_embedding_network_proto; 67 if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) { 68 selection_options = 69 reader.trimmed_proto().GetExtension(selection_options_extension_id); 70 } else { 71 // Default values when SelectionModelOptions is not present. 72 for (const auto codepoint_pair : std::vector<std::pair<int, int>>( 73 {{33, 35}, {37, 39}, {42, 42}, {44, 47}, 74 {58, 59}, {63, 64}, {91, 93}, {95, 95}, 75 {123, 123}, {125, 125}, {161, 161}, {171, 171}, 76 {183, 183}, {187, 187}, {191, 191}, {894, 894}, 77 {903, 903}, {1370, 1375}, {1417, 1418}, {1470, 1470}, 78 {1472, 1472}, {1475, 1475}, {1478, 1478}, {1523, 1524}, 79 {1548, 1549}, {1563, 1563}, {1566, 1567}, {1642, 1645}, 80 {1748, 1748}, {1792, 1805}, {2404, 2405}, {2416, 2416}, 81 {3572, 3572}, {3663, 3663}, {3674, 3675}, {3844, 3858}, 82 {3898, 3901}, {3973, 3973}, {4048, 4049}, {4170, 4175}, 83 {4347, 4347}, {4961, 4968}, {5741, 5742}, {5787, 5788}, 84 {5867, 5869}, {5941, 5942}, {6100, 6102}, {6104, 6106}, 85 {6144, 6154}, {6468, 6469}, {6622, 6623}, {6686, 6687}, 86 {8208, 8231}, {8240, 8259}, {8261, 8273}, {8275, 8286}, 87 {8317, 8318}, {8333, 8334}, {9001, 9002}, {9140, 9142}, 88 {10088, 10101}, {10181, 10182}, {10214, 10219}, {10627, 10648}, 89 {10712, 10715}, {10748, 10749}, {11513, 11516}, {11518, 11519}, 90 {11776, 11799}, {11804, 11805}, {12289, 12291}, {12296, 12305}, 91 {12308, 12319}, {12336, 12336}, {12349, 12349}, {12448, 12448}, 92 {12539, 12539}, {64830, 64831}, {65040, 65049}, {65072, 65106}, 93 {65108, 65121}, {65123, 65123}, {65128, 65128}, {65130, 65131}, 94 {65281, 65283}, {65285, 65290}, {65292, 65295}, {65306, 65307}, 95 {65311, 65312}, {65339, 65341}, {65343, 65343}, {65371, 65371}, 96 {65373, 65373}, {65375, 65381}, {65792, 65793}, {66463, 66463}, 97 {68176, 68184}})) { 98 for (int i = codepoint_pair.first; i <= codepoint_pair.second; i++) { 99 selection_options.add_punctuation_to_strip(i); 100 } 101 selection_options.set_strip_punctuation(true); 102 selection_options.set_enforce_symmetry(true); 103 selection_options.set_symmetry_context_size( 104 feature_processor_options.context_size() * 2); 105 } 106 } 107 108 SharingModelOptions sharing_options; 109 auto sharing_options_extension_id = 110 sharing_model_options_in_embedding_network_proto; 111 if (reader.trimmed_proto().HasExtension(sharing_options_extension_id)) { 112 sharing_options = 113 reader.trimmed_proto().GetExtension(sharing_options_extension_id); 114 } else { 115 // Default values when SharingModelOptions is not present. 116 sharing_options.set_always_accept_url_hint(true); 117 sharing_options.set_always_accept_email_hint(true); 118 } 119 120 if (!model_options.use_shared_embeddings()) { 121 std::shared_ptr<EmbeddingParams> embedding_params(new EmbeddingParams( 122 start, num_bytes, feature_processor_options.context_size())); 123 return new ModelParams(start, num_bytes, embedding_params, 124 selection_options, sharing_options, 125 feature_processor_options); 126 } else { 127 return new ModelParams( 128 start, num_bytes, std::move(external_embedding_params), 129 selection_options, sharing_options, feature_processor_options); 130 } 131 } 132 133 } // namespace libtextclassifier 134