Home | History | Annotate | Download | only in util
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Utility that converts a "frozen" inference graph (output from the
     17 // freeze_graph utility) into a format in which large Const ops are converted to
     18 // ImmutableConst ops which are memmapped when the graph is executed by
     19 // TensorFlow.
     20 //
     21 //  tensorflow/contrib/util/convert_graphdef_memmapped_format
     22 //        --in_graph=frozen.model --out_graph=memmapped.mmodel
     23 //
     24 // Parameters:
     25 // in_graph - name of a file with a frozen GraphDef proto in binary format
     26 // out_graph - name of the output file, where the graph in memmapped format will
     27 // be saved.
     28 // min_conversion_size_bytes - tensors with fewer than this many bytes of data
     29 // will not be converted to ImmutableConst format, and kept in the graph.
     30 
     31 #include <vector>
     32 
     33 #include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
     34 #include "tensorflow/core/platform/init_main.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/util/command_line_flags.h"
     37 
     38 namespace tensorflow {
     39 namespace {
     40 
     41 int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
     42   string in_graph = "";
     43   string out_graph = "";
     44   int min_conversion_tensor_size = 10000;
     45   std::vector<Flag> flag_list = {
     46       Flag("in_graph", &in_graph, "input graph"),
     47       Flag("out_graph", &out_graph, "output graph"),
     48       Flag("min_conversion_tensor_size", &min_conversion_tensor_size,
     49            "constants with tensors that have less than this number elements "
     50            "won't be converted into ImmutableConst (be memmapped)"),
     51   };
     52   string usage = Flags::Usage(argv[0], flag_list);
     53   const bool parse_result = Flags::Parse(&argc, argv, flag_list);
     54   // We need to call this to set up global state for TensorFlow.
     55   port::InitMain(usage.c_str(), &argc, &argv);
     56   if (!parse_result) {
     57     LOG(ERROR) << "\n" << usage;
     58     return -1;
     59   }
     60   if (argc > 1) {
     61     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
     62     return -1;
     63   }
     64   if (in_graph.empty()) {
     65     LOG(ERROR) << "in_graph graph can't be empty";
     66     return -1;
     67   }
     68   if (out_graph.empty()) {
     69     LOG(ERROR) << "out_graph graph can't be empty";
     70     return -1;
     71   }
     72   if (min_conversion_tensor_size <= 0) {
     73     LOG(ERROR) << "min_conversion_tensor_size must be > 0";
     74     return -1;
     75   }
     76   const auto result = ConvertConstantsToImmutable(in_graph, out_graph,
     77                                                   min_conversion_tensor_size);
     78   if (!result.ok()) {
     79     LOG(ERROR) << "Conversion failed " << result.error_message();
     80     return -1;
     81   }
     82   return 0;
     83 }
     84 
     85 }  // namespace
     86 }  // namespace tensorflow
     87 
     88 int main(int argc, char* argv[]) {
     89   return tensorflow::ParseFlagsAndConvertGraph(argc, argv);
     90 }
     91