Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/python/eager/python_eager_op_gen.h"
     17 
     18 #include <memory>
     19 #include <string>
     20 #include <unordered_set>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/op.h"
     24 #include "tensorflow/core/framework/op_def.pb.h"
     25 #include "tensorflow/core/framework/op_gen_lib.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/io/inputbuffer.h"
     28 #include "tensorflow/core/lib/io/path.h"
     29 #include "tensorflow/core/lib/strings/scanner.h"
     30 #include "tensorflow/core/platform/env.h"
     31 #include "tensorflow/core/platform/init_main.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 
     34 namespace tensorflow {
     35 namespace {
     36 
     37 Status ReadOpListFromFile(const string& filename,
     38                           std::vector<string>* op_list) {
     39   std::unique_ptr<RandomAccessFile> file;
     40   TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file));
     41   std::unique_ptr<io::InputBuffer> input_buffer(
     42       new io::InputBuffer(file.get(), 256 << 10));
     43   string line_contents;
     44   Status s = input_buffer->ReadLine(&line_contents);
     45   while (s.ok()) {
     46     // The parser assumes that the op name is the first string on each
     47     // line with no preceding whitespace, and ignores lines that do
     48     // not start with an op name as a comment.
     49     strings::Scanner scanner{StringPiece(line_contents)};
     50     StringPiece op_name;
     51     if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
     52             .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
     53             .GetResult(nullptr, &op_name)) {
     54       op_list->emplace_back(op_name.ToString());
     55     }
     56     s = input_buffer->ReadLine(&line_contents);
     57   }
     58   if (!errors::IsOutOfRange(s)) return s;
     59   return Status::OK();
     60 }
     61 
     62 // The argument parsing is deliberately simplistic to support our only
     63 // known use cases:
     64 //
     65 // 1. Read all op names from a file.
     66 // 2. Read all op names from the arg as a comma-delimited list.
     67 //
     68 // Expected command-line argument syntax:
     69 // ARG ::= '@' FILENAME
     70 //       |  OP_NAME [',' OP_NAME]*
     71 //       |  ''
     72 Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
     73   std::vector<string> op_names = str_util::Split(arg, ',');
     74   if (op_names.size() == 1 && op_names[0].empty()) {
     75     return Status::OK();
     76   } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@") {
     77     const string filename = op_names[0].substr(1);
     78     return tensorflow::ReadOpListFromFile(filename, op_list);
     79   } else {
     80     *op_list = std::move(op_names);
     81   }
     82   return Status::OK();
     83 }
     84 
     85 // Use the name of the current executable to infer the C++ source file
     86 // where the REGISTER_OP() call for the operator can be found.
     87 // Returns the name of the file.
     88 // Returns an empty string if the current executable's name does not
     89 // follow a known pattern.
     90 string InferSourceFileName(const char* argv_zero) {
     91   StringPiece command_str = io::Basename(argv_zero);
     92 
     93   // For built-in ops, the Bazel build creates a separate executable
     94   // with the name gen_<op type>_ops_py_wrappers_cc containing the
     95   // operators defined in <op type>_ops.cc
     96   const char* kExecPrefix = "gen_";
     97   const char* kExecSuffix = "_py_wrappers_cc";
     98   if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
     99     command_str.remove_suffix(strlen(kExecSuffix));
    100     return strings::StrCat(command_str, ".cc");
    101   } else {
    102     return string("");
    103   }
    104 }
    105 
    106 void PrintAllPythonOps(const std::vector<string>& op_list,
    107                        const std::vector<string>& api_def_dirs,
    108                        const string& source_file_name, bool require_shapes,
    109                        bool op_list_is_whitelist) {
    110   OpList ops;
    111   OpRegistry::Global()->Export(false, &ops);
    112 
    113   ApiDefMap api_def_map(ops);
    114   if (!api_def_dirs.empty()) {
    115     Env* env = Env::Default();
    116 
    117     for (const auto& api_def_dir : api_def_dirs) {
    118       std::vector<string> api_files;
    119       TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
    120                                         &api_files));
    121       TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
    122     }
    123     api_def_map.UpdateDocs();
    124   }
    125 
    126   if (op_list_is_whitelist) {
    127     std::unordered_set<string> whitelist(op_list.begin(), op_list.end());
    128     OpList pruned_ops;
    129     for (const auto& op_def : ops.op()) {
    130       if (whitelist.find(op_def.name()) != whitelist.end()) {
    131         *pruned_ops.mutable_op()->Add() = op_def;
    132       }
    133     }
    134     PrintEagerPythonOps(pruned_ops, api_def_map, {}, require_shapes,
    135                         source_file_name);
    136   } else {
    137     PrintEagerPythonOps(ops, api_def_map, op_list, require_shapes,
    138                         source_file_name);
    139   }
    140 }
    141 
    142 }  // namespace
    143 }  // namespace tensorflow
    144 
    145 int main(int argc, char* argv[]) {
    146   tensorflow::port::InitMain(argv[0], &argc, &argv);
    147 
    148   tensorflow::string source_file_name =
    149       tensorflow::InferSourceFileName(argv[0]);
    150 
    151   // Usage:
    152   //   gen_main api_def_dir1,api_def_dir2,...
    153   //       [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
    154   if (argc < 3) {
    155     return -1;
    156   }
    157   std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
    158       argv[1], ",", tensorflow::str_util::SkipEmpty());
    159 
    160   if (argc == 3) {
    161     tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
    162                                   tensorflow::string(argv[2]) == "1",
    163                                   false /* op_list_is_whitelist */);
    164   } else if (argc == 4) {
    165     std::vector<tensorflow::string> hidden_ops;
    166     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
    167     tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
    168                                   tensorflow::string(argv[3]) == "1",
    169                                   false /* op_list_is_whitelist */);
    170   } else if (argc == 5) {
    171     std::vector<tensorflow::string> op_list;
    172     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
    173     tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
    174                                   tensorflow::string(argv[3]) == "1",
    175                                   tensorflow::string(argv[4]) == "1");
    176   } else {
    177     return -1;
    178   }
    179   return 0;
    180 }
    181