Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // This abstracts command line arguments in toco.
     16 // Arg<T> is a parseable type that can register a default value, be able to
     17 // parse itself, and keep track of whether it was specified.
     18 #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
     19 #define TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
     20 
     21 #include <functional>
     22 #include <unordered_map>
     23 #include <vector>
     24 #if defined(PLATFORM_GOOGLE)
     25 #include "strings/split.h"
     26 #endif
     27 #include "absl/strings/numbers.h"
     28 #include "absl/strings/str_split.h"
     29 #include "tensorflow/contrib/lite/toco/toco_port.h"
     30 #include "tensorflow/contrib/lite/toco/toco_types.h"
     31 
     32 namespace toco {
     33 
     34 // Since std::vector<int32> is in the std namespace, and we are not allowed
     35 // to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
     36 // to use as the flag type:
     37 struct IntList {
     38   std::vector<int32> elements;
     39 };
     40 struct StringMapList {
     41   std::vector<std::unordered_map<string, string>> elements;
     42 };
     43 
     44 // command_line_flags.h don't track whether or not a flag is specified. Arg
     45 // contains the value (which will be default if not specified) and also
     46 // whether the flag is specified.
     47 // TODO(aselle): consider putting doc string and ability to construct the
     48 // tensorflow argument into this, so declaration of parameters can be less
     49 // distributed.
     50 // Every template specialization of Arg is required to implement
     51 // default_value(), specified(), value(), parse(), bind().
     52 template <class T>
     53 class Arg final {
     54  public:
     55   explicit Arg(T default_ = T()) : value_(default_) {}
     56   virtual ~Arg() {}
     57 
     58   // Provide default_value() to arg list
     59   T default_value() const { return value_; }
     60   // Return true if the command line argument was specified on the command line.
     61   bool specified() const { return specified_; }
     62   // Const reference to parsed value.
     63   const T& value() const { return value_; }
     64 
     65   // Parsing callback for the tensorflow::Flags code
     66   bool parse(T value_in) {
     67     value_ = value_in;
     68     specified_ = true;
     69     return true;
     70   }
     71 
     72   // Bind the parse member function so tensorflow::Flags can call it.
     73   std::function<bool(T)> bind() {
     74     return std::bind(&Arg::parse, this, std::placeholders::_1);
     75   }
     76 
     77  private:
     78   // Becomes true after parsing if the value was specified
     79   bool specified_ = false;
     80   // Value of the argument (initialized to the default in the constructor).
     81   T value_;
     82 };
     83 
     84 template <>
     85 class Arg<toco::IntList> final {
     86  public:
     87   // Provide default_value() to arg list
     88   string default_value() const { return ""; }
     89   // Return true if the command line argument was specified on the command line.
     90   bool specified() const { return specified_; }
     91   // Bind the parse member function so tensorflow::Flags can call it.
     92   bool parse(string text) {
     93     parsed_value_.elements.clear();
     94     specified_ = true;
     95     // strings::Split("") produces {""}, but we need {} on empty input.
     96     // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
     97     // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
     98     if (!text.empty()) {
     99       int32 element;
    100       for (absl::string_view part : absl::StrSplit(text, ',')) {
    101         if (!SimpleAtoi(part, &element)) return false;
    102         parsed_value_.elements.push_back(element);
    103       }
    104     }
    105     return true;
    106   }
    107 
    108   std::function<bool(string)> bind() {
    109     return std::bind(&Arg::parse, this, std::placeholders::_1);
    110   }
    111 
    112   const toco::IntList& value() const { return parsed_value_; }
    113 
    114  private:
    115   toco::IntList parsed_value_;
    116   bool specified_ = false;
    117 };
    118 
    119 template <>
    120 class Arg<toco::StringMapList> final {
    121  public:
    122   // Provide default_value() to StringMapList
    123   string default_value() const { return ""; }
    124   // Return true if the command line argument was specified on the command line.
    125   bool specified() const { return specified_; }
    126   // Bind the parse member function so tensorflow::Flags can call it.
    127 
    128   bool parse(string text) {
    129     parsed_value_.elements.clear();
    130     specified_ = true;
    131 
    132     if (text.empty()) {
    133       return true;
    134     }
    135 
    136 #if defined(PLATFORM_GOOGLE)
    137     std::vector<absl::string_view> outer_vector;
    138     absl::string_view text_disposable_copy = text;
    139     SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector);
    140     for (const absl::string_view& outer_member_stringpiece : outer_vector) {
    141       string outer_member(outer_member_stringpiece);
    142       if (outer_member.empty()) {
    143         continue;
    144       }
    145       string outer_member_copy = outer_member;
    146       absl::StripAsciiWhitespace(&outer_member);
    147       if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
    148       if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
    149       const std::vector<string> inner_fields_vector =
    150           absl::StrSplit(outer_member, ',');
    151 
    152       std::unordered_map<string, string> element;
    153       for (const string& member_field : inner_fields_vector) {
    154         std::vector<string> outer_member_key_value =
    155             absl::StrSplit(member_field, ':');
    156         if (outer_member_key_value.size() != 2) return false;
    157         string& key = outer_member_key_value[0];
    158         string& value = outer_member_key_value[1];
    159         absl::StripAsciiWhitespace(&key);
    160         absl::StripAsciiWhitespace(&value);
    161         if (element.count(key) != 0) return false;
    162         element[key] = value;
    163       }
    164       parsed_value_.elements.push_back(element);
    165     }
    166     return true;
    167 #else
    168     // TODO(aselle): Fix argument parsing when absl supports structuredline
    169     fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__,
    170             __LINE__);
    171     abort();
    172 #endif
    173   }
    174 
    175   std::function<bool(string)> bind() {
    176     return std::bind(&Arg::parse, this, std::placeholders::_1);
    177   }
    178 
    179   const toco::StringMapList& value() const { return parsed_value_; }
    180 
    181  private:
    182   toco::StringMapList parsed_value_;
    183   bool specified_ = false;
    184 };
    185 
    186 // Flags that describe a model. See model_cmdline_flags.cc for details.
    187 struct ParsedModelFlags {
    188   Arg<string> input_array;
    189   Arg<string> input_arrays;
    190   Arg<string> output_array;
    191   Arg<string> output_arrays;
    192   Arg<string> input_shapes;
    193   Arg<float> mean_value = Arg<float>(0.f);
    194   Arg<string> mean_values;
    195   Arg<float> std_value = Arg<float>(1.f);
    196   Arg<string> std_values;
    197   Arg<string> input_data_type;
    198   Arg<string> input_data_types;
    199   Arg<bool> variable_batch = Arg<bool>(false);
    200   Arg<toco::IntList> input_shape;
    201   Arg<toco::StringMapList> rnn_states;
    202   Arg<toco::StringMapList> model_checks;
    203   // Debugging output options.
    204   // TODO(benoitjacob): these shouldn't be ModelFlags.
    205   Arg<string> graphviz_first_array;
    206   Arg<string> graphviz_last_array;
    207   Arg<string> dump_graphviz;
    208   Arg<bool> dump_graphviz_video = Arg<bool>(false);
    209   Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
    210   Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
    211   Arg<string> arrays_extra_info_file;
    212 };
    213 
    214 // Flags that describe the operation you would like to do (what conversion
    215 // you want). See toco_cmdline_flags.cc for details.
    216 struct ParsedTocoFlags {
    217   Arg<string> input_file;
    218   Arg<string> output_file;
    219   Arg<string> input_format;
    220   Arg<string> output_format;
    221   // TODO(aselle): command_line_flags  doesn't support doubles
    222   Arg<float> default_ranges_min = Arg<float>(0.);
    223   Arg<float> default_ranges_max = Arg<float>(0.);
    224   Arg<string> inference_type;
    225   Arg<string> inference_input_type;
    226   Arg<bool> drop_fake_quant = Arg<bool>(false);
    227   Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
    228   Arg<bool> allow_custom_ops = Arg<bool>(false);
    229   // Deprecated flags
    230   Arg<string> input_type;
    231   Arg<string> input_types;
    232   Arg<bool> drop_control_dependency = Arg<bool>(false);
    233 };
    234 
    235 }  // namespace toco
    236 #endif  // TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
    237