1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // This abstracts command line arguments in toco. 16 // Arg<T> is a parseable type that can register a default value, be able to 17 // parse itself, and keep track of whether it was specified. 18 #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ 19 #define TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ 20 21 #include <functional> 22 #include <unordered_map> 23 #include <vector> 24 #if defined(PLATFORM_GOOGLE) 25 #include "strings/split.h" 26 #endif 27 #include "absl/strings/numbers.h" 28 #include "absl/strings/str_split.h" 29 #include "tensorflow/contrib/lite/toco/toco_port.h" 30 #include "tensorflow/contrib/lite/toco/toco_types.h" 31 32 namespace toco { 33 34 // Since std::vector<int32> is in the std namespace, and we are not allowed 35 // to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type 36 // to use as the flag type: 37 struct IntList { 38 std::vector<int32> elements; 39 }; 40 struct StringMapList { 41 std::vector<std::unordered_map<string, string>> elements; 42 }; 43 44 // command_line_flags.h don't track whether or not a flag is specified. Arg 45 // contains the value (which will be default if not specified) and also 46 // whether the flag is specified. 47 // TODO(aselle): consider putting doc string and ability to construct the 48 // tensorflow argument into this, so declaration of parameters can be less 49 // distributed. 50 // Every template specialization of Arg is required to implement 51 // default_value(), specified(), value(), parse(), bind(). 52 template <class T> 53 class Arg final { 54 public: 55 explicit Arg(T default_ = T()) : value_(default_) {} 56 virtual ~Arg() {} 57 58 // Provide default_value() to arg list 59 T default_value() const { return value_; } 60 // Return true if the command line argument was specified on the command line. 61 bool specified() const { return specified_; } 62 // Const reference to parsed value. 63 const T& value() const { return value_; } 64 65 // Parsing callback for the tensorflow::Flags code 66 bool parse(T value_in) { 67 value_ = value_in; 68 specified_ = true; 69 return true; 70 } 71 72 // Bind the parse member function so tensorflow::Flags can call it. 73 std::function<bool(T)> bind() { 74 return std::bind(&Arg::parse, this, std::placeholders::_1); 75 } 76 77 private: 78 // Becomes true after parsing if the value was specified 79 bool specified_ = false; 80 // Value of the argument (initialized to the default in the constructor). 81 T value_; 82 }; 83 84 template <> 85 class Arg<toco::IntList> final { 86 public: 87 // Provide default_value() to arg list 88 string default_value() const { return ""; } 89 // Return true if the command line argument was specified on the command line. 90 bool specified() const { return specified_; } 91 // Bind the parse member function so tensorflow::Flags can call it. 92 bool parse(string text) { 93 parsed_value_.elements.clear(); 94 specified_ = true; 95 // strings::Split("") produces {""}, but we need {} on empty input. 96 // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could 97 // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements) 98 if (!text.empty()) { 99 int32 element; 100 for (absl::string_view part : absl::StrSplit(text, ',')) { 101 if (!SimpleAtoi(part, &element)) return false; 102 parsed_value_.elements.push_back(element); 103 } 104 } 105 return true; 106 } 107 108 std::function<bool(string)> bind() { 109 return std::bind(&Arg::parse, this, std::placeholders::_1); 110 } 111 112 const toco::IntList& value() const { return parsed_value_; } 113 114 private: 115 toco::IntList parsed_value_; 116 bool specified_ = false; 117 }; 118 119 template <> 120 class Arg<toco::StringMapList> final { 121 public: 122 // Provide default_value() to StringMapList 123 string default_value() const { return ""; } 124 // Return true if the command line argument was specified on the command line. 125 bool specified() const { return specified_; } 126 // Bind the parse member function so tensorflow::Flags can call it. 127 128 bool parse(string text) { 129 parsed_value_.elements.clear(); 130 specified_ = true; 131 132 if (text.empty()) { 133 return true; 134 } 135 136 #if defined(PLATFORM_GOOGLE) 137 std::vector<absl::string_view> outer_vector; 138 absl::string_view text_disposable_copy = text; 139 SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); 140 for (const absl::string_view& outer_member_stringpiece : outer_vector) { 141 string outer_member(outer_member_stringpiece); 142 if (outer_member.empty()) { 143 continue; 144 } 145 string outer_member_copy = outer_member; 146 absl::StripAsciiWhitespace(&outer_member); 147 if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; 148 if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; 149 const std::vector<string> inner_fields_vector = 150 absl::StrSplit(outer_member, ','); 151 152 std::unordered_map<string, string> element; 153 for (const string& member_field : inner_fields_vector) { 154 std::vector<string> outer_member_key_value = 155 absl::StrSplit(member_field, ':'); 156 if (outer_member_key_value.size() != 2) return false; 157 string& key = outer_member_key_value[0]; 158 string& value = outer_member_key_value[1]; 159 absl::StripAsciiWhitespace(&key); 160 absl::StripAsciiWhitespace(&value); 161 if (element.count(key) != 0) return false; 162 element[key] = value; 163 } 164 parsed_value_.elements.push_back(element); 165 } 166 return true; 167 #else 168 // TODO(aselle): Fix argument parsing when absl supports structuredline 169 fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__, 170 __LINE__); 171 abort(); 172 #endif 173 } 174 175 std::function<bool(string)> bind() { 176 return std::bind(&Arg::parse, this, std::placeholders::_1); 177 } 178 179 const toco::StringMapList& value() const { return parsed_value_; } 180 181 private: 182 toco::StringMapList parsed_value_; 183 bool specified_ = false; 184 }; 185 186 // Flags that describe a model. See model_cmdline_flags.cc for details. 187 struct ParsedModelFlags { 188 Arg<string> input_array; 189 Arg<string> input_arrays; 190 Arg<string> output_array; 191 Arg<string> output_arrays; 192 Arg<string> input_shapes; 193 Arg<float> mean_value = Arg<float>(0.f); 194 Arg<string> mean_values; 195 Arg<float> std_value = Arg<float>(1.f); 196 Arg<string> std_values; 197 Arg<string> input_data_type; 198 Arg<string> input_data_types; 199 Arg<bool> variable_batch = Arg<bool>(false); 200 Arg<toco::IntList> input_shape; 201 Arg<toco::StringMapList> rnn_states; 202 Arg<toco::StringMapList> model_checks; 203 // Debugging output options. 204 // TODO(benoitjacob): these shouldn't be ModelFlags. 205 Arg<string> graphviz_first_array; 206 Arg<string> graphviz_last_array; 207 Arg<string> dump_graphviz; 208 Arg<bool> dump_graphviz_video = Arg<bool>(false); 209 Arg<bool> allow_nonexistent_arrays = Arg<bool>(false); 210 Arg<bool> allow_nonascii_arrays = Arg<bool>(false); 211 Arg<string> arrays_extra_info_file; 212 }; 213 214 // Flags that describe the operation you would like to do (what conversion 215 // you want). See toco_cmdline_flags.cc for details. 216 struct ParsedTocoFlags { 217 Arg<string> input_file; 218 Arg<string> output_file; 219 Arg<string> input_format; 220 Arg<string> output_format; 221 // TODO(aselle): command_line_flags doesn't support doubles 222 Arg<float> default_ranges_min = Arg<float>(0.); 223 Arg<float> default_ranges_max = Arg<float>(0.); 224 Arg<string> inference_type; 225 Arg<string> inference_input_type; 226 Arg<bool> drop_fake_quant = Arg<bool>(false); 227 Arg<bool> reorder_across_fake_quant = Arg<bool>(false); 228 Arg<bool> allow_custom_ops = Arg<bool>(false); 229 // Deprecated flags 230 Arg<string> input_type; 231 Arg<string> input_types; 232 Arg<bool> drop_control_dependency = Arg<bool>(false); 233 }; 234 235 } // namespace toco 236 #endif // TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ 237