1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "common/task-context.h" 18 19 #include <stdlib.h> 20 21 #include <string> 22 23 #include "util/base/integral_types.h" 24 #include "util/base/logging.h" 25 #include "util/strings/numbers.h" 26 27 namespace libtextclassifier { 28 namespace nlp_core { 29 30 namespace { 31 int32 ParseInt32WithDefault(const std::string &s, int32 defval) { 32 int32 value = defval; 33 return ParseInt32(s.c_str(), &value) ? value : defval; 34 } 35 36 int64 ParseInt64WithDefault(const std::string &s, int64 defval) { 37 int64 value = defval; 38 return ParseInt64(s.c_str(), &value) ? value : defval; 39 } 40 41 double ParseDoubleWithDefault(const std::string &s, double defval) { 42 double value = defval; 43 return ParseDouble(s.c_str(), &value) ? value : defval; 44 } 45 } // namespace 46 47 TaskInput *TaskContext::GetInput(const std::string &name) { 48 // Return existing input if it exists. 49 for (int i = 0; i < spec_.input_size(); ++i) { 50 if (spec_.input(i).name() == name) return spec_.mutable_input(i); 51 } 52 53 // Create new input. 54 TaskInput *input = spec_.add_input(); 55 input->set_name(name); 56 return input; 57 } 58 59 TaskInput *TaskContext::GetInput(const std::string &name, 60 const std::string &file_format, 61 const std::string &record_format) { 62 TaskInput *input = GetInput(name); 63 if (!file_format.empty()) { 64 bool found = false; 65 for (int i = 0; i < input->file_format_size(); ++i) { 66 if (input->file_format(i) == file_format) found = true; 67 } 68 if (!found) input->add_file_format(file_format); 69 } 70 if (!record_format.empty()) { 71 bool found = false; 72 for (int i = 0; i < input->record_format_size(); ++i) { 73 if (input->record_format(i) == record_format) found = true; 74 } 75 if (!found) input->add_record_format(record_format); 76 } 77 return input; 78 } 79 80 void TaskContext::SetParameter(const std::string &name, 81 const std::string &value) { 82 TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")"; 83 84 // If the parameter already exists update the value. 85 for (int i = 0; i < spec_.parameter_size(); ++i) { 86 if (spec_.parameter(i).name() == name) { 87 spec_.mutable_parameter(i)->set_value(value); 88 return; 89 } 90 } 91 92 // Add new parameter. 93 TaskSpec::Parameter *param = spec_.add_parameter(); 94 param->set_name(name); 95 param->set_value(value); 96 } 97 98 std::string TaskContext::GetParameter(const std::string &name) const { 99 // First try to find parameter in task specification. 100 for (int i = 0; i < spec_.parameter_size(); ++i) { 101 if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); 102 } 103 104 // Parameter not found, return empty std::string. 105 return ""; 106 } 107 108 int TaskContext::GetIntParameter(const std::string &name) const { 109 std::string value = GetParameter(name); 110 return ParseInt32WithDefault(value, 0); 111 } 112 113 int64 TaskContext::GetInt64Parameter(const std::string &name) const { 114 std::string value = GetParameter(name); 115 return ParseInt64WithDefault(value, 0); 116 } 117 118 bool TaskContext::GetBoolParameter(const std::string &name) const { 119 std::string value = GetParameter(name); 120 return value == "true"; 121 } 122 123 double TaskContext::GetFloatParameter(const std::string &name) const { 124 std::string value = GetParameter(name); 125 return ParseDoubleWithDefault(value, 0.0); 126 } 127 128 std::string TaskContext::Get(const std::string &name, 129 const char *defval) const { 130 // First try to find parameter in task specification. 131 for (int i = 0; i < spec_.parameter_size(); ++i) { 132 if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); 133 } 134 135 // Parameter not found, return default value. 136 return defval; 137 } 138 139 std::string TaskContext::Get(const std::string &name, 140 const std::string &defval) const { 141 return Get(name, defval.c_str()); 142 } 143 144 int TaskContext::Get(const std::string &name, int defval) const { 145 std::string value = Get(name, ""); 146 return ParseInt32WithDefault(value, defval); 147 } 148 149 int64 TaskContext::Get(const std::string &name, int64 defval) const { 150 std::string value = Get(name, ""); 151 return ParseInt64WithDefault(value, defval); 152 } 153 154 double TaskContext::Get(const std::string &name, double defval) const { 155 std::string value = Get(name, ""); 156 return ParseDoubleWithDefault(value, defval); 157 } 158 159 bool TaskContext::Get(const std::string &name, bool defval) const { 160 std::string value = Get(name, ""); 161 return value.empty() ? defval : value == "true"; 162 } 163 164 std::string TaskContext::InputFile(const TaskInput &input) { 165 if (input.part_size() == 0) { 166 TC_LOG(ERROR) << "No file for TaskInput " << input.name(); 167 return ""; 168 } 169 if (input.part_size() > 1) { 170 TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name(); 171 } 172 return input.part(0).file_pattern(); 173 } 174 175 bool TaskContext::Supports(const TaskInput &input, 176 const std::string &file_format, 177 const std::string &record_format) { 178 // Check file format. 179 if (input.file_format_size() > 0) { 180 bool found = false; 181 for (int i = 0; i < input.file_format_size(); ++i) { 182 if (input.file_format(i) == file_format) { 183 found = true; 184 break; 185 } 186 } 187 if (!found) return false; 188 } 189 190 // Check record format. 191 if (input.record_format_size() > 0) { 192 bool found = false; 193 for (int i = 0; i < input.record_format_size(); ++i) { 194 if (input.record_format(i) == record_format) { 195 found = true; 196 break; 197 } 198 } 199 if (!found) return false; 200 } 201 202 return true; 203 } 204 205 } // namespace nlp_core 206 } // namespace libtextclassifier 207