Home | History | Annotate | Download | only in proto_text
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <stdio.h>
     17 #include <set>
     18 
     19 #include "tensorflow/core/platform/logging.h"
     20 #include "tensorflow/core/platform/protobuf.h"
     21 #include "tensorflow/core/platform/types.h"
     22 #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
     23 
     24 namespace tensorflow {
     25 
     26 namespace {
     27 class CrashOnErrorCollector
     28     : public tensorflow::protobuf::compiler::MultiFileErrorCollector {
     29  public:
     30   ~CrashOnErrorCollector() override {}
     31 
     32   void AddError(const string& filename, int line, int column,
     33                 const string& message) override {
     34     LOG(FATAL) << "Unexpected error at " << filename << "@" << line << ":"
     35                << column << " - " << message;
     36   }
     37 };
     38 
     39 static const char kTensorFlowHeaderPrefix[] = "";
     40 
     41 static const char kPlaceholderFile[] =
     42     "tensorflow/tools/proto_text/placeholder.txt";
     43 
     44 bool IsPlaceholderFile(const char* s) {
     45   string ph(kPlaceholderFile);
     46   string str(s);
     47   return str.size() >= strlen(kPlaceholderFile) &&
     48          ph == str.substr(str.size() - ph.size());
     49 }
     50 
     51 }  // namespace
     52 
     53 // Main program to take input protos and write output pb_text source files that
     54 // contain generated proto text input and output functions.
     55 //
     56 // Main expects:
     57 // - First argument is output path
     58 // - Second argument is the relative path of the protos to the root. E.g.,
     59 //   for protos built by a rule in tensorflow/core, this will be
     60 //   tensorflow/core.
     61 // - Then any number of source proto file names, plus one source name must be
     62 //   placeholder.txt from this gen tool's package.  placeholder.txt is
     63 //   ignored for proto resolution, but is used to determine the root at which
     64 //   the build tool has placed the source proto files.
     65 //
     66 // Note that this code doesn't use tensorflow's command line parsing, because of
     67 // circular dependencies between libraries if that were done.
     68 //
     69 // This is meant to be invoked by a genrule. See BUILD for more information.
     70 int MainImpl(int argc, char** argv) {
     71   if (argc < 4) {
     72     LOG(ERROR) << "Pass output path, relative path, and at least proto file";
     73     return -1;
     74   }
     75 
     76   const string output_root = argv[1];
     77   const string output_relative_path = kTensorFlowHeaderPrefix + string(argv[2]);
     78 
     79   string src_relative_path;
     80   bool has_placeholder = false;
     81   for (int i = 3; i < argc; ++i) {
     82     if (IsPlaceholderFile(argv[i])) {
     83       const string s(argv[i]);
     84       src_relative_path = s.substr(0, s.size() - strlen(kPlaceholderFile));
     85       has_placeholder = true;
     86     }
     87   }
     88   if (!has_placeholder) {
     89     LOG(ERROR) << kPlaceholderFile << " must be passed";
     90     return -1;
     91   }
     92 
     93   tensorflow::protobuf::compiler::DiskSourceTree source_tree;
     94 
     95   source_tree.MapPath("", src_relative_path.empty() ? "." : src_relative_path);
     96   CrashOnErrorCollector crash_on_error;
     97   tensorflow::protobuf::compiler::Importer importer(&source_tree,
     98                                                     &crash_on_error);
     99 
    100   for (int i = 3; i < argc; i++) {
    101     if (IsPlaceholderFile(argv[i])) continue;
    102     const string proto_path = string(argv[i]).substr(src_relative_path.size());
    103 
    104     const tensorflow::protobuf::FileDescriptor* fd =
    105         importer.Import(proto_path);
    106 
    107     const int index = proto_path.find_last_of(".");
    108     string proto_path_no_suffix = proto_path.substr(0, index);
    109 
    110     proto_path_no_suffix =
    111         proto_path_no_suffix.substr(output_relative_path.size());
    112 
    113     const auto code =
    114         tensorflow::GetProtoTextFunctionCode(*fd, kTensorFlowHeaderPrefix);
    115 
    116     // Three passes, one for each output file.
    117     for (int pass = 0; pass < 3; ++pass) {
    118       string suffix;
    119       string data;
    120       if (pass == 0) {
    121         suffix = ".pb_text.h";
    122         data = code.header;
    123       } else if (pass == 1) {
    124         suffix = ".pb_text-impl.h";
    125         data = code.header_impl;
    126       } else {
    127         suffix = ".pb_text.cc";
    128         data = code.cc;
    129       }
    130 
    131       const string path = output_root + "/" + proto_path_no_suffix + suffix;
    132       FILE* f = fopen(path.c_str(), "w");
    133       if (f == nullptr) return -1;
    134       if (fwrite(data.c_str(), 1, data.size(), f) != data.size()) {
    135         fclose(f);
    136         return -1;
    137       }
    138       if (fclose(f) != 0) {
    139         return -1;
    140       }
    141     }
    142   }
    143   return 0;
    144 }
    145 
    146 }  // namespace tensorflow
    147 
    148 int main(int argc, char** argv) { return tensorflow::MainImpl(argc, argv); }
    149