Home | History | Annotate | Download | only in fel
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/common/fel/feature-extractor.h"
     18 
     19 #include "lang_id/common/fel/feature-types.h"
     20 #include "lang_id/common/fel/fel-parser.h"
     21 #include "lang_id/common/lite_base/logging.h"
     22 #include "lang_id/common/lite_strings/numbers.h"
     23 
     24 namespace libtextclassifier3 {
     25 namespace mobile {
     26 
     27 constexpr FeatureValue GenericFeatureFunction::kNone;
     28 
     29 GenericFeatureExtractor::GenericFeatureExtractor() {}
     30 
     31 GenericFeatureExtractor::~GenericFeatureExtractor() {}
     32 
     33 bool GenericFeatureExtractor::Parse(const string &source) {
     34   // Parse feature specification into descriptor.
     35   FELParser parser;
     36 
     37   if (!parser.Parse(source, mutable_descriptor())) {
     38     SAFTM_LOG(ERROR) << "Error parsing the FEL spec " << source;
     39     return false;
     40   }
     41 
     42   // Initialize feature extractor from descriptor.
     43   return InitializeFeatureFunctions();
     44 }
     45 
     46 bool GenericFeatureExtractor::InitializeFeatureTypes() {
     47   // Register all feature types.
     48   GetFeatureTypes(&feature_types_);
     49   for (size_t i = 0; i < feature_types_.size(); ++i) {
     50     FeatureType *ft = feature_types_[i];
     51     ft->set_base(i);
     52 
     53     // Check for feature space overflow.
     54     double domain_size = ft->GetDomainSize();
     55     if (domain_size < 0) {
     56       SAFTM_LOG(ERROR) << "Illegal domain size for feature " << ft->name()
     57                        << ": " << domain_size;
     58       return false;
     59     }
     60   }
     61   return true;
     62 }
     63 
     64 string GenericFeatureFunction::GetParameter(const string &name,
     65                                             const string &default_value) const {
     66   // Find named parameter in feature descriptor.
     67   for (int i = 0; i < descriptor_->parameter_size(); ++i) {
     68     if (name == descriptor_->parameter(i).name()) {
     69       return descriptor_->parameter(i).value();
     70     }
     71   }
     72   return default_value;
     73 }
     74 
     75 GenericFeatureFunction::GenericFeatureFunction() {}
     76 
     77 GenericFeatureFunction::~GenericFeatureFunction() { delete feature_type_; }
     78 
     79 int GenericFeatureFunction::GetIntParameter(const string &name,
     80                                             int default_value) const {
     81   string value_str = GetParameter(name, "");
     82   if (value_str.empty()) {
     83     // Parameter not specified, use default value for it.
     84     return default_value;
     85   }
     86   int value = 0;
     87   if (!LiteAtoi(value_str, &value)) {
     88     SAFTM_LOG(DFATAL) << "Unable to parse '" << value_str
     89                       << "' as int for parameter " << name;
     90     return default_value;
     91   }
     92   return value;
     93 }
     94 
     95 bool GenericFeatureFunction::GetBoolParameter(const string &name,
     96                                               bool default_value) const {
     97   string value = GetParameter(name, "");
     98   if (value.empty()) return default_value;
     99   if (value == "true") return true;
    100   if (value == "false") return false;
    101   SAFTM_LOG(DFATAL) << "Illegal value '" << value << "' for bool parameter "
    102                     << name;
    103   return default_value;
    104 }
    105 
    106 void GenericFeatureFunction::GetFeatureTypes(
    107     std::vector<FeatureType *> *types) const {
    108   if (feature_type_ != nullptr) types->push_back(feature_type_);
    109 }
    110 
    111 FeatureType *GenericFeatureFunction::GetFeatureType() const {
    112   // If a single feature type has been registered return it.
    113   if (feature_type_ != nullptr) return feature_type_;
    114 
    115   // Get feature types for function.
    116   std::vector<FeatureType *> types;
    117   GetFeatureTypes(&types);
    118 
    119   // If there is exactly one feature type return this, else return null.
    120   if (types.size() == 1) return types[0];
    121   return nullptr;
    122 }
    123 
    124 string GenericFeatureFunction::name() const {
    125   string output;
    126   if (descriptor_->name().empty()) {
    127     if (!prefix_.empty()) {
    128       output.append(prefix_);
    129       output.append(".");
    130     }
    131     ToFEL(*descriptor_, &output);
    132   } else {
    133     output = descriptor_->name();
    134   }
    135   return output;
    136 }
    137 
    138 }  // namespace mobile
    139 }  // namespace nlp_saft
    140