1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" 17 18 #include "tensorflow/core/common_runtime/constant_folding.h" 19 #include "tensorflow/core/graph/graph_constructor.h" 20 #include "tensorflow/core/graph/node_builder.h" 21 #include "tensorflow/core/graph/subgraph.h" 22 #include "tensorflow/core/platform/init_main.h" 23 #include "tensorflow/core/public/session.h" 24 #include "tensorflow/core/util/command_line_flags.h" 25 #include "tensorflow/tools/graph_transforms/transform_utils.h" 26 27 namespace tensorflow { 28 namespace graph_transforms { 29 namespace { 30 // Ensures the tensor is the expected shape. 31 Status ErrorIfNotVector(const Tensor& input, const string& input_name, 32 int expected_width) { 33 if ((input.shape().dims() != 1) || 34 (input.shape().dim_size(0) != expected_width)) { 35 return errors::InvalidArgument( 36 input_name, 37 " input to batch norm has bad shape: ", input.shape().DebugString()); 38 } 39 return Status::OK(); 40 } 41 42 Status GetScaleAndOffsetValues(const NodeMatch& match, 43 std::vector<float>* scale_values, 44 std::vector<float>* offset_values) { 45 // Find all the nodes we expect in the subgraph. 46 const NodeDef& batch_norm_node = match.node; 47 // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ 48 // by input order and attribute names. 49 CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" || 50 batch_norm_node.op() == "FusedBatchNorm"); 51 const bool is_fused = batch_norm_node.op() == "FusedBatchNorm"; 52 const int mean_idx = is_fused ? 3 : 1; 53 const int var_idx = is_fused ? 4 : 2; 54 const int beta_idx = is_fused ? 2 : 3; 55 const int gamma_idx = is_fused ? 1 : 4; 56 const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon"; 57 // FusedBatchNorm always scales after normalization. 58 const bool scale_after_normalization = 59 is_fused || batch_norm_node.attr().at("scale_after_normalization").b(); 60 61 const NodeDef& mean_node = match.inputs[mean_idx].node; 62 CHECK_EQ("Const", mean_node.op()); 63 const NodeDef& variance_node = match.inputs[var_idx].node; 64 CHECK_EQ("Const", variance_node.op()); 65 const NodeDef& beta_node = match.inputs[beta_idx].node; 66 CHECK_EQ("Const", beta_node.op()); 67 const NodeDef& gamma_node = match.inputs[gamma_idx].node; 68 CHECK_EQ("Const", gamma_node.op()); 69 70 // We have a set of vectors that we want to combine into a vector of 71 // scale values and offset values. 72 Tensor mean = GetNodeTensorAttr(mean_node, "value"); 73 Tensor variance = GetNodeTensorAttr(variance_node, "value"); 74 Tensor beta = GetNodeTensorAttr(beta_node, "value"); 75 Tensor gamma = GetNodeTensorAttr(gamma_node, "value"); 76 const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f(); 77 78 // Make sure all the inputs really are vectors with the same shape. 79 const int64 num_cols = mean.shape().dim_size(0); 80 TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance", num_cols)); 81 TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", num_cols)); 82 TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", num_cols)); 83 84 scale_values->resize(num_cols); 85 offset_values->resize(num_cols); 86 87 // Calculate the scale and offset values to apply. 88 if (scale_after_normalization) { 89 for (int i = 0; i < num_cols; ++i) { 90 (*scale_values)[i] = 91 (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) * 92 gamma.flat<float>()(i); 93 } 94 } else { 95 for (int i = 0; i < num_cols; ++i) { 96 (*scale_values)[i] = 97 (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)); 98 } 99 } 100 for (int i = 0; i < num_cols; ++i) { 101 (*offset_values)[i] = 102 (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i); 103 } 104 return Status::OK(); 105 } 106 107 Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values, 108 const std::vector<float>& offset_values, 109 const NodeMatch& conv_node_match, 110 const string& conv_output_name, 111 std::vector<NodeDef>* new_nodes) { 112 const NodeDef& conv_node = conv_node_match.node; 113 CHECK_EQ("Conv2D", conv_node.op()); 114 const NodeDef& input_node = conv_node_match.inputs[0].node; 115 const NodeDef& weights_node = conv_node_match.inputs[1].node; 116 CHECK_EQ("Const", weights_node.op()); 117 118 Tensor weights = GetNodeTensorAttr(weights_node, "value"); 119 const int64 weights_cols = weights.shape().dim_size(3); 120 CHECK_EQ(weights_cols, scale_values.size()); 121 122 // Multiply the original weights by the scale vector. 123 auto weights_matrix = weights.flat_inner_dims<float>(); 124 Tensor scaled_weights(DT_FLOAT, weights.shape()); 125 auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>(); 126 for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { 127 for (int64 col = 0; col < weights_cols; ++col) { 128 scaled_weights_matrix(row, col) = 129 weights_matrix(row, col) * scale_values[col]; 130 } 131 } 132 // Figure out the remaining bias to add on. 133 Tensor bias_offset(DT_FLOAT, {weights_cols}); 134 auto bias_offset_vector = bias_offset.flat<float>(); 135 for (int64 col = 0; col < weights_cols; ++col) { 136 bias_offset_vector(col) = offset_values[col]; 137 } 138 139 // Construct the new nodes. 140 NodeDef scaled_weights_node; 141 scaled_weights_node.set_op("Const"); 142 scaled_weights_node.set_name(weights_node.name()); 143 SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node); 144 SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node); 145 new_nodes->push_back(scaled_weights_node); 146 147 // The input and convolution can be copied straight over, since the 148 // name of the scaled weights constant is the same as the original. 149 new_nodes->push_back(input_node); 150 new_nodes->push_back(conv_node); 151 152 NodeDef bias_offset_node; 153 bias_offset_node.set_op("Const"); 154 bias_offset_node.set_name(conv_node.name() + "_bn_offset"); 155 SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node); 156 SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node); 157 new_nodes->push_back(bias_offset_node); 158 159 NodeDef bias_add_node; 160 bias_add_node.set_op("BiasAdd"); 161 bias_add_node.set_name(conv_output_name); 162 CopyNodeAttr(conv_node, "T", "T", &bias_add_node); 163 AddNodeInput(conv_node.name(), &bias_add_node); 164 AddNodeInput(bias_offset_node.name(), &bias_add_node); 165 new_nodes->push_back(bias_add_node); 166 return Status::OK(); 167 } 168 169 Status FuseBatchNormWithConv(const NodeMatch& match, 170 std::vector<NodeDef>* new_nodes) { 171 // Calculate the scale and offset values to apply. 172 std::vector<float> scale_values; 173 std::vector<float> offset_values; 174 TF_RETURN_IF_ERROR( 175 GetScaleAndOffsetValues(match, &scale_values, &offset_values)); 176 177 // Fuse conv weights, and set the final output node name as batch_norm_node. 178 const NodeDef& batch_norm_node = match.node; 179 TF_RETURN_IF_ERROR( 180 FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0], 181 batch_norm_node.name(), new_nodes)); 182 return Status::OK(); 183 } 184 185 Status FuseBatchNormWithConvConcat(const NodeMatch& match, 186 std::vector<NodeDef>* new_nodes) { 187 // Calculate the scale and offset values to apply. 188 std::vector<float> scale_values; 189 std::vector<float> offset_values; 190 TF_RETURN_IF_ERROR( 191 GetScaleAndOffsetValues(match, &scale_values, &offset_values)); 192 193 // Find all the nodes we expect in the subgraph. 194 const NodeDef& batch_norm_node = match.node; 195 const NodeMatch& concat_node_match = match.inputs[0]; 196 NodeDef concat_node = concat_node_match.node; 197 CHECK_EQ("ConcatV2", concat_node.op()); 198 199 // First process the axis. 200 NodeDef axis_node = concat_node_match.inputs[2].node; 201 CHECK_EQ("Const", axis_node.op()); 202 Tensor axis = GetNodeTensorAttr(axis_node, "value"); 203 int32 axis_scalar = (axis.scalar<int32>())(); 204 205 // Set both conv0 and conv1 have the same scale and offset in default. 206 std::vector<float> scale0(scale_values); 207 std::vector<float> offset0(offset_values); 208 std::vector<float> scale1(scale_values); 209 std::vector<float> offset1(offset_values); 210 if (axis_scalar == 3) { 211 // If axis is 3, then scale and offset will be split into two halfs. 212 const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node; 213 Tensor weights0 = GetNodeTensorAttr(weights0_node, "value"); 214 const int64 split_cols = weights0.shape().dim_size(3); 215 // Only keep the first half for scale0/offset0. 216 scale0.erase(scale0.begin() + split_cols, scale0.end()); 217 offset0.erase(offset0.begin() + split_cols, offset0.end()); 218 // Only keep the second half for scale1/offset1. 219 scale1.erase(scale1.begin(), scale1.begin() + split_cols); 220 offset1.erase(offset1.begin(), offset1.begin() + split_cols); 221 } 222 223 // Fuse the weights for input0 of conv2d. 224 const string concat0_output_name = concat_node.name() + "_bn_in0"; 225 TF_RETURN_IF_ERROR( 226 FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0], 227 concat0_output_name, new_nodes)); 228 229 // Fuse the weights for input1 of conv2d. 230 const string concat1_output_name = concat_node.name() + "_bn_in1"; 231 TF_RETURN_IF_ERROR( 232 FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1], 233 concat1_output_name, new_nodes)); 234 235 // Push the shape node. 236 new_nodes->push_back(concat_node_match.inputs[2].node); 237 238 // Set the final output op name to batch_normal_node. 239 concat_node.set_name(batch_norm_node.name()); 240 concat_node.set_input(0, concat0_output_name); 241 concat_node.set_input(1, concat1_output_name); 242 new_nodes->push_back(concat_node); 243 return Status::OK(); 244 } 245 } // namespace 246 247 // Finds monolithic batch norm ops (as used in early versions of TensorFlow) and 248 // converts them into premultiplied weight inputs to convolutions. 249 Status FoldOldBatchNorms(const GraphDef& input_graph_def, 250 const TransformFuncContext& context, 251 GraphDef* output_graph_def) { 252 GraphDef current_graph_def = input_graph_def; 253 // We have to do several passes to catch all the old BN nodes, since many of 254 // them may share inputs and so be excluded from replacement in one pass. 255 bool did_graph_change; 256 do { 257 did_graph_change = false; 258 GraphDef replaced_graph_def; 259 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 260 current_graph_def, // clang-format off 261 {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node 262 { 263 {"Conv2D", // conv_node 264 { 265 {"*"}, // input_node 266 {"Const"}, // weights_node 267 } 268 }, 269 {"Const"}, // mean_node 270 {"Const"}, // variance_node 271 {"Const"}, // beta_node 272 {"Const"}, // gamma_node 273 } 274 }, // clang-format on 275 [&did_graph_change](const NodeMatch& match, 276 const std::set<string>& input_nodes, 277 const std::set<string>& output_nodes, 278 std::vector<NodeDef>* new_nodes) { 279 TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes)); 280 did_graph_change = true; 281 return Status::OK(); 282 }, 283 {}, &replaced_graph_def)); 284 current_graph_def = replaced_graph_def; 285 } while (did_graph_change); 286 287 do { 288 did_graph_change = false; 289 GraphDef replaced_graph_def; 290 // Replace BatchNorm with concat as input. 291 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 292 current_graph_def, // clang-format off 293 {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node 294 { 295 {"ConcatV2|Concat", // concat two conv2d. 296 { 297 {"Conv2D", // conv_node 298 { 299 {"*"}, // input_node 300 {"Const"}, // weights_node 301 } 302 }, 303 {"Conv2D", // conv_node 304 { 305 {"*"}, // input_node 306 {"Const"}, // weights_node 307 } 308 }, 309 {"Const"}, // axis 310 }, 311 }, 312 {"Const"}, // mean_node 313 {"Const"}, // variance_node 314 {"Const"}, // beta_node 315 {"Const"}, // gamma_node 316 } 317 }, // clang-format on 318 [&did_graph_change](const NodeMatch& match, 319 const std::set<string>& input_nodes, 320 const std::set<string>& output_nodes, 321 std::vector<NodeDef>* new_nodes) { 322 TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes)); 323 did_graph_change = true; 324 return Status::OK(); 325 }, 326 {}, &replaced_graph_def)); 327 current_graph_def = replaced_graph_def; 328 } while (did_graph_change); 329 330 *output_graph_def = current_graph_def; 331 return Status::OK(); 332 } 333 334 REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms); 335 336 } // namespace graph_transforms 337 } // namespace tensorflow 338