Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/cc/ops/const_op.h"
     19 #include "tensorflow/cc/ops/image_ops.h"
     20 #include "tensorflow/cc/ops/nn_ops.h"
     21 #include "tensorflow/cc/ops/sendrecv_ops.h"
     22 #include "tensorflow/cc/ops/standard_ops.h"
     23 #include "tensorflow/core/framework/tensor_testutil.h"
     24 #include "tensorflow/core/kernels/quantization_utils.h"
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/platform/test_benchmark.h"
     28 #include "tensorflow/core/public/session.h"
     29 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     30 
     31 namespace tensorflow {
     32 namespace graph_transforms {
     33 
     34 // Declare here, so we don't need a public header.
     35 Status QuantizeNodes(const GraphDef& input_graph_def,
     36                      const TransformFuncContext& context,
     37                      GraphDef* output_graph_def);
     38 Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
     39                                     const TransformFuncContext& context,
     40                                     GraphDef* output_graph_def);
     41 Status QuantizePlaceholders(const GraphDef& input_graph_def,
     42                             const TransformFuncContext& context,
     43                             GraphDef* output_graph_def);
     44 Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
     45                                      const TransformFuncContext& context,
     46                                      GraphDef* output_graph_def);
     47 Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
     48                                 const TransformFuncContext& context,
     49                                 GraphDef* output_graph_def);
     50 Status HoistFakeQuants(const GraphDef& input_graph_def,
     51                        const TransformFuncContext& context,
     52                        GraphDef* output_graph_def);
     53 Status MergeDuplicateNodes(const GraphDef& input_graph_def,
     54                            const TransformFuncContext& context,
     55                            GraphDef* output_graph_def);
     56 
     57 class QuantizeNodesTest : public ::testing::Test {
     58  protected:
     59   void TestTransformedVersusFloatGraph(
     60       const TransformFunc& transform_function, const GraphDef& float_graph_def,
     61       const std::vector<std::pair<string, Tensor>>& float_inputs,
     62       const std::vector<std::pair<string, Tensor>>& transformed_inputs,
     63       const std::vector<string>& output_names,
     64       const TransformFuncContext& in_context, double threshold,
     65       GraphDef* transformed_graph_def) {
     66     std::unique_ptr<Session> float_session(NewSession(SessionOptions()));
     67     TF_ASSERT_OK(float_session->Create(float_graph_def));
     68     std::vector<Tensor> float_outputs;
     69     TF_ASSERT_OK(
     70         float_session->Run(float_inputs, output_names, {}, &float_outputs));
     71 
     72     TransformFuncContext context(in_context);
     73     std::vector<string> input_names;
     74     for (const std::pair<const string&, const Tensor&> float_input :
     75          float_inputs) {
     76       context.input_names.push_back(float_input.first);
     77     }
     78 
     79     context.output_names = output_names;
     80     TF_ASSERT_OK(
     81         transform_function(float_graph_def, context, transformed_graph_def));
     82 
     83     std::unique_ptr<Session> transformed_session(NewSession(SessionOptions()));
     84     TF_ASSERT_OK(transformed_session->Create(*transformed_graph_def));
     85     std::vector<Tensor> transformed_outputs;
     86     TF_ASSERT_OK(transformed_session->Run(transformed_inputs, output_names, {},
     87                                           &transformed_outputs));
     88 
     89     const int output_count = output_names.size();
     90     EXPECT_EQ(output_count, float_outputs.size());
     91     EXPECT_EQ(output_count, transformed_outputs.size());
     92     for (int i = 0; i < output_count; ++i) {
     93       test::ExpectTensorNear<float>(float_outputs[i], transformed_outputs[i],
     94                                     threshold);
     95     }
     96   }
     97 
     98   void TestQuantizedVersusFloatGraph(
     99       const GraphDef& float_graph_def,
    100       const std::vector<std::pair<string, Tensor>>& inputs,
    101       const std::vector<string>& output_names) {
    102     GraphDef quantized_graph_def;
    103     TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def, inputs,
    104                                     inputs, output_names, {}, 1.0,
    105                                     &quantized_graph_def);
    106     // Reshape is not included here because it can be added as part of the
    107     // quantization process.
    108     const std::set<string> quantizable_ops = {
    109         "Add",   "BiasAdd",        "Concat",  "Conv2D",  "MatMul", "Relu",
    110         "Relu6", "ResizeBilinear", "AvgPool", "MaxPool", "Mul"};
    111     for (const NodeDef& node : quantized_graph_def.node()) {
    112       EXPECT_EQ(0, quantizable_ops.count(node.op()))
    113           << "Found quantizable node " << node.op() << " for node named "
    114           << node.name();
    115     }
    116   }
    117 
    118   void TestGraphWithInputRange(
    119       const GraphDef& float_graph_def,
    120       const std::vector<std::pair<string, Tensor>>& float_inputs,
    121       const std::vector<string>& output_names, float range_min,
    122       float range_max) {
    123     TransformFuncContext context;
    124     context.params["input_min"] = {strings::StrCat(range_min)};
    125     context.params["input_max"] = {strings::StrCat(range_max)};
    126 
    127     std::vector<std::pair<string, Tensor>> quantized_inputs;
    128     for (const std::pair<string, Tensor>& float_input : float_inputs) {
    129       const Tensor& float_tensor = float_input.second;
    130       Tensor quantized_tensor(DT_QUINT8, float_tensor.shape());
    131       FloatTensorToQuantizedInPlace<quint8>(float_tensor, range_min, range_max,
    132                                             &quantized_tensor);
    133       quantized_inputs.push_back({float_input.first, quantized_tensor});
    134     }
    135 
    136     GraphDef quantized_graph_def;
    137     TestTransformedVersusFloatGraph(
    138         QuantizeNodes, float_graph_def, float_inputs, quantized_inputs,
    139         output_names, context, 1.0, &quantized_graph_def);
    140   }
    141 
    142   void TestGraphWithFallbackRange(
    143       const GraphDef& float_graph_def,
    144       const std::vector<std::pair<string, Tensor>>& float_inputs,
    145       const std::vector<string>& output_names, float range_min, float range_max,
    146       GraphDef* quantized_graph_def) {
    147     TransformFuncContext context;
    148     context.params["fallback_min"] = {strings::StrCat(range_min)};
    149     context.params["fallback_max"] = {strings::StrCat(range_max)};
    150     TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def,
    151                                     float_inputs, float_inputs, output_names,
    152                                     context, 2.0, quantized_graph_def);
    153   }
    154 
    155   void TestIgnoreOps(std::initializer_list<string> ops_to_ignore) {
    156     auto root = tensorflow::Scope::NewRootScope();
    157     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    158 
    159     // A small helper to construct a Const op.
    160     auto const_op = [&](const string& name, const TensorShape& shape,
    161                         std::initializer_list<float> values) {
    162       Tensor tensor(DT_FLOAT, shape);
    163       test::FillValues<float>(&tensor, values);
    164       return Const(root.WithOpName(name), Input::Initializer(tensor));
    165     };
    166 
    167     // A simple graph with two different quantizable ops.
    168     int m = 1;
    169     int n = 1;
    170     int k = 1;
    171     Output a_op = const_op("a_op", {m, k}, {2});
    172     Output b_op = const_op("b_op", {k, n}, {3});
    173     Output c_op = const_op("c_op", {m, k}, {1});
    174     Output d_op = const_op("d_op", {k, n}, {4});
    175     Output mat_mul_op = MatMul(root.WithOpName("mat_mul_op"), a_op, b_op);
    176     Output mul_op = Mul(root.WithOpName("mul"), c_op, d_op);
    177 
    178     GraphDef float_graph_def;
    179     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    180 
    181     TransformFuncContext context;
    182     if (ops_to_ignore.size() > 0) {
    183       context.params["ignore_op"] = ops_to_ignore;
    184     }
    185 
    186     GraphDef quantized_graph_def;
    187     TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def, {}, {},
    188                                     {"mat_mul_op", "mul"}, context, 1.0,
    189                                     &quantized_graph_def);
    190 
    191     // Make sure the quantized graph still contains the op that should have
    192     // been ignored by QuantizeNodes.
    193     for (const string& op_name : ops_to_ignore) {
    194       bool exists_in_quantized_graph = false;
    195       for (const NodeDef& node : quantized_graph_def.node()) {
    196         if (node.op() == op_name) {
    197           exists_in_quantized_graph = true;
    198           break;
    199         }
    200       }
    201       EXPECT_TRUE(exists_in_quantized_graph)
    202           << "Op " << op_name
    203           << " should not have been replace by a quantized version";
    204     }
    205   }
    206 
    207   void TestQuantizeMatMul(int m, int n, int k,
    208                           const std::vector<float>& a_values,
    209                           const std::vector<float>& b_values) {
    210     auto root = tensorflow::Scope::NewRootScope();
    211     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    212 
    213     Tensor a_tensor(DT_FLOAT, TensorShape({m, k}));
    214     test::FillValues<float>(&a_tensor, a_values);
    215     Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor));
    216 
    217     Tensor b_tensor(DT_FLOAT, TensorShape({k, n}));
    218     test::FillValues<float>(&b_tensor, b_values);
    219     Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
    220 
    221     Output mat_mul_op = MatMul(root.WithOpName("mat_mul_op"), a_op, b_op);
    222 
    223     GraphDef float_graph_def;
    224     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    225 
    226     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"mat_mul_op"});
    227   }
    228 
    229   void TestQuantizeMatMulTiny() {
    230     // These tests are added to test the generate case where
    231     // min(matrix) == max(matrix), which used to cause problems.
    232     TestQuantizeMatMul(1, 1, 1, {2}, {3});
    233     TestQuantizeMatMul(1, 2, 1, {1}, {2, 3});
    234     TestQuantizeMatMul(1, 1, 2, {1, 1}, {1, 1});
    235     TestQuantizeMatMul(1, 1, 2, {0, 0}, {1, 1});
    236     // The general case.
    237     TestQuantizeMatMul(1, 1, 2, {1, 2}, {1, 2});
    238   }
    239 
    240   void TestQuantizeMatMulSmall() {
    241     TestQuantizeMatMul(2, 4, 3, {1, 2, 3, 4, 5, 6},
    242                        {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
    243   }
    244 
    245   void TestQuantizeMul() {
    246     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    247 
    248     std::vector<int64> x_shape({10, 100});
    249     const size_t x_num_elements = TensorShape(x_shape).num_elements();
    250     std::vector<float> x_values(x_num_elements);
    251     for (int i = 0; i < x_num_elements; ++i) {
    252       x_values[i] = (i % 256) / 256.0f;
    253     }
    254 
    255     std::vector<int64> y_shape({100});
    256     const size_t y_num_elements = TensorShape(y_shape).num_elements();
    257     std::vector<float> y_values(y_num_elements);
    258     for (int i = 0; i < y_num_elements; ++i) {
    259       y_values[i] = ((i + 23) % 123) - 50;
    260     }
    261 
    262     Scope root = Scope::NewRootScope();
    263 
    264     Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape));
    265     test::FillValues<float>(&x_float_tensor, x_values);
    266     Output x = Const(root.WithOpName("x"), Input::Initializer(x_float_tensor));
    267 
    268     Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape));
    269     test::FillValues<float>(&y_float_tensor, y_values);
    270     Output y = Const(root.WithOpName("y"), Input::Initializer(y_float_tensor));
    271 
    272     Mul mul = Mul(root.WithOpName("mul"), x, y);
    273 
    274     GraphDef float_graph_def;
    275     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    276 
    277     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"mul"});
    278   }
    279 
    280   void TestQuantizeAdd() {
    281     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    282 
    283     std::vector<int64> x_shape({10, 100});
    284     const size_t x_num_elements = TensorShape(x_shape).num_elements();
    285     std::vector<float> x_values(x_num_elements);
    286     for (int i = 0; i < x_num_elements; ++i) {
    287       x_values[i] = (i % 256) / 256.0f;
    288     }
    289 
    290     std::vector<int64> y_shape({100});
    291     const size_t y_num_elements = TensorShape(y_shape).num_elements();
    292     std::vector<float> y_values(y_num_elements);
    293     for (int i = 0; i < y_num_elements; ++i) {
    294       y_values[i] = ((i + 23) % 123) - 50;
    295     }
    296 
    297     Scope root = Scope::NewRootScope();
    298 
    299     Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape));
    300     test::FillValues<float>(&x_float_tensor, x_values);
    301     Output x = Const(root.WithOpName("x"), Input::Initializer(x_float_tensor));
    302 
    303     Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape));
    304     test::FillValues<float>(&y_float_tensor, y_values);
    305     Output y = Const(root.WithOpName("y"), Input::Initializer(y_float_tensor));
    306 
    307     Add add = Add(root.WithOpName("add"), x, y);
    308 
    309     GraphDef float_graph_def;
    310     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    311 
    312     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"add"});
    313   }
    314 
    315   void TestQuantizeConv2D(int depth, int input_width, int input_height,
    316                           int input_batch_count, int filter_size,
    317                           int filter_count, int stride, const string& padding,
    318                           const std::vector<float>& input_values,
    319                           const std::vector<float>& filter_values) {
    320     auto root = tensorflow::Scope::NewRootScope();
    321     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    322 
    323     Tensor input_tensor(DT_FLOAT, TensorShape({input_batch_count, input_height,
    324                                                input_width, depth}));
    325     test::FillValues<float>(&input_tensor, input_values);
    326     Output input_op =
    327         Const(root.WithOpName("input_op"), Input::Initializer(input_tensor));
    328 
    329     Tensor filter_tensor(
    330         DT_FLOAT, TensorShape({filter_size, filter_size, depth, filter_count}));
    331     test::FillValues<float>(&filter_tensor, filter_values);
    332     Output filter_op =
    333         Const(root.WithOpName("filter_op"), Input::Initializer(filter_tensor));
    334 
    335     Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, filter_op,
    336                             {1, stride, stride, 1}, padding);
    337 
    338     GraphDef float_graph_def;
    339     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    340 
    341     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"conv_op"});
    342   }
    343 
    344   void TestQuantizeBiasAdd() {
    345     auto root = tensorflow::Scope::NewRootScope();
    346     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    347 
    348     Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6}));
    349     test::FillIota<float>(&input_tensor, 1);
    350     Output input_op =
    351         Const(root.WithOpName("input_op"), Input::Initializer(input_tensor));
    352 
    353     Tensor offset_tensor(DT_FLOAT, TensorShape({6}));
    354     test::FillIota<float>(&offset_tensor, 1);
    355     Output offset_op =
    356         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
    357 
    358     Output bias_add_op =
    359         BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op);
    360 
    361     GraphDef float_graph_def;
    362     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    363 
    364     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"bias_add_op"});
    365   }
    366 
    367   void TestQuantizeConcat() {
    368     auto root = tensorflow::Scope::NewRootScope();
    369     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    370 
    371     Tensor shape_tensor(DT_INT32, TensorShape({}));
    372     test::FillValues<int32>(&shape_tensor, {0});
    373     Output shape_op =
    374         Const(root.WithOpName("shape_op"), Input::Initializer(shape_tensor));
    375 
    376     Tensor a_tensor(DT_FLOAT, TensorShape({2, 2, 3}));
    377     test::FillValues<float>(&a_tensor, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
    378     Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor));
    379 
    380     Tensor b_tensor(DT_FLOAT, TensorShape({2, 2, 3}));
    381     test::FillValues<float>(&b_tensor,
    382                             {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
    383     Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
    384 
    385     Output concat_op =
    386         Concat(root.WithOpName("concat_op"), {a_op, b_op}, shape_op);
    387 
    388     GraphDef float_graph_def;
    389     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    390 
    391     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"concat_op"});
    392   }
    393 
    394   void TestQuantizeRelu() {
    395     auto root = tensorflow::Scope::NewRootScope();
    396     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    397 
    398     Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1}));
    399     test::FillValues<float>(&constant_tensor,
    400                             {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
    401     Output constant_op = Const(root.WithOpName("constant_op"),
    402                                Input::Initializer(constant_tensor));
    403 
    404     Output relu_op = Relu(root.WithOpName("relu_op"), constant_op);
    405 
    406     GraphDef float_graph_def;
    407     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    408 
    409     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"relu_op"});
    410   }
    411 
    412   void TestQuantizeRelu6() {
    413     auto root = tensorflow::Scope::NewRootScope();
    414     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    415 
    416     Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1}));
    417     test::FillValues<float>(&constant_tensor,
    418                             {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
    419     Output constant_op = Const(root.WithOpName("constant_op"),
    420                                Input::Initializer(constant_tensor));
    421 
    422     Output relu6_op = Relu6(root.WithOpName("relu6_op"), constant_op);
    423 
    424     GraphDef float_graph_def;
    425     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    426 
    427     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"relu6_op"});
    428   }
    429 
    430   void TestQuantizeMaxPool() {
    431     auto root = tensorflow::Scope::NewRootScope();
    432     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    433 
    434     Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1}));
    435     test::FillValues<float>(&constant_tensor,
    436                             {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
    437     Output constant_op = Const(root.WithOpName("constant_op"),
    438                                Input::Initializer(constant_tensor));
    439 
    440     Output max_pool_op = MaxPool(root.WithOpName("max_pool_op"), constant_op,
    441                                  {1, 2, 2, 1}, {1, 1, 1, 1}, "SAME");
    442 
    443     GraphDef float_graph_def;
    444     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    445 
    446     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"max_pool_op"});
    447   }
    448 
    449   void TestQuantizeAvgPool() {
    450     auto root = tensorflow::Scope::NewRootScope();
    451     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    452 
    453     Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1}));
    454     test::FillValues<float>(&constant_tensor,
    455                             {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
    456     Output constant_op = Const(root.WithOpName("constant_op"),
    457                                Input::Initializer(constant_tensor));
    458 
    459     Output avg_pool_op = AvgPool(root.WithOpName("avg_pool_op"), constant_op,
    460                                  {1, 2, 2, 1}, {1, 1, 1, 1}, "SAME");
    461 
    462     GraphDef float_graph_def;
    463     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    464 
    465     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"avg_pool_op"});
    466   }
    467 
    468   void TestQuantizeReshape() {
    469     auto root = tensorflow::Scope::NewRootScope();
    470     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    471 
    472     Tensor constant_tensor(DT_FLOAT, TensorShape({4, 5}));
    473     test::FillValues<float>(&constant_tensor,
    474                             {1,  2,  3,  4,  5,  6,  7,  8,  9,  10,
    475                              11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
    476     Output constant_op = Const(root.WithOpName("constant_op"),
    477                                Input::Initializer(constant_tensor));
    478 
    479     Output reshape_op =
    480         Reshape(root.WithOpName("reshape_op"), constant_op, {10, 2});
    481 
    482     GraphDef float_graph_def;
    483     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    484 
    485     TestQuantizedVersusFloatGraph(float_graph_def, {}, {"reshape_op"});
    486   }
    487 
    488   void TestRemoveRedundantQuantization() {
    489     auto root = tensorflow::Scope::NewRootScope();
    490     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    491 
    492     Tensor quantized_tensor(DT_QUINT8, TensorShape({}));
    493     test::FillValues<quint8>(&quantized_tensor, {0});
    494     Output quantized_op = Const(root.WithOpName("quantized_op"),
    495                                 Input::Initializer(quantized_tensor));
    496 
    497     Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
    498     test::FillValues<float>(&quantized_min_tensor, {2.0f});
    499     Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
    500                                     Input::Initializer(quantized_min_tensor));
    501 
    502     Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
    503     test::FillValues<float>(&quantized_max_tensor, {2.0f});
    504     Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
    505                                     Input::Initializer(quantized_min_tensor));
    506 
    507     Output dequantize_op =
    508         Dequantize(root.WithOpName("dequantize_op"), quantized_op,
    509                    quantized_min_op, quantized_max_op);
    510 
    511     Tensor dequantize_reshape_dims_tensor(DT_INT32, TensorShape({1}));
    512     test::FillValues<int32>(&dequantize_reshape_dims_tensor, {-1});
    513     Output dequantize_reshape_dims =
    514         Const(root.WithOpName("dequantize_reshape_dims"),
    515               Input::Initializer(dequantize_reshape_dims_tensor));
    516 
    517     Tensor dequantize_reduction_dims_tensor(DT_INT32, TensorShape({}));
    518     test::FillValues<int32>(&dequantize_reduction_dims_tensor, {0});
    519     Output dequantize_reduction_dims =
    520         Const(root.WithOpName("dequantize_reduction_dims"),
    521               Input::Initializer(dequantize_reduction_dims_tensor));
    522 
    523     Output dequantize_reshape = Reshape(root.WithOpName("dequantize_reshape"),
    524                                         dequantize_op, dequantize_reshape_dims);
    525 
    526     Output dequantize_min =
    527         Min(root.WithOpName("dequantize_min"), dequantize_reshape,
    528             dequantize_reduction_dims, Min::Attrs().KeepDims(false));
    529 
    530     Output dequantize_max =
    531         Max(root.WithOpName("dequantize_max"), dequantize_reshape,
    532             dequantize_reduction_dims, Max::Attrs().KeepDims(false));
    533 
    534     QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op,
    535                            dequantize_min, dequantize_max, DT_QUINT8,
    536                            QuantizeV2::Attrs().Mode("MIN_FIRST"));
    537 
    538     Output final_dequantize =
    539         Dequantize(root.WithOpName("final_dequantize"), quantize_op.output,
    540                    quantize_op.output_min, quantize_op.output_max);
    541 
    542     GraphDef float_graph_def;
    543     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    544 
    545     GraphDef removed_graph_def;
    546     TestTransformedVersusFloatGraph(
    547         RemoveRedundantQuantizations, float_graph_def, {}, {},
    548         {"final_dequantize"}, {}, 1.0, &removed_graph_def);
    549 
    550     std::map<string, const NodeDef*> node_map;
    551     MapNamesToNodes(removed_graph_def, &node_map);
    552     EXPECT_EQ(1, node_map.count("final_dequantize"));
    553     EXPECT_EQ("quantized_op", node_map.at("final_dequantize")->input(0));
    554   }
    555 
    556   void TestRemoveRedundantQuantizationWithBiasAdd() {
    557     auto root = tensorflow::Scope::NewRootScope();
    558     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    559 
    560     Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
    561     test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
    562     Output quantized_op = Const(root.WithOpName("quantized_op"),
    563                                 Input::Initializer(quantized_tensor));
    564 
    565     Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
    566     test::FillValues<float>(&quantized_min_tensor, {2.0f});
    567     Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
    568                                     Input::Initializer(quantized_min_tensor));
    569 
    570     Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
    571     test::FillValues<float>(&quantized_max_tensor, {2.0f});
    572     Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
    573                                     Input::Initializer(quantized_min_tensor));
    574 
    575     Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
    576     test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
    577     Output offset_op =
    578         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
    579 
    580     Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
    581     test::FillValues<float>(&offset_min_tensor, {0.0f});
    582     Output offset_min_op = Const(root.WithOpName("offset_min_op"),
    583                                  Input::Initializer(offset_min_tensor));
    584 
    585     Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
    586     test::FillValues<float>(&offset_max_tensor, {255.0f});
    587     Output offset_max_op = Const(root.WithOpName("offset_max_op"),
    588                                  Input::Initializer(offset_max_tensor));
    589 
    590     QuantizedBiasAdd quantized_bias_add_op(
    591         root.WithOpName("bias_add_op"), quantized_op, offset_op,
    592         quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
    593         DT_QINT32);
    594 
    595     RequantizationRange requantization_range_op(
    596         root.WithOpName("requantization_range_op"),
    597         quantized_bias_add_op.output, quantized_bias_add_op.min_out,
    598         quantized_bias_add_op.max_out);
    599 
    600     Requantize requantize_op(
    601         root.WithOpName("requantize_op"), quantized_bias_add_op.output,
    602         quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
    603         requantization_range_op.output_min, requantization_range_op.output_max,
    604         DT_QUINT8);
    605 
    606     Output dequantize_op =
    607         Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
    608                    requantize_op.output_min, requantize_op.output_max);
    609 
    610     Tensor dequantize_reshape_dims_tensor(DT_INT32, TensorShape({1}));
    611     test::FillValues<int32>(&dequantize_reshape_dims_tensor, {-1});
    612     Output dequantize_reshape_dims =
    613         Const(root.WithOpName("dequantize_reshape_dims"),
    614               Input::Initializer(dequantize_reshape_dims_tensor));
    615 
    616     Tensor dequantize_reduction_dims_tensor(DT_INT32, TensorShape({}));
    617     test::FillValues<int32>(&dequantize_reduction_dims_tensor, {0});
    618     Output dequantize_reduction_dims =
    619         Const(root.WithOpName("dequantize_reduction_dims"),
    620               Input::Initializer(dequantize_reduction_dims_tensor));
    621 
    622     Output dequantize_reshape = Reshape(root.WithOpName("dequantize_reshape"),
    623                                         dequantize_op, dequantize_reshape_dims);
    624 
    625     Output dequantize_min =
    626         Min(root.WithOpName("dequantize_min"), dequantize_reshape,
    627             dequantize_reduction_dims, Min::Attrs().KeepDims(false));
    628 
    629     Output dequantize_max =
    630         Max(root.WithOpName("dequantize_max"), dequantize_reshape,
    631             dequantize_reduction_dims, Max::Attrs().KeepDims(false));
    632 
    633     QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op,
    634                            dequantize_min, dequantize_max, DT_QUINT8,
    635                            QuantizeV2::Attrs().Mode("MIN_FIRST"));
    636 
    637     Output final_dequantize =
    638         Dequantize(root.WithOpName("final_dequantize"), quantize_op.output,
    639                    quantize_op.output_min, quantize_op.output_max);
    640 
    641     GraphDef float_graph_def;
    642     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    643 
    644     GraphDef removed_graph_def;
    645     TestTransformedVersusFloatGraph(
    646         RemoveRedundantQuantizations, float_graph_def, {}, {},
    647         {"final_dequantize"}, {}, 1.0, &removed_graph_def);
    648 
    649     std::map<string, const NodeDef*> node_map;
    650     MapNamesToNodes(removed_graph_def, &node_map);
    651     EXPECT_EQ(1, node_map.count("final_dequantize"));
    652     EXPECT_EQ("requantize_op", node_map.at("final_dequantize")->input(0));
    653   }
    654 
    655   void TestQuantizeResizeBilinear() {
    656     auto root = tensorflow::Scope::NewRootScope();
    657     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    658 
    659     Tensor size_tensor(DT_INT32, TensorShape({2}));
    660     test::FillValues<int32>(&size_tensor, {256, 256});
    661 
    662     Output constant_op = Const(root.WithOpName("size_tensor_op"),
    663                                Input::Initializer(size_tensor));
    664 
    665     Output placeholder_op =
    666         Placeholder(root.WithOpName("placeholder_op"), DT_FLOAT);
    667 
    668     Output resize_bilinear_op = ResizeBilinear(
    669         root.WithOpName("resize_bilinear_op"), placeholder_op, constant_op);
    670 
    671     GraphDef float_graph_def;
    672     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    673 
    674     Tensor input_tensor(DT_FLOAT, {1, 128, 128, 3});
    675     test::FillFn<float>(&input_tensor, [](int) { return 100.0f; });
    676 
    677     TestQuantizedVersusFloatGraph(float_graph_def,
    678                                   {{"placeholder_op", input_tensor}},
    679                                   {"resize_bilinear_op"});
    680   }
    681 
    682   void TestRemoveRedundantQuantizationWithMultipleOutputs() {
    683     auto root = tensorflow::Scope::NewRootScope();
    684     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    685 
    686     Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
    687     test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
    688     Output quantized_op = Const(root.WithOpName("quantized_op"),
    689                                 Input::Initializer(quantized_tensor));
    690 
    691     Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
    692     test::FillValues<float>(&quantized_min_tensor, {2.0f});
    693     Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
    694                                     Input::Initializer(quantized_min_tensor));
    695 
    696     Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
    697     test::FillValues<float>(&quantized_max_tensor, {2.0f});
    698     Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
    699                                     Input::Initializer(quantized_min_tensor));
    700 
    701     Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
    702     test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
    703     Output offset_op =
    704         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
    705 
    706     Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
    707     test::FillValues<float>(&offset_min_tensor, {0.0f});
    708     Output offset_min_op = Const(root.WithOpName("offset_min_op"),
    709                                  Input::Initializer(offset_min_tensor));
    710 
    711     Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
    712     test::FillValues<float>(&offset_max_tensor, {255.0f});
    713     Output offset_max_op = Const(root.WithOpName("offset_max_op"),
    714                                  Input::Initializer(offset_max_tensor));
    715 
    716     QuantizedBiasAdd quantized_bias_add_op(
    717         root.WithOpName("bias_add_op"), quantized_op, offset_op,
    718         quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
    719         DT_QINT32);
    720 
    721     RequantizationRange requantization_range_op(
    722         root.WithOpName("requantization_range_op"),
    723         quantized_bias_add_op.output, quantized_bias_add_op.min_out,
    724         quantized_bias_add_op.max_out);
    725 
    726     Requantize requantize_op(
    727         root.WithOpName("requantize_op"), quantized_bias_add_op.output,
    728         quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
    729         requantization_range_op.output_min, requantization_range_op.output_max,
    730         DT_QUINT8);
    731 
    732     Output dequantize_op =
    733         Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
    734                    requantize_op.output_min, requantize_op.output_max);
    735 
    736     Tensor dequantize_reshape_dims_tensor(DT_INT32, TensorShape({1}));
    737     test::FillValues<int32>(&dequantize_reshape_dims_tensor, {-1});
    738     Output dequantize_reshape_dims =
    739         Const(root.WithOpName("dequantize_reshape_dims"),
    740               Input::Initializer(dequantize_reshape_dims_tensor));
    741 
    742     Tensor dequantize_reduction_dims_tensor(DT_INT32, TensorShape({}));
    743     test::FillValues<int32>(&dequantize_reduction_dims_tensor, {0});
    744     Output dequantize_reduction_dims =
    745         Const(root.WithOpName("dequantize_reduction_dims"),
    746               Input::Initializer(dequantize_reduction_dims_tensor));
    747 
    748     Output dequantize_reshape = Reshape(root.WithOpName("dequantize_reshape"),
    749                                         dequantize_op, dequantize_reshape_dims);
    750 
    751     Output dequantize_min =
    752         Min(root.WithOpName("dequantize_min"), dequantize_reshape,
    753             dequantize_reduction_dims, Min::Attrs().KeepDims(false));
    754 
    755     Output dequantize_max =
    756         Max(root.WithOpName("dequantize_max"), dequantize_reshape,
    757             dequantize_reduction_dims, Max::Attrs().KeepDims(false));
    758 
    759     QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op,
    760                            dequantize_min, dequantize_max, DT_QUINT8,
    761                            QuantizeV2::Attrs().Mode("MIN_FIRST"));
    762 
    763     Output final_dequantize =
    764         Dequantize(root.WithOpName("final_dequantize"), quantize_op.output,
    765                    quantize_op.output_min, quantize_op.output_max);
    766 
    767     Output relu_op = Relu(root.WithOpName("relu_op"), dequantize_op);
    768 
    769     GraphDef float_graph_def;
    770     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    771 
    772     GraphDef removed_graph_def;
    773     TestTransformedVersusFloatGraph(
    774         RemoveRedundantQuantizations, float_graph_def, {}, {},
    775         {"final_dequantize", "relu_op"}, {}, 1.0, &removed_graph_def);
    776 
    777     std::map<string, int> op_type_count;
    778     for (const NodeDef& node : removed_graph_def.node()) {
    779       ++op_type_count[node.op()];
    780     }
    781     EXPECT_EQ(2, op_type_count["Dequantize"]);
    782   }
    783 
    784   void TestQuantizePlaceholders() {
    785     auto root = tensorflow::Scope::NewRootScope();
    786     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    787 
    788     Output placeholder_op =
    789         Placeholder(root.WithOpName("placeholder_op"), DT_FLOAT);
    790 
    791     Output relu_op = Relu(root.WithOpName("relu_op"), placeholder_op);
    792 
    793     GraphDef float_graph_def;
    794     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    795 
    796     TransformFuncContext context;
    797     context.input_names = {"placeholder_op"};
    798     context.output_names = {"relu_op"};
    799     context.params = {{"input_min", {"-10.0"}}, {"input_max", {"10.0"}}};
    800 
    801     GraphDef quantized_graph_def;
    802     TF_ASSERT_OK(
    803         QuantizePlaceholders(float_graph_def, context, &quantized_graph_def));
    804 
    805     Tensor input_tensor(DT_FLOAT, {});
    806     input_tensor.flat<float>()(0) = 5.0f;
    807 
    808     TestQuantizedVersusFloatGraph(
    809         float_graph_def, {{"placeholder_op", input_tensor}}, {"relu_op"});
    810 
    811     std::map<string, const NodeDef*> node_map;
    812     MapNamesToNodes(quantized_graph_def, &node_map);
    813     EXPECT_NE("placeholder_op", node_map.at("relu_op")->input(0));
    814     EXPECT_EQ("Placeholder", node_map.at("placeholder_op")->op());
    815     EXPECT_EQ(DT_QUINT8,
    816               node_map.at("placeholder_op")->attr().at("dtype").type());
    817   }
    818 
    819   void TestInputRange() {
    820     auto root = tensorflow::Scope::NewRootScope();
    821     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    822 
    823     const int width = 100;
    824 
    825     Tensor a_data(DT_FLOAT, TensorShape({1, width}));
    826     test::FillIota<float>(&a_data, 1.0f);
    827     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    828 
    829     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    830 
    831     Output bias_add =
    832         BiasAdd(root.WithOpName("bias_add"), a_const, placeholder);
    833 
    834     GraphDef graph_def;
    835     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    836 
    837     Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
    838     test::FillIota<float>(&placeholder_tensor, 1.0f);
    839 
    840     TestGraphWithInputRange(graph_def, {{"placeholder", placeholder_tensor}},
    841                             {"bias_add"}, 0.0f, 100.0f);
    842   }
    843 
    844   void TestFallbackRange() {
    845     auto root = tensorflow::Scope::NewRootScope();
    846     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    847 
    848     const int width = 100;
    849 
    850     Tensor a_data(DT_FLOAT, TensorShape({1, width}));
    851     test::FillIota<float>(&a_data, 1.0f);
    852     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    853 
    854     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    855 
    856     Output bias_add =
    857         BiasAdd(root.WithOpName("bias_add"), a_const, placeholder);
    858 
    859     GraphDef graph_def;
    860     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    861 
    862     Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
    863     test::FillIota<float>(&placeholder_tensor, 1.0f);
    864 
    865     GraphDef quantized_graph_def;
    866     TestGraphWithFallbackRange(graph_def, {{"placeholder", placeholder_tensor}},
    867                                {"bias_add"}, 0.0f, 200.0f,
    868                                &quantized_graph_def);
    869 
    870     for (const NodeDef& node : quantized_graph_def.node()) {
    871       EXPECT_NE("RequantizationRange", node.op());
    872     }
    873   }
    874 
    875   void TestConvertFakeQuantsToRequantize() {
    876     auto root = tensorflow::Scope::NewRootScope();
    877     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    878 
    879     Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6}));
    880     test::FillIota<float>(&input_tensor, 1);
    881     Output input_op =
    882         Const(root.WithOpName("input_op"), Input::Initializer(input_tensor));
    883 
    884     Tensor offset_tensor(DT_FLOAT, TensorShape({6}));
    885     test::FillIota<float>(&offset_tensor, 1);
    886     Output offset_op =
    887         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
    888 
    889     Output bias_add_op =
    890         BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op);
    891 
    892     Tensor fake_quant_min_tensor(DT_FLOAT, TensorShape({}));
    893     test::FillValues<float>(&fake_quant_min_tensor, {0.0f});
    894     Output fake_quant_min_op = Const(root.WithOpName("fake_quant_min_op"),
    895                                      Input::Initializer(fake_quant_min_tensor));
    896 
    897     Tensor fake_quant_max_tensor(DT_FLOAT, TensorShape({}));
    898     test::FillValues<float>(&fake_quant_max_tensor, {18.0f});
    899     Output fake_quant_max_op = Const(root.WithOpName("fake_quant_max_op"),
    900                                      Input::Initializer(fake_quant_max_tensor));
    901 
    902     Output fake_quant_op =
    903         FakeQuantWithMinMaxVars(root.WithOpName("fake_quant_op"), bias_add_op,
    904                                 fake_quant_min_op, fake_quant_max_op);
    905 
    906     GraphDef float_graph_def;
    907     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
    908 
    909     GraphDef converted_graph_def;
    910     TestTransformedVersusFloatGraph(ConvertFakeQuantsToRequantize,
    911                                     float_graph_def, {}, {}, {"fake_quant_op"},
    912                                     {}, 1.0, &converted_graph_def);
    913 
    914     for (const NodeDef& node : converted_graph_def.node()) {
    915       EXPECT_NE("FakeQuantWithMinMaxVars", node.op());
    916     }
    917   }
    918 
    919   void TestMergeAdjacentRequantizes() {
    920     auto root = tensorflow::Scope::NewRootScope();
    921     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    922 
    923     Tensor input_tensor(DT_QUINT8, TensorShape({1, 1, 2, 6}));
    924     test::FillValues<quint8>(&input_tensor,
    925                              {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
    926     Output input_op =
    927         Const(root.WithOpName("input_op"), Input::Initializer(input_tensor));
    928 
    929     Tensor input_min_tensor(DT_FLOAT, TensorShape({}));
    930     test::FillValues<float>(&input_min_tensor, {0.0f});
    931     Output input_min_op = Const(root.WithOpName("input_min_op"),
    932                                 Input::Initializer(input_min_tensor));
    933 
    934     Tensor input_max_tensor(DT_FLOAT, TensorShape({}));
    935     test::FillValues<float>(&input_max_tensor, {255.0f});
    936     Output input_max_op = Const(root.WithOpName("input_max_op"),
    937                                 Input::Initializer(input_max_tensor));
    938 
    939     Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
    940     test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
    941     Output offset_op =
    942         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
    943 
    944     Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
    945     test::FillValues<float>(&offset_min_tensor, {0.0f});
    946     Output offset_min_op = Const(root.WithOpName("offset_min_op"),
    947                                  Input::Initializer(offset_min_tensor));
    948 
    949     Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
    950     test::FillValues<float>(&offset_max_tensor, {255.0f});
    951     Output offset_max_op = Const(root.WithOpName("offset_max_op"),
    952                                  Input::Initializer(offset_max_tensor));
    953 
    954     QuantizedBiasAdd quantized_bias_add_op(
    955         root.WithOpName("quantized_bias_add_op"), input_op, offset_op,
    956         input_min_op, input_max_op, offset_min_op, offset_max_op, DT_QINT32);
    957 
    958     RequantizationRange requantization_range_op(
    959         root.WithOpName("requantization_range_op"),
    960         quantized_bias_add_op.output, quantized_bias_add_op.min_out,
    961         quantized_bias_add_op.max_out);
    962 
    963     Requantize requantize_op(
    964         root.WithOpName("requantize_op"), quantized_bias_add_op.output,
    965         quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
    966         requantization_range_op.output_min, requantization_range_op.output_max,
    967         DT_QUINT8);
    968 
    969     Output dequantize_op =
    970         Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
    971                    requantize_op.output_min, requantize_op.output_max,
    972                    Dequantize::Attrs().Mode("MIN_FIRST"));
    973 
    974     Tensor quantize_min_tensor(DT_FLOAT, TensorShape({}));
    975     test::FillValues<float>(&quantize_min_tensor, {0.0f});
    976     Output quantize_min_op = Const(root.WithOpName("quantize_min_op"),
    977                                    Input::Initializer(quantize_min_tensor));
    978 
    979     Tensor quantize_max_tensor(DT_FLOAT, TensorShape({}));
    980     test::FillValues<float>(&quantize_max_tensor, {255.0f});
    981     Output quantize_max_op = Const(root.WithOpName("quantize_max_op"),
    982                                    Input::Initializer(quantize_max_tensor));
    983 
    984     QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op,
    985                            quantize_min_op, quantize_max_op, DT_QINT32,
    986                            QuantizeV2::Attrs().Mode("MIN_FIRST"));
    987 
    988     Tensor fake_requantize_min_tensor(DT_FLOAT, TensorShape({}));
    989     test::FillValues<float>(&fake_requantize_min_tensor, {0.0f});
    990     Output fake_requantize_min_op =
    991         Const(root.WithOpName("fake_requantize_min_op"),
    992               Input::Initializer(fake_requantize_min_tensor));
    993 
    994     Tensor fake_requantize_max_tensor(DT_FLOAT, TensorShape({}));
    995     test::FillValues<float>(&fake_requantize_max_tensor, {255.0f});
    996     Output fake_requantize_max_op =
    997         Const(root.WithOpName("fake_requantize_max_op"),
    998               Input::Initializer(fake_requantize_max_tensor));
    999 
   1000     Requantize fake_requantize_op(
   1001         root.WithOpName("fake_requantize_op"), quantize_op.output,
   1002         quantize_op.output_min, quantize_op.output_max, fake_requantize_min_op,
   1003         fake_requantize_max_op, DT_QUINT8);
   1004 
   1005     Output fake_dequantize_op = Dequantize(
   1006         root.WithOpName("fake_dequantize_op"), fake_requantize_op.output,
   1007         fake_requantize_op.output_min, fake_requantize_op.output_max,
   1008         Dequantize::Attrs().Mode("MIN_FIRST"));
   1009 
   1010     GraphDef float_graph_def;
   1011     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1012 
   1013     GraphDef converted_graph_def;
   1014     TestTransformedVersusFloatGraph(MergeAdjacentRequantizes, float_graph_def,
   1015                                     {}, {}, {"fake_dequantize_op"}, {}, 1.0,
   1016                                     &converted_graph_def);
   1017 
   1018     int requantize_count = 0;
   1019     for (const NodeDef& node : converted_graph_def.node()) {
   1020       if (node.op() == "Requantize") {
   1021         ++requantize_count;
   1022       }
   1023     }
   1024     EXPECT_EQ(1, requantize_count);
   1025   }
   1026 
   1027   void TestConvertFakeQuantsEndToEnd() {
   1028     auto root = tensorflow::Scope::NewRootScope();
   1029     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1030 
   1031     Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6}));
   1032     test::FillIota<float>(&input_tensor, 1);
   1033     Output input_op =
   1034         Const(root.WithOpName("input_op"), Input::Initializer(input_tensor));
   1035 
   1036     Tensor offset_tensor(DT_FLOAT, TensorShape({6}));
   1037     test::FillIota<float>(&offset_tensor, 1);
   1038     Output offset_op =
   1039         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
   1040 
   1041     Output bias_add_op =
   1042         BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op);
   1043 
   1044     Tensor fake_quant_min_tensor(DT_FLOAT, TensorShape({}));
   1045     test::FillValues<float>(&fake_quant_min_tensor, {0.0f});
   1046     Output fake_quant_min_op = Const(root.WithOpName("fake_quant_min_op"),
   1047                                      Input::Initializer(fake_quant_min_tensor));
   1048 
   1049     Tensor fake_quant_max_tensor(DT_FLOAT, TensorShape({}));
   1050     test::FillValues<float>(&fake_quant_max_tensor, {18.0f});
   1051     Output fake_quant_max_op = Const(root.WithOpName("fake_quant_max_op"),
   1052                                      Input::Initializer(fake_quant_max_tensor));
   1053 
   1054     Output fake_quant_op =
   1055         FakeQuantWithMinMaxVars(root.WithOpName("fake_quant_op"), bias_add_op,
   1056                                 fake_quant_min_op, fake_quant_max_op);
   1057 
   1058     GraphDef float_graph_def;
   1059     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1060 
   1061     GraphDef converted_graph_def;
   1062     TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def, {}, {},
   1063                                     {"fake_quant_op"}, {}, 1.0,
   1064                                     &converted_graph_def);
   1065 
   1066     int requantize_count = 0;
   1067     for (const NodeDef& node : converted_graph_def.node()) {
   1068       EXPECT_NE("FakeQuantWithMinMaxVars", node.op());
   1069       if (node.op() == "Requantize") {
   1070         ++requantize_count;
   1071       }
   1072     }
   1073     EXPECT_EQ(1, requantize_count);
   1074   }
   1075 
   1076   void TestHoistFakeQuants() {
   1077     auto root = tensorflow::Scope::NewRootScope();
   1078     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1079 
   1080     Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6}));
   1081     test::FillIota<float>(&input_tensor, 1);
   1082     Output input_op =
   1083         Const(root.WithOpName("input_op"), Input::Initializer(input_tensor));
   1084 
   1085     Tensor offset_tensor(DT_FLOAT, TensorShape({6}));
   1086     test::FillIota<float>(&offset_tensor, 1);
   1087     Output offset_op =
   1088         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
   1089 
   1090     Output bias_add_op =
   1091         BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op);
   1092 
   1093     Output relu_op = Relu(root.WithOpName("relu_op"), bias_add_op);
   1094 
   1095     Output max_pool_op = MaxPool(root.WithOpName("max_pool_op"), relu_op,
   1096                                  {1, 2, 2, 1}, {1, 1, 1, 1}, "SAME");
   1097 
   1098     Tensor fake_quant_min_tensor(DT_FLOAT, TensorShape({}));
   1099     test::FillValues<float>(&fake_quant_min_tensor, {0.0f});
   1100     Output fake_quant_min_op = Const(root.WithOpName("fake_quant_min_op"),
   1101                                      Input::Initializer(fake_quant_min_tensor));
   1102 
   1103     Tensor fake_quant_max_tensor(DT_FLOAT, TensorShape({}));
   1104     test::FillValues<float>(&fake_quant_max_tensor, {18.0f});
   1105     Output fake_quant_max_op = Const(root.WithOpName("fake_quant_max_op"),
   1106                                      Input::Initializer(fake_quant_max_tensor));
   1107 
   1108     Output fake_quant_op =
   1109         FakeQuantWithMinMaxVars(root.WithOpName("fake_quant_op"), max_pool_op,
   1110                                 fake_quant_min_op, fake_quant_max_op);
   1111 
   1112     GraphDef float_graph_def;
   1113     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1114 
   1115     GraphDef converted_graph_def;
   1116     TestTransformedVersusFloatGraph(HoistFakeQuants, float_graph_def, {}, {},
   1117                                     {"fake_quant_op"}, {}, 1.0,
   1118                                     &converted_graph_def);
   1119 
   1120     std::map<string, const NodeDef*> node_map;
   1121     MapNamesToNodes(converted_graph_def, &node_map);
   1122     EXPECT_EQ("MaxPool", node_map.at("fake_quant_op")->op());
   1123     EXPECT_EQ("FakeQuantWithMinMaxVars",
   1124               node_map.at(node_map.at("relu_op")->input(0))->op());
   1125   }
   1126 
   1127   void TestMergeDuplicateQuantizes() {
   1128     auto root = tensorflow::Scope::NewRootScope();
   1129     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1130 
   1131     Tensor quantized_tensor(DT_QUINT8, TensorShape({}));
   1132     test::FillValues<quint8>(&quantized_tensor, {0});
   1133     Output quantized_op = Const(root.WithOpName("quantized_op"),
   1134                                 Input::Initializer(quantized_tensor));
   1135 
   1136     Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
   1137     test::FillValues<float>(&quantized_min_tensor, {2.0f});
   1138     Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
   1139                                     Input::Initializer(quantized_min_tensor));
   1140 
   1141     Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
   1142     test::FillValues<float>(&quantized_max_tensor, {2.0f});
   1143     Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
   1144                                     Input::Initializer(quantized_min_tensor));
   1145 
   1146     Output dequantize_op =
   1147         Dequantize(root.WithOpName("dequantize_op"), quantized_op,
   1148                    quantized_min_op, quantized_max_op);
   1149 
   1150     Tensor quantize_reshape_dims1_tensor(DT_INT32, TensorShape({1}));
   1151     test::FillValues<int32>(&quantize_reshape_dims1_tensor, {-1});
   1152     Output quantize_reshape_dims1 =
   1153         Const(root.WithOpName("dequantize_reshape_dims1"),
   1154               Input::Initializer(quantize_reshape_dims1_tensor));
   1155 
   1156     Tensor quantize_reduction_dims1_tensor(DT_INT32, TensorShape({}));
   1157     test::FillValues<int32>(&quantize_reduction_dims1_tensor, {0});
   1158     Output quantize_reduction_dims1 =
   1159         Const(root.WithOpName("quantize_reduction_dims1"),
   1160               Input::Initializer(quantize_reduction_dims1_tensor));
   1161 
   1162     Output quantize_reshape1 = Reshape(root.WithOpName("quantize_reshape1"),
   1163                                        dequantize_op, quantize_reshape_dims1);
   1164 
   1165     Output quantize_min1 =
   1166         Min(root.WithOpName("quantize_min1"), quantize_reshape1,
   1167             quantize_reduction_dims1, Min::Attrs().KeepDims(false));
   1168 
   1169     Output quantize_max1 =
   1170         Max(root.WithOpName("quantize_max1"), quantize_reshape1,
   1171             quantize_reduction_dims1, Max::Attrs().KeepDims(false));
   1172 
   1173     QuantizeV2 quantize_op1(root.WithOpName("quantize_op1"), dequantize_op,
   1174                             quantize_min1, quantize_max1, DT_QUINT8,
   1175                             QuantizeV2::Attrs().Mode("MIN_FIRST"));
   1176 
   1177     Tensor quantize_reshape_dims2_tensor(DT_INT32, TensorShape({1}));
   1178     test::FillValues<int32>(&quantize_reshape_dims2_tensor, {-1});
   1179     Output quantize_reshape_dims2 =
   1180         Const(root.WithOpName("dequantize_reshape_dims2"),
   1181               Input::Initializer(quantize_reshape_dims2_tensor));
   1182 
   1183     Tensor quantize_reduction_dims2_tensor(DT_INT32, TensorShape({}));
   1184     test::FillValues<int32>(&quantize_reduction_dims2_tensor, {0});
   1185     Output quantize_reduction_dims2 =
   1186         Const(root.WithOpName("quantize_reduction_dims2"),
   1187               Input::Initializer(quantize_reduction_dims2_tensor));
   1188 
   1189     Output quantize_reshape2 = Reshape(root.WithOpName("quantize_reshape2"),
   1190                                        dequantize_op, quantize_reshape_dims2);
   1191 
   1192     Output quantize_min2 =
   1193         Min(root.WithOpName("quantize_min2"), quantize_reshape2,
   1194             quantize_reduction_dims2, Min::Attrs().KeepDims(false));
   1195 
   1196     Output quantize_max2 =
   1197         Max(root.WithOpName("quantize_max2"), quantize_reshape2,
   1198             quantize_reduction_dims2, Max::Attrs().KeepDims(false));
   1199 
   1200     QuantizeV2 quantize_op2(root.WithOpName("quantize_op2"), dequantize_op,
   1201                             quantize_min1, quantize_max1, DT_QUINT8,
   1202                             QuantizeV2::Attrs().Mode("MIN_FIRST"));
   1203 
   1204     Output final_dequantize1 =
   1205         Dequantize(root.WithOpName("final_dequantize1"), quantize_op1.output,
   1206                    quantize_op1.output_min, quantize_op1.output_max);
   1207 
   1208     Output final_dequantize2 =
   1209         Dequantize(root.WithOpName("final_dequantize2"), quantize_op2.output,
   1210                    quantize_op2.output_min, quantize_op2.output_max);
   1211 
   1212     Output add_op =
   1213         Add(root.WithOpName("add_op"), final_dequantize1, final_dequantize2);
   1214 
   1215     GraphDef float_graph_def;
   1216     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1217 
   1218     GraphDef merged_graph_def;
   1219     TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, {},
   1220                                     {}, {"add_op"}, {}, 1.0, &merged_graph_def);
   1221 
   1222     std::map<string, int> op_map;
   1223     for (const NodeDef& node : merged_graph_def.node()) {
   1224       ++op_map[node.op()];
   1225     }
   1226     EXPECT_EQ(1, op_map["QuantizeV2"]);
   1227   }
   1228 
   1229   void TestMergeDuplicateConsts() {
   1230     auto root = tensorflow::Scope::NewRootScope();
   1231     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1232 
   1233     const int width = 10;
   1234 
   1235     Tensor a_tensor(DT_FLOAT, TensorShape({width}));
   1236     test::FillIota<float>(&a_tensor, 1.0f);
   1237     Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor));
   1238 
   1239     Tensor b_tensor(DT_FLOAT, TensorShape({width}));
   1240     test::FillIota<float>(&b_tensor, 1.0f);
   1241     Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
   1242 
   1243     Output add_op = Add(root.WithOpName("add_op"), a_op, b_op);
   1244 
   1245     Tensor c_tensor(DT_FLOAT, TensorShape({width}));
   1246     test::FillIota<float>(&c_tensor, 2.0f);
   1247     Output c_op = Const(root.WithOpName("c_op"), Input::Initializer(c_tensor));
   1248 
   1249     Output mul_op = Mul(root.WithOpName("mul_op"), add_op, c_op);
   1250 
   1251     GraphDef float_graph_def;
   1252     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1253 
   1254     GraphDef merged_graph_def;
   1255     TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, {},
   1256                                     {}, {"mul_op"}, {}, 1.0, &merged_graph_def);
   1257 
   1258     std::map<string, const NodeDef*> node_map;
   1259     MapNamesToNodes(merged_graph_def, &node_map);
   1260     EXPECT_EQ(1, (node_map.count("a_op") + node_map.count("b_op")));
   1261     string remaining_const;
   1262     if (node_map.count("a_op")) {
   1263       remaining_const = "a_op";
   1264     } else {
   1265       remaining_const = "b_op";
   1266     }
   1267     EXPECT_EQ(remaining_const, node_map["add_op"]->input(0));
   1268     EXPECT_EQ(remaining_const, node_map["add_op"]->input(1));
   1269     EXPECT_EQ(1, node_map.count("c_op"));
   1270     EXPECT_EQ("add_op", node_map["mul_op"]->input(0));
   1271     EXPECT_EQ("c_op", node_map["mul_op"]->input(1));
   1272   }
   1273 
   1274   void TestMergeDuplicatesNested() {
   1275     auto root = tensorflow::Scope::NewRootScope();
   1276     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1277 
   1278     const int width = 10;
   1279 
   1280     Tensor a_tensor(DT_FLOAT, TensorShape({width}));
   1281     test::FillIota<float>(&a_tensor, 1.0f);
   1282     Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor));
   1283 
   1284     Output a_relu_op = Relu(root.WithOpName("a_relu_op"), a_op);
   1285 
   1286     Tensor b_tensor(DT_FLOAT, TensorShape({width}));
   1287     test::FillIota<float>(&b_tensor, 1.0f);
   1288     Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
   1289 
   1290     Output b_relu_op = Relu(root.WithOpName("b_relu_op"), b_op);
   1291 
   1292     Output add_op = Add(root.WithOpName("add_op"), a_relu_op, b_relu_op);
   1293 
   1294     Tensor c_tensor(DT_FLOAT, TensorShape({width}));
   1295     test::FillIota<float>(&c_tensor, 2.0f);
   1296     Output c_op = Const(root.WithOpName("c_op"), Input::Initializer(c_tensor));
   1297 
   1298     Output mul_op = Mul(root.WithOpName("mul_op"), add_op, c_op);
   1299 
   1300     GraphDef float_graph_def;
   1301     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1302 
   1303     GraphDef merged_graph_def;
   1304     TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, {},
   1305                                     {}, {"mul_op"}, {}, 1.0, &merged_graph_def);
   1306 
   1307     std::map<string, const NodeDef*> node_map;
   1308     MapNamesToNodes(merged_graph_def, &node_map);
   1309     EXPECT_EQ(1, (node_map.count("a_op") + node_map.count("b_op")));
   1310     EXPECT_EQ(1, (node_map.count("a_relu_op") + node_map.count("b_relu_op")));
   1311     string remaining_relu;
   1312     if (node_map.count("a_relu_op")) {
   1313       remaining_relu = "a_relu_op";
   1314     } else {
   1315       remaining_relu = "b_relu_op";
   1316     }
   1317     EXPECT_EQ(remaining_relu, node_map["add_op"]->input(0));
   1318     EXPECT_EQ(remaining_relu, node_map["add_op"]->input(1));
   1319     EXPECT_EQ(1, node_map.count("c_op"));
   1320     EXPECT_EQ("add_op", node_map["mul_op"]->input(0));
   1321     EXPECT_EQ("c_op", node_map["mul_op"]->input(1));
   1322   }
   1323 
   1324   void TestMergeDuplicatesInOut() {
   1325     auto root = tensorflow::Scope::NewRootScope();
   1326     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1327 
   1328     const int width = 10;
   1329 
   1330     Tensor a_tensor(DT_FLOAT, TensorShape({width}));
   1331     test::FillIota<float>(&a_tensor, 1.0f);
   1332     Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor));
   1333 
   1334     Output a_relu_op = Relu(root.WithOpName("a_relu_op"), a_op);
   1335 
   1336     Tensor b_tensor(DT_FLOAT, TensorShape({width}));
   1337     test::FillIota<float>(&b_tensor, 1.0f);
   1338     Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
   1339 
   1340     Output b_relu_op = Relu(root.WithOpName("b_relu_op"), b_op);
   1341 
   1342     Output add_op = Add(root.WithOpName("add_op"), a_relu_op, b_relu_op);
   1343 
   1344     Tensor c_tensor(DT_FLOAT, TensorShape({width}));
   1345     test::FillIota<float>(&c_tensor, 2.0f);
   1346     Output c_op = Const(root.WithOpName("c_op"), Input::Initializer(c_tensor));
   1347 
   1348     Output mul_op1 = Mul(root.WithOpName("mul_op1"), add_op, c_op);
   1349     Output mul_op2 = Mul(root.WithOpName("mul_op2"), add_op, c_op);
   1350     Output mul_op3 = Mul(root.WithOpName("mul_op3"), add_op, c_op);
   1351 
   1352     Output final_mul_op =
   1353         Mul(root.WithOpName("final_mul_op"), mul_op2, mul_op3);
   1354 
   1355     GraphDef float_graph_def;
   1356     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1357 
   1358     GraphDef merged_graph_def;
   1359     TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def,
   1360                                     {{"a_op", a_tensor}}, {{"a_op", a_tensor}},
   1361                                     {"mul_op1", "final_mul_op"}, {}, 1.0,
   1362                                     &merged_graph_def);
   1363 
   1364     std::map<string, const NodeDef*> node_map;
   1365     MapNamesToNodes(merged_graph_def, &node_map);
   1366     EXPECT_EQ(1, node_map.count("a_op"));
   1367     EXPECT_EQ(1, node_map.count("b_op"));
   1368     EXPECT_EQ(1, node_map.count("a_relu_op"));
   1369     EXPECT_EQ(1, node_map.count("b_relu_op"));
   1370     EXPECT_EQ(1, node_map.count("mul_op1"));
   1371     EXPECT_EQ(1, node_map.count("final_mul_op"));
   1372     EXPECT_EQ(1, (node_map.count("mul_op2") + node_map.count("mul_op3")));
   1373     string remaining_mul;
   1374     if (node_map.count("mul_op2")) {
   1375       remaining_mul = "mul_op2";
   1376     } else {
   1377       remaining_mul = "mul_op3";
   1378     }
   1379     EXPECT_EQ(remaining_mul, node_map["final_mul_op"]->input(0));
   1380     EXPECT_EQ(remaining_mul, node_map["final_mul_op"]->input(1));
   1381     EXPECT_EQ(1, node_map.count("c_op"));
   1382     EXPECT_EQ("add_op", node_map["mul_op1"]->input(0));
   1383     EXPECT_EQ("c_op", node_map["mul_op1"]->input(1));
   1384   }
   1385 
   1386   void TestExcludeNonFloat() {
   1387     auto root = tensorflow::Scope::NewRootScope();
   1388     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
   1389 
   1390     Tensor int_constant_tensor(DT_INT32, TensorShape({4, 5}));
   1391     test::FillIota<int32>(&int_constant_tensor, 1);
   1392     Output int_constant = Const(root.WithOpName("int_constant"),
   1393                                 Input::Initializer(int_constant_tensor));
   1394 
   1395     Tensor float_constant_tensor(DT_FLOAT, TensorShape({4, 5}));
   1396     test::FillIota<float>(&float_constant_tensor, 2.0f);
   1397     Output float_constant = Const(root.WithOpName("float_constant"),
   1398                                   Input::Initializer(float_constant_tensor));
   1399 
   1400     Output excluded_reshape_op =
   1401         Reshape(root.WithOpName("excluded_reshape_op"), int_constant, {10, 2});
   1402 
   1403     Output included_reshape_op = Reshape(root.WithOpName("included_reshape_op"),
   1404                                          float_constant, {10, 2});
   1405 
   1406     Output excluded_relu_op =
   1407         Relu(root.WithOpName("excluded_relu_op"), excluded_reshape_op);
   1408 
   1409     Output excluded_float_caster = Cast(
   1410         root.WithOpName("excluded_float_caster"), excluded_relu_op, DT_FLOAT);
   1411 
   1412     Output included_relu_op =
   1413         Relu(root.WithOpName("included_relu_op"), included_reshape_op);
   1414 
   1415     GraphDef float_graph_def;
   1416     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
   1417 
   1418     GraphDef quantized_graph_def;
   1419     TestTransformedVersusFloatGraph(
   1420         QuantizeNodes, float_graph_def, {}, {},
   1421         {"excluded_float_caster", "included_relu_op"}, {}, 1.0,
   1422         &quantized_graph_def);
   1423 
   1424     std::map<string, const NodeDef*> node_map;
   1425     MapNamesToNodes(quantized_graph_def, &node_map);
   1426     ASSERT_EQ(1, node_map.count("excluded_reshape_op"));
   1427     EXPECT_EQ("Reshape", node_map.at("excluded_reshape_op")->op());
   1428     ASSERT_EQ(1, node_map.count("included_reshape_op"));
   1429     EXPECT_EQ("Dequantize", node_map.at("included_reshape_op")->op());
   1430   }
   1431 };
   1432 
   1433 TEST_F(QuantizeNodesTest, TestIgnoreOps) {
   1434   TestIgnoreOps({});
   1435   TestIgnoreOps({"MatMul"});
   1436   TestIgnoreOps({"MatMul", "Mul"});
   1437 }
   1438 
   1439 TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); }
   1440 
   1441 TEST_F(QuantizeNodesTest, TestQuantizeMatMulSmall) {
   1442   TestQuantizeMatMulSmall();
   1443 }
   1444 
   1445 TEST_F(QuantizeNodesTest, TestQuantizeMul) { TestQuantizeMul(); }
   1446 
   1447 TEST_F(QuantizeNodesTest, TestQuantizeAdd) { TestQuantizeAdd(); }
   1448 
   1449 TEST_F(QuantizeNodesTest, TestOddPaddingProblem) {
   1450   // Tests one error case we ran into in a real graph.
   1451   TestQuantizeConv2D(1, 4, 4, 1, 3, 1, 2, "SAME",
   1452                      {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   1453                      {1, 2, 3, 4, 5, 6, 7, 8, 9});
   1454 }
   1455 
   1456 TEST_F(QuantizeNodesTest, TestQuantizeConv2D) {
   1457   TestQuantizeConv2D(1, 4, 3, 1, 3, 1, 1, "SAME",
   1458                      {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
   1459                      {1, 4, 7, 2, 5, 8, 3, 6, 9});
   1460 }
   1461 
   1462 TEST_F(QuantizeNodesTest, TestQuantizeBiasAdd) { TestQuantizeBiasAdd(); }
   1463 
   1464 TEST_F(QuantizeNodesTest, TestQuantizeConcat) { TestQuantizeConcat(); }
   1465 
   1466 TEST_F(QuantizeNodesTest, TestQuantizeRelu) { TestQuantizeRelu(); }
   1467 
   1468 TEST_F(QuantizeNodesTest, TestQuantizeRelu6) { TestQuantizeRelu6(); }
   1469 
   1470 TEST_F(QuantizeNodesTest, TestQuantizeMaxPool) { TestQuantizeMaxPool(); }
   1471 
   1472 TEST_F(QuantizeNodesTest, TestQuantizeAvgPool) { TestQuantizeAvgPool(); }
   1473 
   1474 TEST_F(QuantizeNodesTest, TestQuantizeReshape) { TestQuantizeReshape(); }
   1475 
   1476 TEST_F(QuantizeNodesTest, TestQuantizeResizeBilinear) {
   1477   TestQuantizeResizeBilinear();
   1478 }
   1479 
   1480 TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantization) {
   1481   TestRemoveRedundantQuantization();
   1482 }
   1483 
   1484 TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantizationWithBiasAdd) {
   1485   TestRemoveRedundantQuantizationWithBiasAdd();
   1486 }
   1487 
   1488 TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantizationWithMultipleOutputs) {
   1489   TestRemoveRedundantQuantizationWithMultipleOutputs();
   1490 }
   1491 
   1492 TEST_F(QuantizeNodesTest, TestQuantizePlaceholders) {
   1493   TestQuantizePlaceholders();
   1494 }
   1495 
   1496 TEST_F(QuantizeNodesTest, TestInputRange) { TestInputRange(); }
   1497 
   1498 TEST_F(QuantizeNodesTest, TestFallbackRange) { TestFallbackRange(); }
   1499 
   1500 TEST_F(QuantizeNodesTest, TestConvertFakeQuantsToRequantize) {
   1501   TestConvertFakeQuantsToRequantize();
   1502 }
   1503 
   1504 TEST_F(QuantizeNodesTest, TestMergeAdjacentRequantizes) {
   1505   TestMergeAdjacentRequantizes();
   1506 }
   1507 
   1508 TEST_F(QuantizeNodesTest, TestConvertFakeQuantsEndToEnd) {
   1509   TestConvertFakeQuantsEndToEnd();
   1510 }
   1511 
   1512 TEST_F(QuantizeNodesTest, TestHoistFakeQuants) { TestHoistFakeQuants(); }
   1513 
   1514 TEST_F(QuantizeNodesTest, TestMergeDuplicateQuantizes) {
   1515   TestMergeDuplicateQuantizes();
   1516 }
   1517 
   1518 TEST_F(QuantizeNodesTest, TestMergeDuplicateConsts) {
   1519   TestMergeDuplicateConsts();
   1520 }
   1521 
   1522 TEST_F(QuantizeNodesTest, TestMergeDuplicatesNested) {
   1523   TestMergeDuplicatesNested();
   1524 }
   1525 
   1526 TEST_F(QuantizeNodesTest, TestMergeDuplicateInOut) {
   1527   TestMergeDuplicatesInOut();
   1528 }
   1529 
   1530 TEST_F(QuantizeNodesTest, TestExcludeNonFloat) { TestExcludeNonFloat(); }
   1531 
   1532 }  // namespace graph_transforms
   1533 }  // namespace tensorflow
   1534