Home | History | Annotate | Download | only in util
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
     16 
     17 #include <unordered_set>
     18 #include "tensorflow/core/framework/attr_value.pb.h"
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor.pb.h"
     24 #include "tensorflow/core/framework/tensor_shape.pb.h"
     25 #include "tensorflow/core/framework/types.pb.h"
     26 #include "tensorflow/core/kernels/immutable_constant_op.h"
     27 #include "tensorflow/core/platform/env.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 #include "tensorflow/core/util/memmapped_file_system_writer.h"
     30 
     31 namespace tensorflow {
     32 namespace {
     33 class NodeConverter {
     34  public:
     35   // Converts one node. In-place updates node_def, writes the tensor in
     36   // memmapped
     37   // format, using writer. If the conversion has been done, convert_counter is
     38   // increased.
     39   Status ConvertConstantsToImmutable(NodeDef* node_def,
     40                                      MemmappedFileSystemWriter* writer,
     41                                      int* convert_counter,
     42                                      int min_conversion_size_bytes) {
     43     // Check the size.
     44     const AttrValue& value = node_def->attr().at("value");
     45     const TensorProto& tensor_proto = value.tensor();
     46 
     47     // Create copies of tensor datatype and shape, to put into the operator
     48     // after
     49     // the tensor is destroyed.
     50     const DataType tensor_data_type = tensor_proto.dtype();
     51     const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
     52 
     53     // Check that the tensor type is POD, only these types are supported for
     54     // memmapping.
     55     // DataType enum is explicitly converted to int to avoid errors with passing
     56     // enum type are a parameter type to std::unordered_set.
     57     static std::unordered_set<int> supported_types{
     58 #define TYPE_FOR_SET(type) static_cast<int>(DataTypeToEnum<type>::value),
     59         TF_CALL_POD_TYPES(TYPE_FOR_SET)
     60 #undef ADD_TYPE
     61     };
     62 
     63     if (supported_types.count(static_cast<int>(tensor_data_type)) == 0) {
     64       return Status::OK();
     65     }
     66 
     67     // Create Tensor from value and write it in memmapped format.
     68     Tensor parsed(tensor_proto.dtype());
     69     if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
     70       return errors::InvalidArgument("Cannot parse tensor from proto: ",
     71                                      tensor_proto.DebugString());
     72     }
     73     if (parsed.TotalBytes() < static_cast<size_t>(min_conversion_size_bytes)) {
     74       return Status::OK();
     75     }
     76 
     77     const string memmapped_region_name =
     78         MemmappedFileSystem::kMemmappedPackagePrefix +
     79         ConvertVariableNameToUniqueRegionName(node_def->name());
     80 
     81     TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name));
     82 
     83     node_def->set_op("ImmutableConst");
     84 
     85     // Erase all attributes and leave only attributes that can be understood by
     86     // ImmutableConst.
     87     auto* mutable_attr = node_def->mutable_attr();
     88     mutable_attr->clear();
     89 
     90     {
     91       AttrValue attr_value;
     92       attr_value.set_type(tensor_data_type);
     93       mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value});
     94     }
     95     {
     96       AttrValue attr_value;
     97       *(attr_value.mutable_shape()) = tensor_shape;
     98       mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value});
     99     }
    100     {
    101       AttrValue attr_value;
    102       attr_value.set_s(memmapped_region_name);
    103       mutable_attr->insert(
    104           {ImmutableConstantOp::kMemoryRegionNameAttr, attr_value});
    105     }
    106     ++*convert_counter;
    107     return Status::OK();
    108   }
    109 
    110  private:
    111   string ConvertVariableNameToUniqueRegionName(const string& variable_name) {
    112     string region_name = SanitizeVariableName(variable_name);
    113     while (!used_names_.insert(region_name).second) {
    114       region_name += '_';
    115     }
    116     return region_name;
    117   }
    118 
    119   static string SanitizeVariableName(const string& variable_name) {
    120     string result;
    121     for (char c : variable_name) {
    122       if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
    123           (c >= '0' && c <= '9') || c == '_' || c == '.') {
    124         result += c;
    125       } else {
    126         result += '_';
    127       }
    128     }
    129     return result;
    130   }
    131   std::unordered_set<string> used_names_;
    132 };
    133 
    134 }  // namespace
    135 
    136 // Loads the graph, replaces operators, and writes it out.
    137 Status ConvertConstantsToImmutable(const string& in_graph_filename,
    138                                    const string& out_graph_filename,
    139                                    int min_conversion_size_bytes) {
    140   Env* default_env = Env::Default();
    141   GraphDef graph_def;
    142   const auto load_graph_status =
    143       ReadBinaryProto(default_env, in_graph_filename, &graph_def);
    144   if (!load_graph_status.ok()) {
    145     return tensorflow::errors::NotFound(
    146         "Failed to load graph at '", in_graph_filename,
    147         "' : ", load_graph_status.error_message());
    148   }
    149 
    150   NodeConverter node_converter;
    151 
    152   // Create output writer.
    153   MemmappedFileSystemWriter writer;
    154   TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename));
    155 
    156   // Iterate over graph nodes, looking for Const and replacing it with
    157   // ImmutableConst.
    158   int convert_counter = 0;
    159   for (int i = 0; i < graph_def.node_size(); ++i) {
    160     const NodeDef& node = graph_def.node(i);
    161     if (node.op() == "Const") {
    162       // Try to convert to ImmutableConst
    163       TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable(
    164           graph_def.mutable_node(i), &writer, &convert_counter,
    165           min_conversion_size_bytes));
    166     }
    167   }
    168   TF_RETURN_IF_ERROR(writer.SaveProtobuf(
    169       graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef));
    170   TF_RETURN_IF_ERROR(writer.FlushAndClose());
    171   LOG(INFO) << "Converted " << convert_counter << " nodes";
    172   return Status::OK();
    173 }
    174 
    175 }  // namespace tensorflow
    176