1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" 16 17 #include <unordered_set> 18 #include "tensorflow/core/framework/attr_value.pb.h" 19 #include "tensorflow/core/framework/graph.pb.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor.pb.h" 24 #include "tensorflow/core/framework/tensor_shape.pb.h" 25 #include "tensorflow/core/framework/types.pb.h" 26 #include "tensorflow/core/kernels/immutable_constant_op.h" 27 #include "tensorflow/core/platform/env.h" 28 #include "tensorflow/core/platform/logging.h" 29 #include "tensorflow/core/util/memmapped_file_system_writer.h" 30 31 namespace tensorflow { 32 namespace { 33 class NodeConverter { 34 public: 35 // Converts one node. In-place updates node_def, writes the tensor in 36 // memmapped 37 // format, using writer. If the conversion has been done, convert_counter is 38 // increased. 39 Status ConvertConstantsToImmutable(NodeDef* node_def, 40 MemmappedFileSystemWriter* writer, 41 int* convert_counter, 42 int min_conversion_size_bytes) { 43 // Check the size. 44 const AttrValue& value = node_def->attr().at("value"); 45 const TensorProto& tensor_proto = value.tensor(); 46 47 // Create copies of tensor datatype and shape, to put into the operator 48 // after 49 // the tensor is destroyed. 50 const DataType tensor_data_type = tensor_proto.dtype(); 51 const TensorShapeProto tensor_shape = tensor_proto.tensor_shape(); 52 53 // Check that the tensor type is POD, only these types are supported for 54 // memmapping. 55 // DataType enum is explicitly converted to int to avoid errors with passing 56 // enum type are a parameter type to std::unordered_set. 57 static std::unordered_set<int> supported_types{ 58 #define TYPE_FOR_SET(type) static_cast<int>(DataTypeToEnum<type>::value), 59 TF_CALL_POD_TYPES(TYPE_FOR_SET) 60 #undef ADD_TYPE 61 }; 62 63 if (supported_types.count(static_cast<int>(tensor_data_type)) == 0) { 64 return Status::OK(); 65 } 66 67 // Create Tensor from value and write it in memmapped format. 68 Tensor parsed(tensor_proto.dtype()); 69 if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { 70 return errors::InvalidArgument("Cannot parse tensor from proto: ", 71 tensor_proto.DebugString()); 72 } 73 if (parsed.TotalBytes() < static_cast<size_t>(min_conversion_size_bytes)) { 74 return Status::OK(); 75 } 76 77 const string memmapped_region_name = 78 MemmappedFileSystem::kMemmappedPackagePrefix + 79 ConvertVariableNameToUniqueRegionName(node_def->name()); 80 81 TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name)); 82 83 node_def->set_op("ImmutableConst"); 84 85 // Erase all attributes and leave only attributes that can be understood by 86 // ImmutableConst. 87 auto* mutable_attr = node_def->mutable_attr(); 88 mutable_attr->clear(); 89 90 { 91 AttrValue attr_value; 92 attr_value.set_type(tensor_data_type); 93 mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value}); 94 } 95 { 96 AttrValue attr_value; 97 *(attr_value.mutable_shape()) = tensor_shape; 98 mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value}); 99 } 100 { 101 AttrValue attr_value; 102 attr_value.set_s(memmapped_region_name); 103 mutable_attr->insert( 104 {ImmutableConstantOp::kMemoryRegionNameAttr, attr_value}); 105 } 106 ++*convert_counter; 107 return Status::OK(); 108 } 109 110 private: 111 string ConvertVariableNameToUniqueRegionName(const string& variable_name) { 112 string region_name = SanitizeVariableName(variable_name); 113 while (!used_names_.insert(region_name).second) { 114 region_name += '_'; 115 } 116 return region_name; 117 } 118 119 static string SanitizeVariableName(const string& variable_name) { 120 string result; 121 for (char c : variable_name) { 122 if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || 123 (c >= '0' && c <= '9') || c == '_' || c == '.') { 124 result += c; 125 } else { 126 result += '_'; 127 } 128 } 129 return result; 130 } 131 std::unordered_set<string> used_names_; 132 }; 133 134 } // namespace 135 136 // Loads the graph, replaces operators, and writes it out. 137 Status ConvertConstantsToImmutable(const string& in_graph_filename, 138 const string& out_graph_filename, 139 int min_conversion_size_bytes) { 140 Env* default_env = Env::Default(); 141 GraphDef graph_def; 142 const auto load_graph_status = 143 ReadBinaryProto(default_env, in_graph_filename, &graph_def); 144 if (!load_graph_status.ok()) { 145 return tensorflow::errors::NotFound( 146 "Failed to load graph at '", in_graph_filename, 147 "' : ", load_graph_status.error_message()); 148 } 149 150 NodeConverter node_converter; 151 152 // Create output writer. 153 MemmappedFileSystemWriter writer; 154 TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename)); 155 156 // Iterate over graph nodes, looking for Const and replacing it with 157 // ImmutableConst. 158 int convert_counter = 0; 159 for (int i = 0; i < graph_def.node_size(); ++i) { 160 const NodeDef& node = graph_def.node(i); 161 if (node.op() == "Const") { 162 // Try to convert to ImmutableConst 163 TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable( 164 graph_def.mutable_node(i), &writer, &convert_counter, 165 min_conversion_size_bytes)); 166 } 167 } 168 TF_RETURN_IF_ERROR(writer.SaveProtobuf( 169 graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef)); 170 TF_RETURN_IF_ERROR(writer.FlushAndClose()); 171 LOG(INFO) << "Converted " << convert_counter << " nodes"; 172 return Status::OK(); 173 } 174 175 } // namespace tensorflow 176