1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // This abstracts command line arguments in toco. 16 // Arg<T> is a parseable type that can register a default value, be able to 17 // parse itself, and keep track of whether it was specified. 18 #ifndef TENSORFLOW_LITE_TOCO_ARGS_H_ 19 #define TENSORFLOW_LITE_TOCO_ARGS_H_ 20 21 #include <functional> 22 #include <unordered_map> 23 #include <vector> 24 #include "tensorflow/lite/toco/toco_port.h" 25 #include "absl/strings/numbers.h" 26 #include "absl/strings/str_split.h" 27 #include "tensorflow/lite/toco/toco_types.h" 28 29 namespace toco { 30 31 // Since std::vector<int32> is in the std namespace, and we are not allowed 32 // to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type 33 // to use as the flag type: 34 struct IntList { 35 std::vector<int32> elements; 36 }; 37 struct StringMapList { 38 std::vector<std::unordered_map<string, string>> elements; 39 }; 40 41 // command_line_flags.h don't track whether or not a flag is specified. Arg 42 // contains the value (which will be default if not specified) and also 43 // whether the flag is specified. 44 // TODO(aselle): consider putting doc string and ability to construct the 45 // tensorflow argument into this, so declaration of parameters can be less 46 // distributed. 47 // Every template specialization of Arg is required to implement 48 // default_value(), specified(), value(), parse(), bind(). 49 template <class T> 50 class Arg final { 51 public: 52 explicit Arg(T default_ = T()) : value_(default_) {} 53 virtual ~Arg() {} 54 55 // Provide default_value() to arg list 56 T default_value() const { return value_; } 57 // Return true if the command line argument was specified on the command line. 58 bool specified() const { return specified_; } 59 // Const reference to parsed value. 60 const T& value() const { return value_; } 61 62 // Parsing callback for the tensorflow::Flags code 63 bool Parse(T value_in) { 64 value_ = value_in; 65 specified_ = true; 66 return true; 67 } 68 69 // Bind the parse member function so tensorflow::Flags can call it. 70 std::function<bool(T)> bind() { 71 return std::bind(&Arg::Parse, this, std::placeholders::_1); 72 } 73 74 private: 75 // Becomes true after parsing if the value was specified 76 bool specified_ = false; 77 // Value of the argument (initialized to the default in the constructor). 78 T value_; 79 }; 80 81 template <> 82 class Arg<toco::IntList> final { 83 public: 84 // Provide default_value() to arg list 85 string default_value() const { return ""; } 86 // Return true if the command line argument was specified on the command line. 87 bool specified() const { return specified_; } 88 // Bind the parse member function so tensorflow::Flags can call it. 89 bool Parse(string text); 90 91 std::function<bool(string)> bind() { 92 return std::bind(&Arg::Parse, this, std::placeholders::_1); 93 } 94 95 const toco::IntList& value() const { return parsed_value_; } 96 97 private: 98 toco::IntList parsed_value_; 99 bool specified_ = false; 100 }; 101 102 template <> 103 class Arg<toco::StringMapList> final { 104 public: 105 // Provide default_value() to StringMapList 106 string default_value() const { return ""; } 107 // Return true if the command line argument was specified on the command line. 108 bool specified() const { return specified_; } 109 // Bind the parse member function so tensorflow::Flags can call it. 110 111 bool Parse(string text); 112 113 std::function<bool(string)> bind() { 114 return std::bind(&Arg::Parse, this, std::placeholders::_1); 115 } 116 117 const toco::StringMapList& value() const { return parsed_value_; } 118 119 private: 120 toco::StringMapList parsed_value_; 121 bool specified_ = false; 122 }; 123 124 // Flags that describe a model. See model_cmdline_flags.cc for details. 125 struct ParsedModelFlags { 126 Arg<string> input_array; 127 Arg<string> input_arrays; 128 Arg<string> output_array; 129 Arg<string> output_arrays; 130 Arg<string> input_shapes; 131 Arg<int> batch_size = Arg<int>(1); 132 Arg<float> mean_value = Arg<float>(0.f); 133 Arg<string> mean_values; 134 Arg<float> std_value = Arg<float>(1.f); 135 Arg<string> std_values; 136 Arg<string> input_data_type; 137 Arg<string> input_data_types; 138 Arg<bool> variable_batch = Arg<bool>(false); 139 Arg<toco::IntList> input_shape; 140 Arg<toco::StringMapList> rnn_states; 141 Arg<toco::StringMapList> model_checks; 142 Arg<bool> change_concat_input_ranges = Arg<bool>(true); 143 // Debugging output options. 144 // TODO(benoitjacob): these shouldn't be ModelFlags. 145 Arg<string> graphviz_first_array; 146 Arg<string> graphviz_last_array; 147 Arg<string> dump_graphviz; 148 Arg<bool> dump_graphviz_video = Arg<bool>(false); 149 Arg<bool> allow_nonexistent_arrays = Arg<bool>(false); 150 Arg<bool> allow_nonascii_arrays = Arg<bool>(false); 151 Arg<string> arrays_extra_info_file; 152 Arg<string> model_flags_file; 153 }; 154 155 // Flags that describe the operation you would like to do (what conversion 156 // you want). See toco_cmdline_flags.cc for details. 157 struct ParsedTocoFlags { 158 Arg<string> input_file; 159 Arg<string> savedmodel_directory; 160 Arg<string> output_file; 161 Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF"); 162 Arg<string> output_format = Arg<string>("TFLITE"); 163 Arg<string> savedmodel_tagset; 164 // TODO(aselle): command_line_flags doesn't support doubles 165 Arg<float> default_ranges_min = Arg<float>(0.); 166 Arg<float> default_ranges_max = Arg<float>(0.); 167 Arg<float> default_int16_ranges_min = Arg<float>(0.); 168 Arg<float> default_int16_ranges_max = Arg<float>(0.); 169 Arg<string> inference_type; 170 Arg<string> inference_input_type; 171 Arg<bool> drop_fake_quant = Arg<bool>(false); 172 Arg<bool> reorder_across_fake_quant = Arg<bool>(false); 173 Arg<bool> allow_custom_ops = Arg<bool>(false); 174 Arg<bool> post_training_quantize = Arg<bool>(false); 175 // Deprecated flags 176 Arg<bool> quantize_weights = Arg<bool>(false); 177 Arg<string> input_type; 178 Arg<string> input_types; 179 Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false); 180 Arg<bool> drop_control_dependency = Arg<bool>(false); 181 Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false); 182 Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false); 183 Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64); 184 Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true); 185 // WARNING: Experimental interface, subject to change 186 Arg<bool> enable_select_tf_ops = Arg<bool>(false); 187 // WARNING: Experimental interface, subject to change 188 Arg<bool> force_select_tf_ops = Arg<bool>(false); 189 }; 190 191 } // namespace toco 192 #endif // TENSORFLOW_LITE_TOCO_ARGS_H_ 193