Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/graph_transforms/transform_graph.h"
     17 
     18 #include "tensorflow/core/framework/function.pb.h"
     19 #include "tensorflow/core/lib/strings/scanner.h"
     20 #include "tensorflow/core/lib/strings/str_util.h"
     21 #include "tensorflow/core/platform/env.h"
     22 #include "tensorflow/core/platform/init_main.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 #include "tensorflow/core/util/command_line_flags.h"
     25 #include "tensorflow/tools/graph_transforms/file_utils.h"
     26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     27 #if !defined(PLATFORM_WINDOWS)
     28 #include <pwd.h>
     29 #endif
     30 
     31 namespace tensorflow {
     32 namespace graph_transforms {
     33 
     34 using tensorflow::strings::Scanner;
     35 
     36 Status ParseTransformParameters(const string& transforms_string,
     37                                 TransformParameters* params_list) {
     38   params_list->clear();
     39   enum {
     40     TRANSFORM_NAME,
     41     TRANSFORM_PARAM_NAME,
     42     TRANSFORM_PARAM_VALUE,
     43   } state = TRANSFORM_NAME;
     44   StringPiece remaining(transforms_string);
     45   StringPiece match;
     46   StringPiece transform_name;
     47   StringPiece parameter_name;
     48   StringPiece parameter_value;
     49   TransformFuncParameters func_parameters;
     50   while (!remaining.empty()) {
     51     if (state == TRANSFORM_NAME) {
     52       // Reset the list of parameters.
     53       func_parameters.clear();
     54       // Eat up any leading spaces.
     55       Scanner(remaining).AnySpace().GetResult(&remaining, &match);
     56       if (remaining.empty()) {
     57         // Nothing remains after consuming trailing spaces.
     58         // Consumed all transform parameter string without errors.
     59         return Status::OK();
     60       }
     61       // See if we have a valid transform name.
     62       const bool found_transform_name =
     63           Scanner(remaining)
     64               .Many(Scanner::LETTER_DIGIT_UNDERSCORE)
     65               .GetResult(&remaining, &transform_name);
     66       if (!found_transform_name) {
     67         return errors::InvalidArgument("Looking for transform name, but found ",
     68                                        string(remaining).c_str());
     69       }
     70       if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
     71         state = TRANSFORM_PARAM_NAME;
     72       } else {
     73         // Add a transform with no parameters.
     74         params_list->push_back({string(transform_name), func_parameters});
     75         transform_name = "";
     76         state = TRANSFORM_NAME;
     77       }
     78     } else if (state == TRANSFORM_PARAM_NAME) {
     79       if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
     80         params_list->push_back({string(transform_name), func_parameters});
     81         transform_name = "";
     82         state = TRANSFORM_NAME;
     83       } else {
     84         // Eat up any leading spaces or commas.
     85         Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match);
     86         Scanner(remaining).AnySpace().GetResult(&remaining, &match);
     87         // See if we have a valid parameter name.
     88         const bool found_parameter_name =
     89             Scanner(remaining)
     90                 .Many(Scanner::LETTER_DIGIT_UNDERSCORE)
     91                 .GetResult(&remaining, &parameter_name);
     92         if (!found_parameter_name) {
     93           return errors::InvalidArgument(
     94               "Looking for parameter name, but found ",
     95               string(remaining).c_str());
     96         }
     97         if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
     98           state = TRANSFORM_PARAM_VALUE;
     99         } else {
    100           return errors::InvalidArgument("Looking for =, but found ",
    101                                          string(remaining).c_str());
    102         }
    103       }
    104     } else if (state == TRANSFORM_PARAM_VALUE) {
    105       bool found_parameter_value;
    106       // Deal with quoted values.
    107       if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) {
    108         found_parameter_value =
    109             Scanner(remaining).ScanEscapedUntil('"').GetResult(
    110                 &remaining, &parameter_value);
    111         if (found_parameter_value) {
    112           Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match);
    113         }
    114       } else {
    115         // See if we have a valid parameter name.
    116         found_parameter_value =
    117             Scanner(remaining)
    118                 .Many(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
    119                 .GetResult(&remaining, &parameter_value);
    120       }
    121       if (!found_parameter_value) {
    122         return errors::InvalidArgument("Looking for parameter name, but found ",
    123                                        string(remaining).c_str());
    124       }
    125       func_parameters[string(parameter_name)].emplace_back(parameter_value);
    126       // Eat up any trailing quotes.
    127       Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
    128       Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
    129       state = TRANSFORM_PARAM_NAME;
    130     }
    131   }
    132   return Status::OK();
    133 }
    134 
    135 std::string ExpandPath(const std::string& path_string) {
    136 #if defined(PLATFORM_WINDOWS)
    137   return path_string;
    138 #else
    139   if (path_string.empty() || path_string[0] != '~') {
    140     return path_string;
    141   }
    142 
    143   const char* home = nullptr;
    144   std::string::size_type prefix = path_string.find_first_of('/');
    145   if (path_string.length() == 1 || prefix == 1) {
    146     // The value of $HOME, e.g., ~/foo
    147     home = getenv("HOME");
    148     if (!home) {
    149       // If HOME is not available, get uid
    150       struct passwd* pw = getpwuid(getuid());
    151       if (pw) {
    152         home = pw->pw_dir;
    153       }
    154     }
    155   } else {
    156     // The value of ~user, e.g., ~user/foo
    157     std::string user(path_string, 1, (prefix == std::string::npos)
    158                                          ? std::string::npos
    159                                          : prefix - 1);
    160     struct passwd* pw = getpwnam(user.c_str());
    161     if (pw) {
    162       home = pw->pw_dir;
    163     }
    164   }
    165 
    166   if (!home) {
    167     return path_string;
    168   }
    169 
    170   string path(home);
    171   if (prefix == std::string::npos) {
    172     return path;
    173   }
    174 
    175   if (path.length() == 0 || path[path.length() - 1] != '/') {
    176     path += '/';
    177   }
    178   path += path_string.substr(prefix + 1);
    179   return path;
    180 #endif
    181 }
    182 
    183 int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
    184   string in_graph_string = "";
    185   string out_graph_string = "";
    186   string inputs_string = "";
    187   string outputs_string = "";
    188   string transforms_string = "";
    189   bool output_as_text = false;
    190   std::vector<Flag> flag_list = {
    191       Flag("in_graph", &in_graph_string, "input graph file name"),
    192       Flag("out_graph", &out_graph_string, "output graph file name"),
    193       Flag("inputs", &inputs_string, "inputs"),
    194       Flag("outputs", &outputs_string, "outputs"),
    195       Flag("transforms", &transforms_string, "list of transforms"),
    196       Flag("output_as_text", &output_as_text,
    197            "whether to write the graph in text protobuf format"),
    198   };
    199   string usage = Flags::Usage(argv[0], flag_list);
    200   usage += "\nTransforms are:\n";
    201   TransformRegistry* transform_registry = GetTransformRegistry();
    202   for (const auto& pair : *transform_registry) {
    203     usage += pair.first + "\n";
    204   }
    205 
    206   const bool parse_result = Flags::Parse(&argc, argv, flag_list);
    207   // We need to call this to set up global state for TensorFlow.
    208   if (init_main) {
    209     port::InitMain(argv[0], &argc, &argv);
    210   }
    211   if (!parse_result) {
    212     LOG(ERROR) << usage;
    213     return -1;
    214   }
    215   if (argc > 1) {
    216     LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
    217     return -1;
    218   }
    219   if (in_graph_string.empty()) {
    220     LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
    221     return -1;
    222   }
    223   if (out_graph_string.empty()) {
    224     LOG(ERROR) << "out_graph graph can't be empty.\n" << usage;
    225     return -1;
    226   }
    227   if (transforms_string.empty()) {
    228     LOG(ERROR) << "You must specify at least one transform.\n" << usage;
    229     return -1;
    230   }
    231 
    232   string in_graph = ExpandPath(in_graph_string);
    233   string out_graph = ExpandPath(out_graph_string);
    234 
    235   std::vector<string> inputs = str_util::Split(inputs_string, ',');
    236   std::vector<string> outputs = str_util::Split(outputs_string, ',');
    237   TransformParameters transform_params;
    238   Status parse_status =
    239       ParseTransformParameters(transforms_string, &transform_params);
    240   if (!parse_status.ok()) {
    241     LOG(ERROR) << "Failed to parse --transform argument, error was "
    242                << parse_status.error_message();
    243     return -1;
    244   }
    245   if (transform_params.empty()) {
    246     LOG(ERROR) << "You must specify at least one transform.\n" << usage;
    247     return -1;
    248   }
    249 
    250   GraphDef graph_def;
    251   Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
    252   if (!load_status.ok()) {
    253     LOG(ERROR) << "Loading graph '" << in_graph_string << "' failed with "
    254                << load_status.error_message();
    255     LOG(ERROR) << usage;
    256     return -1;
    257   }
    258 
    259   Status transform_result =
    260       TransformGraph(inputs, outputs, transform_params, &graph_def);
    261 
    262   if (!transform_result.ok()) {
    263     LOG(ERROR) << transform_result.error_message();
    264     LOG(ERROR) << usage;
    265     return -1;
    266   }
    267 
    268   Status save_status;
    269   if (output_as_text) {
    270     save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
    271   } else {
    272     save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
    273   }
    274   if (!save_status.ok()) {
    275     LOG(ERROR) << "Saving graph '" << out_graph_string << "' failed with "
    276                << save_status.error_message();
    277     return -1;
    278   }
    279 
    280   return 0;
    281 }
    282 
    283 Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
    284                           bool* ignore_errors) {
    285   *ignore_errors = false;
    286   if (transform_params.count("ignore_errors") &&
    287       (!transform_params.at("ignore_errors").empty())) {
    288     const string& ignore_errors_string =
    289         str_util::Lowercase(transform_params.at("ignore_errors").at(0));
    290     if (ignore_errors_string == "true") {
    291       *ignore_errors = true;
    292     } else if (ignore_errors_string == "false") {
    293       *ignore_errors = false;
    294     } else {
    295       return errors::InvalidArgument(
    296           "ignore_errors should be true or false, found ",
    297           ignore_errors_string);
    298     }
    299   }
    300   return Status::OK();
    301 }
    302 
    303 Status TransformGraph(const std::vector<string>& inputs,
    304                       const std::vector<string>& outputs,
    305                       const TransformParameters& transform_params,
    306                       GraphDef* graph_def) {
    307   TransformRegistry* transform_registry = GetTransformRegistry();
    308   for (const auto& transform_info : transform_params) {
    309     const string& transform_name = transform_info.first;
    310     if (transform_name.empty()) {
    311       continue;
    312     }
    313     if (!transform_registry->count(transform_name)) {
    314       return errors::InvalidArgument("Transform '", transform_name,
    315                                      "' not recognized.");
    316     }
    317     LOG(INFO) << "Applying " << transform_name;
    318     const TransformFunc& transform_func =
    319         transform_registry->at(transform_name);
    320     TransformFuncContext context;
    321     context.input_names = inputs;
    322     context.output_names = outputs;
    323     context.params = transform_info.second;
    324     bool ignore_errors;
    325     TF_RETURN_IF_ERROR(
    326         ShouldIgnoreErrors(transform_info.second, &ignore_errors));
    327     GraphDef transformed_graph_def;
    328     Status transform_result =
    329         transform_func(*graph_def, context, &transformed_graph_def);
    330     if (!transform_result.ok()) {
    331       if (ignore_errors) {
    332         LOG(ERROR) << transform_name << ": Ignoring error "
    333                    << transform_result.error_message();
    334         transformed_graph_def = *graph_def;
    335       } else {
    336         return transform_result;
    337       }
    338     }
    339     // Copy over the library from the original input graph.
    340     *transformed_graph_def.mutable_library() = graph_def->library();
    341     TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def));
    342 
    343     *graph_def = transformed_graph_def;
    344   }
    345   return Status::OK();
    346 }
    347 }  // namespace graph_transforms
    348 }  // namespace tensorflow
    349