1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/python/eager/python_eager_op_gen.h" 17 18 #include <memory> 19 #include <string> 20 #include <unordered_set> 21 #include <vector> 22 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_def.pb.h" 25 #include "tensorflow/core/framework/op_gen_lib.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/io/inputbuffer.h" 28 #include "tensorflow/core/lib/io/path.h" 29 #include "tensorflow/core/lib/strings/scanner.h" 30 #include "tensorflow/core/platform/env.h" 31 #include "tensorflow/core/platform/init_main.h" 32 #include "tensorflow/core/platform/logging.h" 33 34 namespace tensorflow { 35 namespace { 36 37 Status ReadOpListFromFile(const string& filename, 38 std::vector<string>* op_list) { 39 std::unique_ptr<RandomAccessFile> file; 40 TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file)); 41 std::unique_ptr<io::InputBuffer> input_buffer( 42 new io::InputBuffer(file.get(), 256 << 10)); 43 string line_contents; 44 Status s = input_buffer->ReadLine(&line_contents); 45 while (s.ok()) { 46 // The parser assumes that the op name is the first string on each 47 // line with no preceding whitespace, and ignores lines that do 48 // not start with an op name as a comment. 49 strings::Scanner scanner{StringPiece(line_contents)}; 50 StringPiece op_name; 51 if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT) 52 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 53 .GetResult(nullptr, &op_name)) { 54 op_list->emplace_back(op_name.ToString()); 55 } 56 s = input_buffer->ReadLine(&line_contents); 57 } 58 if (!errors::IsOutOfRange(s)) return s; 59 return Status::OK(); 60 } 61 62 // The argument parsing is deliberately simplistic to support our only 63 // known use cases: 64 // 65 // 1. Read all op names from a file. 66 // 2. Read all op names from the arg as a comma-delimited list. 67 // 68 // Expected command-line argument syntax: 69 // ARG ::= '@' FILENAME 70 // | OP_NAME [',' OP_NAME]* 71 // | '' 72 Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) { 73 std::vector<string> op_names = str_util::Split(arg, ','); 74 if (op_names.size() == 1 && op_names[0].empty()) { 75 return Status::OK(); 76 } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@") { 77 const string filename = op_names[0].substr(1); 78 return tensorflow::ReadOpListFromFile(filename, op_list); 79 } else { 80 *op_list = std::move(op_names); 81 } 82 return Status::OK(); 83 } 84 85 // Use the name of the current executable to infer the C++ source file 86 // where the REGISTER_OP() call for the operator can be found. 87 // Returns the name of the file. 88 // Returns an empty string if the current executable's name does not 89 // follow a known pattern. 90 string InferSourceFileName(const char* argv_zero) { 91 StringPiece command_str = io::Basename(argv_zero); 92 93 // For built-in ops, the Bazel build creates a separate executable 94 // with the name gen_<op type>_ops_py_wrappers_cc containing the 95 // operators defined in <op type>_ops.cc 96 const char* kExecPrefix = "gen_"; 97 const char* kExecSuffix = "_py_wrappers_cc"; 98 if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) { 99 command_str.remove_suffix(strlen(kExecSuffix)); 100 return strings::StrCat(command_str, ".cc"); 101 } else { 102 return string(""); 103 } 104 } 105 106 void PrintAllPythonOps(const std::vector<string>& op_list, 107 const std::vector<string>& api_def_dirs, 108 const string& source_file_name, bool require_shapes, 109 bool op_list_is_whitelist) { 110 OpList ops; 111 OpRegistry::Global()->Export(false, &ops); 112 113 ApiDefMap api_def_map(ops); 114 if (!api_def_dirs.empty()) { 115 Env* env = Env::Default(); 116 117 for (const auto& api_def_dir : api_def_dirs) { 118 std::vector<string> api_files; 119 TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"), 120 &api_files)); 121 TF_CHECK_OK(api_def_map.LoadFileList(env, api_files)); 122 } 123 api_def_map.UpdateDocs(); 124 } 125 126 if (op_list_is_whitelist) { 127 std::unordered_set<string> whitelist(op_list.begin(), op_list.end()); 128 OpList pruned_ops; 129 for (const auto& op_def : ops.op()) { 130 if (whitelist.find(op_def.name()) != whitelist.end()) { 131 *pruned_ops.mutable_op()->Add() = op_def; 132 } 133 } 134 PrintEagerPythonOps(pruned_ops, api_def_map, {}, require_shapes, 135 source_file_name); 136 } else { 137 PrintEagerPythonOps(ops, api_def_map, op_list, require_shapes, 138 source_file_name); 139 } 140 } 141 142 } // namespace 143 } // namespace tensorflow 144 145 int main(int argc, char* argv[]) { 146 tensorflow::port::InitMain(argv[0], &argc, &argv); 147 148 tensorflow::string source_file_name = 149 tensorflow::InferSourceFileName(argv[0]); 150 151 // Usage: 152 // gen_main api_def_dir1,api_def_dir2,... 153 // [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1] 154 if (argc < 3) { 155 return -1; 156 } 157 std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split( 158 argv[1], ",", tensorflow::str_util::SkipEmpty()); 159 160 if (argc == 3) { 161 tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name, 162 tensorflow::string(argv[2]) == "1", 163 false /* op_list_is_whitelist */); 164 } else if (argc == 4) { 165 std::vector<tensorflow::string> hidden_ops; 166 TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops)); 167 tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name, 168 tensorflow::string(argv[3]) == "1", 169 false /* op_list_is_whitelist */); 170 } else if (argc == 5) { 171 std::vector<tensorflow::string> op_list; 172 TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list)); 173 tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name, 174 tensorflow::string(argv[3]) == "1", 175 tensorflow::string(argv[4]) == "1"); 176 } else { 177 return -1; 178 } 179 return 0; 180 } 181