Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
     17 
     18 #include "tensorflow/core/common_runtime/constant_folding.h"
     19 #include "tensorflow/core/graph/graph_constructor.h"
     20 #include "tensorflow/core/graph/node_builder.h"
     21 #include "tensorflow/core/graph/subgraph.h"
     22 #include "tensorflow/core/platform/init_main.h"
     23 #include "tensorflow/core/public/session.h"
     24 #include "tensorflow/core/util/command_line_flags.h"
     25 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     26 
     27 namespace tensorflow {
     28 namespace graph_transforms {
     29 namespace {
     30 // Ensures the tensor is the expected shape.
     31 Status ErrorIfNotVector(const Tensor& input, const string& input_name,
     32                         int expected_width) {
     33   if ((input.shape().dims() != 1) ||
     34       (input.shape().dim_size(0) != expected_width)) {
     35     return errors::InvalidArgument(
     36         input_name,
     37         " input to batch norm has bad shape: ", input.shape().DebugString());
     38   }
     39   return Status::OK();
     40 }
     41 
     42 Status GetScaleAndOffsetValues(const NodeMatch& match,
     43                                std::vector<float>* scale_values,
     44                                std::vector<float>* offset_values) {
     45   // Find all the nodes we expect in the subgraph.
     46   const NodeDef& batch_norm_node = match.node;
     47   // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ
     48   // by input order and attribute names.
     49   CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" ||
     50         batch_norm_node.op() == "FusedBatchNorm");
     51   const bool is_fused = batch_norm_node.op() == "FusedBatchNorm";
     52   const int mean_idx = is_fused ? 3 : 1;
     53   const int var_idx = is_fused ? 4 : 2;
     54   const int beta_idx = is_fused ? 2 : 3;
     55   const int gamma_idx = is_fused ? 1 : 4;
     56   const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon";
     57   // FusedBatchNorm always scales after normalization.
     58   const bool scale_after_normalization =
     59       is_fused || batch_norm_node.attr().at("scale_after_normalization").b();
     60 
     61   const NodeDef& mean_node = match.inputs[mean_idx].node;
     62   CHECK_EQ("Const", mean_node.op());
     63   const NodeDef& variance_node = match.inputs[var_idx].node;
     64   CHECK_EQ("Const", variance_node.op());
     65   const NodeDef& beta_node = match.inputs[beta_idx].node;
     66   CHECK_EQ("Const", beta_node.op());
     67   const NodeDef& gamma_node = match.inputs[gamma_idx].node;
     68   CHECK_EQ("Const", gamma_node.op());
     69 
     70   // We have a set of vectors that we want to combine into a vector of
     71   // scale values and offset values.
     72   Tensor mean = GetNodeTensorAttr(mean_node, "value");
     73   Tensor variance = GetNodeTensorAttr(variance_node, "value");
     74   Tensor beta = GetNodeTensorAttr(beta_node, "value");
     75   Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
     76   const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f();
     77 
     78   // Make sure all the inputs really are vectors with the same shape.
     79   const int64 num_cols = mean.shape().dim_size(0);
     80   TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance", num_cols));
     81   TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", num_cols));
     82   TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", num_cols));
     83 
     84   scale_values->resize(num_cols);
     85   offset_values->resize(num_cols);
     86 
     87   // Calculate the scale and offset values to apply.
     88   if (scale_after_normalization) {
     89     for (int i = 0; i < num_cols; ++i) {
     90       (*scale_values)[i] =
     91           (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
     92           gamma.flat<float>()(i);
     93     }
     94   } else {
     95     for (int i = 0; i < num_cols; ++i) {
     96       (*scale_values)[i] =
     97           (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
     98     }
     99   }
    100   for (int i = 0; i < num_cols; ++i) {
    101     (*offset_values)[i] =
    102         (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i);
    103   }
    104   return Status::OK();
    105 }
    106 
    107 Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values,
    108                                     const std::vector<float>& offset_values,
    109                                     const NodeMatch& conv_node_match,
    110                                     const string& conv_output_name,
    111                                     std::vector<NodeDef>* new_nodes) {
    112   const NodeDef& conv_node = conv_node_match.node;
    113   CHECK_EQ("Conv2D", conv_node.op());
    114   const NodeDef& input_node = conv_node_match.inputs[0].node;
    115   const NodeDef& weights_node = conv_node_match.inputs[1].node;
    116   CHECK_EQ("Const", weights_node.op());
    117 
    118   Tensor weights = GetNodeTensorAttr(weights_node, "value");
    119   const int64 weights_cols = weights.shape().dim_size(3);
    120   CHECK_EQ(weights_cols, scale_values.size());
    121 
    122   // Multiply the original weights by the scale vector.
    123   auto weights_matrix = weights.flat_inner_dims<float>();
    124   Tensor scaled_weights(DT_FLOAT, weights.shape());
    125   auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
    126   for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
    127     for (int64 col = 0; col < weights_cols; ++col) {
    128       scaled_weights_matrix(row, col) =
    129           weights_matrix(row, col) * scale_values[col];
    130     }
    131   }
    132   // Figure out the remaining bias to add on.
    133   Tensor bias_offset(DT_FLOAT, {weights_cols});
    134   auto bias_offset_vector = bias_offset.flat<float>();
    135   for (int64 col = 0; col < weights_cols; ++col) {
    136     bias_offset_vector(col) = offset_values[col];
    137   }
    138 
    139   // Construct the new nodes.
    140   NodeDef scaled_weights_node;
    141   scaled_weights_node.set_op("Const");
    142   scaled_weights_node.set_name(weights_node.name());
    143   SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
    144   SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
    145   new_nodes->push_back(scaled_weights_node);
    146 
    147   // The input and convolution can be copied straight over, since the
    148   // name of the scaled weights constant is the same as the original.
    149   new_nodes->push_back(input_node);
    150   new_nodes->push_back(conv_node);
    151 
    152   NodeDef bias_offset_node;
    153   bias_offset_node.set_op("Const");
    154   bias_offset_node.set_name(conv_node.name() + "_bn_offset");
    155   SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
    156   SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
    157   new_nodes->push_back(bias_offset_node);
    158 
    159   NodeDef bias_add_node;
    160   bias_add_node.set_op("BiasAdd");
    161   bias_add_node.set_name(conv_output_name);
    162   CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
    163   AddNodeInput(conv_node.name(), &bias_add_node);
    164   AddNodeInput(bias_offset_node.name(), &bias_add_node);
    165   new_nodes->push_back(bias_add_node);
    166   return Status::OK();
    167 }
    168 
    169 Status FuseBatchNormWithConv(const NodeMatch& match,
    170                              std::vector<NodeDef>* new_nodes) {
    171   // Calculate the scale and offset values to apply.
    172   std::vector<float> scale_values;
    173   std::vector<float> offset_values;
    174   TF_RETURN_IF_ERROR(
    175       GetScaleAndOffsetValues(match, &scale_values, &offset_values));
    176 
    177   // Fuse conv weights, and set the final output node name as batch_norm_node.
    178   const NodeDef& batch_norm_node = match.node;
    179   TF_RETURN_IF_ERROR(
    180       FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0],
    181                                    batch_norm_node.name(), new_nodes));
    182   return Status::OK();
    183 }
    184 
    185 Status FuseBatchNormWithConvConcat(const NodeMatch& match,
    186                                    std::vector<NodeDef>* new_nodes) {
    187   // Calculate the scale and offset values to apply.
    188   std::vector<float> scale_values;
    189   std::vector<float> offset_values;
    190   TF_RETURN_IF_ERROR(
    191       GetScaleAndOffsetValues(match, &scale_values, &offset_values));
    192 
    193   // Find all the nodes we expect in the subgraph.
    194   const NodeDef& batch_norm_node = match.node;
    195   const NodeMatch& concat_node_match = match.inputs[0];
    196   NodeDef concat_node = concat_node_match.node;
    197   CHECK_EQ("ConcatV2", concat_node.op());
    198 
    199   // First process the axis.
    200   NodeDef axis_node = concat_node_match.inputs[2].node;
    201   CHECK_EQ("Const", axis_node.op());
    202   Tensor axis = GetNodeTensorAttr(axis_node, "value");
    203   int32 axis_scalar = (axis.scalar<int32>())();
    204 
    205   // Set both conv0 and conv1 have the same scale and offset in default.
    206   std::vector<float> scale0(scale_values);
    207   std::vector<float> offset0(offset_values);
    208   std::vector<float> scale1(scale_values);
    209   std::vector<float> offset1(offset_values);
    210   if (axis_scalar == 3) {
    211     // If axis is 3, then scale and offset will be split into two halfs.
    212     const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node;
    213     Tensor weights0 = GetNodeTensorAttr(weights0_node, "value");
    214     const int64 split_cols = weights0.shape().dim_size(3);
    215     // Only keep the first half for scale0/offset0.
    216     scale0.erase(scale0.begin() + split_cols, scale0.end());
    217     offset0.erase(offset0.begin() + split_cols, offset0.end());
    218     // Only keep the second half for scale1/offset1.
    219     scale1.erase(scale1.begin(), scale1.begin() + split_cols);
    220     offset1.erase(offset1.begin(), offset1.begin() + split_cols);
    221   }
    222 
    223   // Fuse the weights for input0 of conv2d.
    224   const string concat0_output_name = concat_node.name() + "_bn_in0";
    225   TF_RETURN_IF_ERROR(
    226       FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0],
    227                                    concat0_output_name, new_nodes));
    228 
    229   // Fuse the weights for input1 of conv2d.
    230   const string concat1_output_name = concat_node.name() + "_bn_in1";
    231   TF_RETURN_IF_ERROR(
    232       FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1],
    233                                    concat1_output_name, new_nodes));
    234 
    235   // Push the shape node.
    236   new_nodes->push_back(concat_node_match.inputs[2].node);
    237 
    238   // Set the final output op name to batch_normal_node.
    239   concat_node.set_name(batch_norm_node.name());
    240   concat_node.set_input(0, concat0_output_name);
    241   concat_node.set_input(1, concat1_output_name);
    242   new_nodes->push_back(concat_node);
    243   return Status::OK();
    244 }
    245 }  // namespace
    246 
    247 // Finds monolithic batch norm ops (as used in early versions of TensorFlow) and
    248 // converts them into premultiplied weight inputs to convolutions.
    249 Status FoldOldBatchNorms(const GraphDef& input_graph_def,
    250                          const TransformFuncContext& context,
    251                          GraphDef* output_graph_def) {
    252   GraphDef current_graph_def = input_graph_def;
    253   // We have to do several passes to catch all the old BN nodes, since many of
    254   // them may share inputs and so be excluded from replacement in one pass.
    255   bool did_graph_change;
    256   do {
    257     did_graph_change = false;
    258     GraphDef replaced_graph_def;
    259     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    260         current_graph_def,  // clang-format off
    261       {"BatchNormWithGlobalNormalization|FusedBatchNorm",    // batch_norm_node
    262         {
    263           {"Conv2D",                          // conv_node
    264             {
    265               {"*"},                          // input_node
    266               {"Const"},                      // weights_node
    267             }
    268           },
    269           {"Const"},                          // mean_node
    270           {"Const"},                          // variance_node
    271           {"Const"},                          // beta_node
    272           {"Const"},                          // gamma_node
    273         }
    274       },  // clang-format on
    275         [&did_graph_change](const NodeMatch& match,
    276                             const std::set<string>& input_nodes,
    277                             const std::set<string>& output_nodes,
    278                             std::vector<NodeDef>* new_nodes) {
    279           TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes));
    280           did_graph_change = true;
    281           return Status::OK();
    282         },
    283         {}, &replaced_graph_def));
    284     current_graph_def = replaced_graph_def;
    285   } while (did_graph_change);
    286 
    287   do {
    288     did_graph_change = false;
    289     GraphDef replaced_graph_def;
    290     // Replace BatchNorm with concat as input.
    291     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
    292         current_graph_def,  // clang-format off
    293       {"BatchNormWithGlobalNormalization|FusedBatchNorm",    // batch_norm_node
    294         {
    295           {"ConcatV2|Concat",                     // concat two conv2d.
    296             {
    297               {"Conv2D",                          // conv_node
    298                 {
    299                   {"*"},                          // input_node
    300                   {"Const"},                      // weights_node
    301                 }
    302               },
    303               {"Conv2D",                          // conv_node
    304                 {
    305                   {"*"},                          // input_node
    306                   {"Const"},                      // weights_node
    307                 }
    308               },
    309               {"Const"},                          // axis
    310             },
    311           },
    312           {"Const"},                          // mean_node
    313           {"Const"},                          // variance_node
    314           {"Const"},                          // beta_node
    315           {"Const"},                          // gamma_node
    316         }
    317       },  // clang-format on
    318         [&did_graph_change](const NodeMatch& match,
    319                             const std::set<string>& input_nodes,
    320                             const std::set<string>& output_nodes,
    321                             std::vector<NodeDef>* new_nodes) {
    322           TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes));
    323           did_graph_change = true;
    324           return Status::OK();
    325         },
    326         {}, &replaced_graph_def));
    327     current_graph_def = replaced_graph_def;
    328   } while (did_graph_change);
    329 
    330   *output_graph_def = current_graph_def;
    331   return Status::OK();
    332 }
    333 
    334 REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms);
    335 
    336 }  // namespace graph_transforms
    337 }  // namespace tensorflow
    338