1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/graph_transforms/transform_graph.h" 17 18 #include "tensorflow/core/framework/function.pb.h" 19 #include "tensorflow/core/lib/strings/scanner.h" 20 #include "tensorflow/core/lib/strings/str_util.h" 21 #include "tensorflow/core/platform/env.h" 22 #include "tensorflow/core/platform/init_main.h" 23 #include "tensorflow/core/platform/logging.h" 24 #include "tensorflow/core/util/command_line_flags.h" 25 #include "tensorflow/tools/graph_transforms/file_utils.h" 26 #include "tensorflow/tools/graph_transforms/transform_utils.h" 27 #if !defined(PLATFORM_WINDOWS) 28 #include <pwd.h> 29 #endif 30 31 namespace tensorflow { 32 namespace graph_transforms { 33 34 using tensorflow::strings::Scanner; 35 36 Status ParseTransformParameters(const string& transforms_string, 37 TransformParameters* params_list) { 38 params_list->clear(); 39 enum { 40 TRANSFORM_NAME, 41 TRANSFORM_PARAM_NAME, 42 TRANSFORM_PARAM_VALUE, 43 } state = TRANSFORM_NAME; 44 StringPiece remaining(transforms_string); 45 StringPiece match; 46 StringPiece transform_name; 47 StringPiece parameter_name; 48 StringPiece parameter_value; 49 TransformFuncParameters func_parameters; 50 while (!remaining.empty()) { 51 if (state == TRANSFORM_NAME) { 52 // Reset the list of parameters. 53 func_parameters.clear(); 54 // Eat up any leading spaces. 55 Scanner(remaining).AnySpace().GetResult(&remaining, &match); 56 if (remaining.empty()) { 57 // Nothing remains after consuming trailing spaces. 58 // Consumed all transform parameter string without errors. 59 return Status::OK(); 60 } 61 // See if we have a valid transform name. 62 const bool found_transform_name = 63 Scanner(remaining) 64 .Many(Scanner::LETTER_DIGIT_UNDERSCORE) 65 .GetResult(&remaining, &transform_name); 66 if (!found_transform_name) { 67 return errors::InvalidArgument("Looking for transform name, but found ", 68 string(remaining).c_str()); 69 } 70 if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) { 71 state = TRANSFORM_PARAM_NAME; 72 } else { 73 // Add a transform with no parameters. 74 params_list->push_back({string(transform_name), func_parameters}); 75 transform_name = ""; 76 state = TRANSFORM_NAME; 77 } 78 } else if (state == TRANSFORM_PARAM_NAME) { 79 if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) { 80 params_list->push_back({string(transform_name), func_parameters}); 81 transform_name = ""; 82 state = TRANSFORM_NAME; 83 } else { 84 // Eat up any leading spaces or commas. 85 Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match); 86 Scanner(remaining).AnySpace().GetResult(&remaining, &match); 87 // See if we have a valid parameter name. 88 const bool found_parameter_name = 89 Scanner(remaining) 90 .Many(Scanner::LETTER_DIGIT_UNDERSCORE) 91 .GetResult(&remaining, ¶meter_name); 92 if (!found_parameter_name) { 93 return errors::InvalidArgument( 94 "Looking for parameter name, but found ", 95 string(remaining).c_str()); 96 } 97 if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) { 98 state = TRANSFORM_PARAM_VALUE; 99 } else { 100 return errors::InvalidArgument("Looking for =, but found ", 101 string(remaining).c_str()); 102 } 103 } 104 } else if (state == TRANSFORM_PARAM_VALUE) { 105 bool found_parameter_value; 106 // Deal with quoted values. 107 if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) { 108 found_parameter_value = 109 Scanner(remaining).ScanEscapedUntil('"').GetResult( 110 &remaining, ¶meter_value); 111 if (found_parameter_value) { 112 Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match); 113 } 114 } else { 115 // See if we have a valid parameter name. 116 found_parameter_value = 117 Scanner(remaining) 118 .Many(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 119 .GetResult(&remaining, ¶meter_value); 120 } 121 if (!found_parameter_value) { 122 return errors::InvalidArgument("Looking for parameter name, but found ", 123 string(remaining).c_str()); 124 } 125 func_parameters[string(parameter_name)].emplace_back(parameter_value); 126 // Eat up any trailing quotes. 127 Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match); 128 Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match); 129 state = TRANSFORM_PARAM_NAME; 130 } 131 } 132 return Status::OK(); 133 } 134 135 std::string ExpandPath(const std::string& path_string) { 136 #if defined(PLATFORM_WINDOWS) 137 return path_string; 138 #else 139 if (path_string.empty() || path_string[0] != '~') { 140 return path_string; 141 } 142 143 const char* home = nullptr; 144 std::string::size_type prefix = path_string.find_first_of('/'); 145 if (path_string.length() == 1 || prefix == 1) { 146 // The value of $HOME, e.g., ~/foo 147 home = getenv("HOME"); 148 if (!home) { 149 // If HOME is not available, get uid 150 struct passwd* pw = getpwuid(getuid()); 151 if (pw) { 152 home = pw->pw_dir; 153 } 154 } 155 } else { 156 // The value of ~user, e.g., ~user/foo 157 std::string user(path_string, 1, (prefix == std::string::npos) 158 ? std::string::npos 159 : prefix - 1); 160 struct passwd* pw = getpwnam(user.c_str()); 161 if (pw) { 162 home = pw->pw_dir; 163 } 164 } 165 166 if (!home) { 167 return path_string; 168 } 169 170 string path(home); 171 if (prefix == std::string::npos) { 172 return path; 173 } 174 175 if (path.length() == 0 || path[path.length() - 1] != '/') { 176 path += '/'; 177 } 178 path += path_string.substr(prefix + 1); 179 return path; 180 #endif 181 } 182 183 int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) { 184 string in_graph_string = ""; 185 string out_graph_string = ""; 186 string inputs_string = ""; 187 string outputs_string = ""; 188 string transforms_string = ""; 189 bool output_as_text = false; 190 std::vector<Flag> flag_list = { 191 Flag("in_graph", &in_graph_string, "input graph file name"), 192 Flag("out_graph", &out_graph_string, "output graph file name"), 193 Flag("inputs", &inputs_string, "inputs"), 194 Flag("outputs", &outputs_string, "outputs"), 195 Flag("transforms", &transforms_string, "list of transforms"), 196 Flag("output_as_text", &output_as_text, 197 "whether to write the graph in text protobuf format"), 198 }; 199 string usage = Flags::Usage(argv[0], flag_list); 200 usage += "\nTransforms are:\n"; 201 TransformRegistry* transform_registry = GetTransformRegistry(); 202 for (const auto& pair : *transform_registry) { 203 usage += pair.first + "\n"; 204 } 205 206 const bool parse_result = Flags::Parse(&argc, argv, flag_list); 207 // We need to call this to set up global state for TensorFlow. 208 if (init_main) { 209 port::InitMain(argv[0], &argc, &argv); 210 } 211 if (!parse_result) { 212 LOG(ERROR) << usage; 213 return -1; 214 } 215 if (argc > 1) { 216 LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage; 217 return -1; 218 } 219 if (in_graph_string.empty()) { 220 LOG(ERROR) << "in_graph graph can't be empty.\n" << usage; 221 return -1; 222 } 223 if (out_graph_string.empty()) { 224 LOG(ERROR) << "out_graph graph can't be empty.\n" << usage; 225 return -1; 226 } 227 if (transforms_string.empty()) { 228 LOG(ERROR) << "You must specify at least one transform.\n" << usage; 229 return -1; 230 } 231 232 string in_graph = ExpandPath(in_graph_string); 233 string out_graph = ExpandPath(out_graph_string); 234 235 std::vector<string> inputs = str_util::Split(inputs_string, ','); 236 std::vector<string> outputs = str_util::Split(outputs_string, ','); 237 TransformParameters transform_params; 238 Status parse_status = 239 ParseTransformParameters(transforms_string, &transform_params); 240 if (!parse_status.ok()) { 241 LOG(ERROR) << "Failed to parse --transform argument, error was " 242 << parse_status.error_message(); 243 return -1; 244 } 245 if (transform_params.empty()) { 246 LOG(ERROR) << "You must specify at least one transform.\n" << usage; 247 return -1; 248 } 249 250 GraphDef graph_def; 251 Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def); 252 if (!load_status.ok()) { 253 LOG(ERROR) << "Loading graph '" << in_graph_string << "' failed with " 254 << load_status.error_message(); 255 LOG(ERROR) << usage; 256 return -1; 257 } 258 259 Status transform_result = 260 TransformGraph(inputs, outputs, transform_params, &graph_def); 261 262 if (!transform_result.ok()) { 263 LOG(ERROR) << transform_result.error_message(); 264 LOG(ERROR) << usage; 265 return -1; 266 } 267 268 Status save_status; 269 if (output_as_text) { 270 save_status = WriteTextProto(Env::Default(), out_graph, graph_def); 271 } else { 272 save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def); 273 } 274 if (!save_status.ok()) { 275 LOG(ERROR) << "Saving graph '" << out_graph_string << "' failed with " 276 << save_status.error_message(); 277 return -1; 278 } 279 280 return 0; 281 } 282 283 Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, 284 bool* ignore_errors) { 285 *ignore_errors = false; 286 if (transform_params.count("ignore_errors") && 287 (!transform_params.at("ignore_errors").empty())) { 288 const string& ignore_errors_string = 289 str_util::Lowercase(transform_params.at("ignore_errors").at(0)); 290 if (ignore_errors_string == "true") { 291 *ignore_errors = true; 292 } else if (ignore_errors_string == "false") { 293 *ignore_errors = false; 294 } else { 295 return errors::InvalidArgument( 296 "ignore_errors should be true or false, found ", 297 ignore_errors_string); 298 } 299 } 300 return Status::OK(); 301 } 302 303 Status TransformGraph(const std::vector<string>& inputs, 304 const std::vector<string>& outputs, 305 const TransformParameters& transform_params, 306 GraphDef* graph_def) { 307 TransformRegistry* transform_registry = GetTransformRegistry(); 308 for (const auto& transform_info : transform_params) { 309 const string& transform_name = transform_info.first; 310 if (transform_name.empty()) { 311 continue; 312 } 313 if (!transform_registry->count(transform_name)) { 314 return errors::InvalidArgument("Transform '", transform_name, 315 "' not recognized."); 316 } 317 LOG(INFO) << "Applying " << transform_name; 318 const TransformFunc& transform_func = 319 transform_registry->at(transform_name); 320 TransformFuncContext context; 321 context.input_names = inputs; 322 context.output_names = outputs; 323 context.params = transform_info.second; 324 bool ignore_errors; 325 TF_RETURN_IF_ERROR( 326 ShouldIgnoreErrors(transform_info.second, &ignore_errors)); 327 GraphDef transformed_graph_def; 328 Status transform_result = 329 transform_func(*graph_def, context, &transformed_graph_def); 330 if (!transform_result.ok()) { 331 if (ignore_errors) { 332 LOG(ERROR) << transform_name << ": Ignoring error " 333 << transform_result.error_message(); 334 transformed_graph_def = *graph_def; 335 } else { 336 return transform_result; 337 } 338 } 339 // Copy over the library from the original input graph. 340 *transformed_graph_def.mutable_library() = graph_def->library(); 341 TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def)); 342 343 *graph_def = transformed_graph_def; 344 } 345 return Status::OK(); 346 } 347 } // namespace graph_transforms 348 } // namespace tensorflow 349