Home | History | Annotate | Download | only in data
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/data/make_sloppy.h"
     17 
     18 #include "tensorflow/core/framework/attr_value_util.h"
     19 #include "tensorflow/core/framework/function_testlib.h"
     20 #include "tensorflow/core/framework/tensor_testutil.h"
     21 #include "tensorflow/core/grappler/grappler_item.h"
     22 #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
     23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
     24 
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 #include "tensorflow/core/platform/test.h"
     27 
     28 namespace tensorflow {
     29 namespace grappler {
     30 namespace {
     31 
     32 using graph_tests_utils::MakeParallelInterleaveNode;
     33 using graph_tests_utils::MakeParallelMapNode;
     34 using graph_tests_utils::MakeParseExampleNode;
     35 
     36 TEST(MakeSloppy, ParallelInterleave) {
     37   using test::function::NDef;
     38   GrapplerItem item;
     39   item.graph = test::function::GDef(
     40       {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
     41        NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
     42        NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
     43        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
     44        NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
     45        NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
     46        NDef("num_parallel_calls", "Const", {},
     47             {{"value", 1}, {"dtype", DT_INT32}}),
     48        MakeParallelInterleaveNode("interleave", "range", "cycle_length",
     49                                   "block_length", "num_parallel_calls",
     50                                   "XTimesTwo", /*sloppy=*/false)},
     51       // FunctionLib
     52       {
     53           test::function::XTimesTwo(),
     54       });
     55 
     56   MakeSloppy optimizer;
     57   GraphDef output;
     58   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
     59   EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("interleave", output));
     60   int index = graph_utils::FindGraphNodeWithName("interleave", output);
     61   EXPECT_TRUE(output.node(index).attr().at("sloppy").b());
     62 }
     63 
     64 TEST(MakeSloppy, ParallelMap) {
     65   using test::function::NDef;
     66   GrapplerItem item;
     67   item.graph = test::function::GDef(
     68       {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
     69        NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
     70        NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
     71        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
     72        NDef("num_parallel_calls", "Const", {},
     73             {{"value", 1}, {"dtype", DT_INT32}}),
     74        MakeParallelMapNode("map", "range", "num_parallel_calls", "XTimesTwo",
     75                            /*sloppy=*/false)},
     76       // FunctionLib
     77       {
     78           test::function::XTimesTwo(),
     79       });
     80 
     81   MakeSloppy optimizer;
     82   GraphDef output;
     83   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
     84   EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map", output));
     85   int index = graph_utils::FindGraphNodeWithName("map", output);
     86   EXPECT_TRUE(output.node(index).attr().at("sloppy").b());
     87 }
     88 
     89 TEST(MakeSloppy, ParseExampleDataset) {
     90   using test::function::NDef;
     91   GrapplerItem item;
     92   item.graph = test::function::GDef(
     93       {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
     94        NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
     95        NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
     96        NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
     97        NDef("num_parallel_calls", "Const", {},
     98             {{"value", 1}, {"dtype", DT_INT32}}),
     99        MakeParseExampleNode("parse_example", "range", "num_parallel_calls",
    100                             /*sloppy=*/false)},
    101       // FunctionLib
    102       {});
    103 
    104   MakeSloppy optimizer;
    105   GraphDef output;
    106   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
    107   EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("parse_example", output));
    108   int index = graph_utils::FindGraphNodeWithName("parse_example", output);
    109   EXPECT_TRUE(output.node(index).attr().at("sloppy").b());
    110 }
    111 
    112 }  // namespace
    113 }  // namespace grappler
    114 }  // namespace tensorflow
    115