Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "common/task-context.h"
     18 
     19 #include <stdlib.h>
     20 
     21 #include <string>
     22 
     23 #include "util/base/integral_types.h"
     24 #include "util/base/logging.h"
     25 #include "util/strings/numbers.h"
     26 
     27 namespace libtextclassifier {
     28 namespace nlp_core {
     29 
     30 namespace {
     31 int32 ParseInt32WithDefault(const std::string &s, int32 defval) {
     32   int32 value = defval;
     33   return ParseInt32(s.c_str(), &value) ? value : defval;
     34 }
     35 
     36 int64 ParseInt64WithDefault(const std::string &s, int64 defval) {
     37   int64 value = defval;
     38   return ParseInt64(s.c_str(), &value) ? value : defval;
     39 }
     40 
     41 double ParseDoubleWithDefault(const std::string &s, double defval) {
     42   double value = defval;
     43   return ParseDouble(s.c_str(), &value) ? value : defval;
     44 }
     45 }  // namespace
     46 
     47 TaskInput *TaskContext::GetInput(const std::string &name) {
     48   // Return existing input if it exists.
     49   for (int i = 0; i < spec_.input_size(); ++i) {
     50     if (spec_.input(i).name() == name) return spec_.mutable_input(i);
     51   }
     52 
     53   // Create new input.
     54   TaskInput *input = spec_.add_input();
     55   input->set_name(name);
     56   return input;
     57 }
     58 
     59 TaskInput *TaskContext::GetInput(const std::string &name,
     60                                  const std::string &file_format,
     61                                  const std::string &record_format) {
     62   TaskInput *input = GetInput(name);
     63   if (!file_format.empty()) {
     64     bool found = false;
     65     for (int i = 0; i < input->file_format_size(); ++i) {
     66       if (input->file_format(i) == file_format) found = true;
     67     }
     68     if (!found) input->add_file_format(file_format);
     69   }
     70   if (!record_format.empty()) {
     71     bool found = false;
     72     for (int i = 0; i < input->record_format_size(); ++i) {
     73       if (input->record_format(i) == record_format) found = true;
     74     }
     75     if (!found) input->add_record_format(record_format);
     76   }
     77   return input;
     78 }
     79 
     80 void TaskContext::SetParameter(const std::string &name,
     81                                const std::string &value) {
     82   TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")";
     83 
     84   // If the parameter already exists update the value.
     85   for (int i = 0; i < spec_.parameter_size(); ++i) {
     86     if (spec_.parameter(i).name() == name) {
     87       spec_.mutable_parameter(i)->set_value(value);
     88       return;
     89     }
     90   }
     91 
     92   // Add new parameter.
     93   TaskSpec::Parameter *param = spec_.add_parameter();
     94   param->set_name(name);
     95   param->set_value(value);
     96 }
     97 
     98 std::string TaskContext::GetParameter(const std::string &name) const {
     99   // First try to find parameter in task specification.
    100   for (int i = 0; i < spec_.parameter_size(); ++i) {
    101     if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
    102   }
    103 
    104   // Parameter not found, return empty std::string.
    105   return "";
    106 }
    107 
    108 int TaskContext::GetIntParameter(const std::string &name) const {
    109   std::string value = GetParameter(name);
    110   return ParseInt32WithDefault(value, 0);
    111 }
    112 
    113 int64 TaskContext::GetInt64Parameter(const std::string &name) const {
    114   std::string value = GetParameter(name);
    115   return ParseInt64WithDefault(value, 0);
    116 }
    117 
    118 bool TaskContext::GetBoolParameter(const std::string &name) const {
    119   std::string value = GetParameter(name);
    120   return value == "true";
    121 }
    122 
    123 double TaskContext::GetFloatParameter(const std::string &name) const {
    124   std::string value = GetParameter(name);
    125   return ParseDoubleWithDefault(value, 0.0);
    126 }
    127 
    128 std::string TaskContext::Get(const std::string &name,
    129                              const char *defval) const {
    130   // First try to find parameter in task specification.
    131   for (int i = 0; i < spec_.parameter_size(); ++i) {
    132     if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
    133   }
    134 
    135   // Parameter not found, return default value.
    136   return defval;
    137 }
    138 
    139 std::string TaskContext::Get(const std::string &name,
    140                              const std::string &defval) const {
    141   return Get(name, defval.c_str());
    142 }
    143 
    144 int TaskContext::Get(const std::string &name, int defval) const {
    145   std::string value = Get(name, "");
    146   return ParseInt32WithDefault(value, defval);
    147 }
    148 
    149 int64 TaskContext::Get(const std::string &name, int64 defval) const {
    150   std::string value = Get(name, "");
    151   return ParseInt64WithDefault(value, defval);
    152 }
    153 
    154 double TaskContext::Get(const std::string &name, double defval) const {
    155   std::string value = Get(name, "");
    156   return ParseDoubleWithDefault(value, defval);
    157 }
    158 
    159 bool TaskContext::Get(const std::string &name, bool defval) const {
    160   std::string value = Get(name, "");
    161   return value.empty() ? defval : value == "true";
    162 }
    163 
    164 std::string TaskContext::InputFile(const TaskInput &input) {
    165   if (input.part_size() == 0) {
    166     TC_LOG(ERROR) << "No file for TaskInput " << input.name();
    167     return "";
    168   }
    169   if (input.part_size() > 1) {
    170     TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name();
    171   }
    172   return input.part(0).file_pattern();
    173 }
    174 
    175 bool TaskContext::Supports(const TaskInput &input,
    176                            const std::string &file_format,
    177                            const std::string &record_format) {
    178   // Check file format.
    179   if (input.file_format_size() > 0) {
    180     bool found = false;
    181     for (int i = 0; i < input.file_format_size(); ++i) {
    182       if (input.file_format(i) == file_format) {
    183         found = true;
    184         break;
    185       }
    186     }
    187     if (!found) return false;
    188   }
    189 
    190   // Check record format.
    191   if (input.record_format_size() > 0) {
    192     bool found = false;
    193     for (int i = 0; i < input.record_format_size(); ++i) {
    194       if (input.record_format(i) == record_format) {
    195         found = true;
    196         break;
    197       }
    198     }
    199     if (!found) return false;
    200   }
    201 
    202   return true;
    203 }
    204 
    205 }  // namespace nlp_core
    206 }  // namespace libtextclassifier
    207