Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/cc/ops/const_op.h"
     16 #include "tensorflow/cc/ops/sendrecv_ops.h"
     17 #include "tensorflow/cc/ops/standard_ops.h"
     18 #include "tensorflow/core/framework/tensor_testutil.h"
     19 #include "tensorflow/core/lib/core/status_test_util.h"
     20 #include "tensorflow/core/lib/io/path.h"
     21 #include "tensorflow/core/lib/strings/strcat.h"
     22 #include "tensorflow/core/platform/test.h"
     23 #include "tensorflow/core/platform/test_benchmark.h"
     24 #include "tensorflow/core/public/session.h"
     25 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
     26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     27 
     28 namespace tensorflow {
     29 namespace graph_transforms {
     30 
     31 // Declarations so we don't need a public header.
     32 Status SparsifyGather(const GraphDef& input_graph_def,
     33                       const TransformFuncContext& context,
     34                       GraphDef* output_graph_def);
     35 Status ReadTensorFromCheckpoint(
     36     const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
     37     const string& shape_and_slice, Tensor* tensor);
     38 
     39 class SparsifyGatherTest : public ::testing::Test {
     40  protected:
     41   NodeDef* CreateNode(const StringPiece name, const StringPiece op,
     42                       const std::vector<NodeDef*>& inputs, GraphDef* graph_def,
     43                       bool control_dep = false) {
     44     NodeDef* node_def = graph_def->add_node();
     45     node_def->set_name(name.ToString());
     46     node_def->set_op(op.ToString());
     47     if (!control_dep) {
     48       std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
     49         node_def->add_input(input->name());
     50       });
     51     } else {
     52       std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
     53         node_def->add_input(strings::StrCat("^", input->name()));
     54       });
     55     }
     56     return node_def;
     57   }
     58 
     59   void MakeGather(StringPiece name, bool gather_v2, NodeDef* params,
     60                   NodeDef* indices, GraphDef* graph_def) {
     61     if (gather_v2) {
     62       NodeDef* axis_node =
     63           CreateNode(strings::StrCat(name, "_axis"), "Const", {}, graph_def);
     64       Tensor axis_t(DT_INT32, TensorShape({}));
     65       axis_t.scalar<int32>()() = 0;
     66       SetNodeTensorAttr<int32>("value", axis_t, axis_node);
     67       CreateNode(name, "GatherV2", {params, indices, axis_node}, graph_def);
     68     } else {
     69       CreateNode(name, "Gather", {params, indices}, graph_def);
     70     }
     71   }
     72 
     73   void TestSinglePartition(bool gather_v2, bool include_shared_init,
     74                            bool test_variable, bool test_kept_concat,
     75                            const string& shared_init_name = "group_deps") {
     76     GraphDef graph_def;
     77 
     78     const auto checkpoint_path =
     79         io::JoinPath(testing::TmpDir(), "checkpoint_single");
     80     // Build the graph.
     81     NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
     82     NodeDef* w_node;
     83     NodeDef* zeros_const;
     84     NodeDef* zeros_shape;
     85     NodeDef* zeros_node;
     86     NodeDef* assign_node;
     87 
     88     Tensor weights(DT_FLOAT, TensorShape({4, 1}));
     89     test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
     90 
     91     if (!test_variable) {
     92       w_node = CreateNode("w/part_1", "Const", {}, &graph_def);
     93       SetNodeTensorAttr<float>("value", weights, w_node);
     94     } else {
     95       w_node = CreateNode("w/part_1", "VariableV2", {}, &graph_def);
     96 
     97       zeros_shape = CreateNode("w/part_1/Initializer/zeros/shape_as_tensor",
     98                                "Const", {}, &graph_def);
     99       zeros_const = CreateNode("w/part_1/Initializer/zeros/Const", "Const", {},
    100                                &graph_def);
    101       zeros_node = CreateNode("w/part_1/Initializer/zeros", "Fill",
    102                               {zeros_shape, zeros_const}, &graph_def);
    103       assign_node = CreateNode("w/part_1/Assign", "Assign",
    104                                {w_node, zeros_node}, &graph_def);
    105 
    106       NodeDef* save_const_node =
    107           CreateNode("save/Const", "Const", {}, &graph_def);
    108 
    109       Tensor tensor_names_values(DT_STRING, TensorShape({1}));
    110       test::FillValues<string>(&tensor_names_values, {"w"});
    111       NodeDef* tensor_names_node =
    112           CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
    113       SetNodeTensorAttr<string>("value", tensor_names_values,
    114                                 tensor_names_node);
    115 
    116       NodeDef* tensor_shapes_slices_node = CreateNode(
    117           "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
    118       Tensor shapes_slices_val(DT_STRING, TensorShape({1}));
    119       shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
    120       SetNodeTensorAttr<string>("value", shapes_slices_val,
    121                                 tensor_shapes_slices_node);
    122 
    123       NodeDef* restore_node = CreateNode(
    124           "save/RestoreV2", "RestoreV2",
    125           {save_const_node, tensor_names_node, tensor_shapes_slices_node},
    126           &graph_def);
    127       CreateNode("save/Assign", "Assign", {w_node, restore_node}, &graph_def);
    128 
    129       BundleWriter writer(Env::Default(), checkpoint_path);
    130       TF_ASSERT_OK(writer.Add("w", weights));
    131       TF_ASSERT_OK(writer.Finish());
    132     }
    133     SetNodeAttr("dtype", DT_FLOAT, w_node);
    134 
    135     NodeDef* identity_node =
    136         CreateNode("w/read", "Identity", {w_node}, &graph_def);
    137     MakeGather("gather", gather_v2, identity_node, input_node, &graph_def);
    138     if (include_shared_init) {
    139       if (!test_variable) {
    140         CreateNode(shared_init_name, "NoOp", {}, &graph_def);
    141       } else {
    142         CreateNode(shared_init_name, "NoOp", {assign_node}, &graph_def, true);
    143       }
    144     }
    145 
    146     NodeDef* concat_axis_node =
    147         CreateNode("linear/concat/axis", "Const", {}, &graph_def);
    148     NodeDef* concat_input_node =
    149         CreateNode("concat/input/node", "Const", {}, &graph_def);
    150     NodeDef* concat_node = nullptr;
    151     if (!test_kept_concat) {
    152       concat_node = CreateNode(
    153           "concat/node", "ConcatV2",
    154           {identity_node, concat_input_node, concat_axis_node}, &graph_def);
    155       SetNodeAttr("N", 2, concat_node);
    156     } else {
    157       NodeDef* concat_input_node_2 =
    158           CreateNode("concat/input/node_2", "Const", {}, &graph_def);
    159       concat_node = CreateNode("concat/node", "ConcatV2",
    160                                {identity_node, concat_input_node,
    161                                 concat_input_node_2, concat_axis_node},
    162                                &graph_def);
    163       SetNodeAttr("N", 3, concat_node);
    164     }
    165 
    166     // Run the op.
    167     GraphDef result;
    168     TransformFuncContext context;
    169     context.input_names = {"ids"};
    170     context.output_names = {"gather"};
    171     if (test_variable) {
    172       context.params["input_checkpoint"] = {checkpoint_path};
    173     }
    174     if (shared_init_name != "group_deps") {
    175       context.params["group_init_node"] = {shared_init_name};
    176     }
    177     TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
    178 
    179     // Validation begins.
    180     std::map<string, const NodeDef*> node_lookup;
    181     MapNamesToNodes(result, &node_lookup);
    182 
    183     // Check nodes.
    184     EXPECT_EQ(0,
    185               node_lookup.count("w/part_1/Initializer/zeros/shape_as_tensor"));
    186     EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros/Const"));
    187     EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros"));
    188     EXPECT_EQ(0, node_lookup.count("w/part_1/Assign"));
    189 
    190     EXPECT_EQ(1, node_lookup.count("ids"));
    191     EXPECT_EQ("Const", node_lookup.at("ids")->op());
    192 
    193     EXPECT_EQ(1, node_lookup.count("concat/node"));
    194 
    195     if (!test_kept_concat) {
    196       EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
    197       EXPECT_EQ("Identity", node_lookup.at("concat/node")->op());
    198       EXPECT_EQ(1, node_lookup.at("concat/node")->input_size());
    199       EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
    200     } else {
    201       EXPECT_EQ(1, node_lookup.count("linear/concat/axis"));
    202       EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op());
    203       EXPECT_EQ(3, node_lookup.at("concat/node")->input_size());
    204       EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
    205       EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1));
    206       EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2));
    207       EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i());
    208     }
    209 
    210     EXPECT_EQ(1, node_lookup.count("w/part_1/indices"));
    211     EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op());
    212     Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
    213     test::FillValues<int64>(&expected_indices_tensor, {0, 2, 3});
    214     test::ExpectTensorEqual<int64>(
    215         expected_indices_tensor,
    216         GetNodeTensorAttr(*(node_lookup.at("w/part_1/indices")), "value"));
    217 
    218     EXPECT_EQ(1, node_lookup.count("w/part_1/values"));
    219     EXPECT_EQ("Const", node_lookup.at("w/part_1/values")->op());
    220     Tensor expected_values_tensor(DT_FLOAT, TensorShape({3}));
    221     test::FillValues<float>(&expected_values_tensor, {0.2, 1.2, 0.001});
    222     test::ExpectTensorNear<float>(
    223         expected_values_tensor,
    224         GetNodeTensorAttr(*(node_lookup.at("w/part_1/values")), "value"), 1e-5);
    225 
    226     EXPECT_EQ(1, node_lookup.count("w/part_1/HashTable"));
    227     EXPECT_EQ("HashTable", node_lookup.at("w/part_1/HashTable")->op());
    228 
    229     EXPECT_EQ(1, node_lookup.count("w/part_1/InitializeTable"));
    230     EXPECT_EQ("InitializeTable",
    231               node_lookup.at("w/part_1/InitializeTable")->op());
    232 
    233     // Nodes in "gather" scope.
    234     EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind"));
    235     EXPECT_EQ("LookupTableFind",
    236               node_lookup.at("gather/LookupTableFind")->op());
    237 
    238     EXPECT_EQ(1, node_lookup.count("gather/Const"));
    239     EXPECT_EQ("Const", node_lookup.at("gather/Const")->op());
    240     Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({}));
    241     test::FillValues<float>(&expected_gather_default_tensor, {0.0});
    242     test::ExpectTensorNear<float>(
    243         expected_gather_default_tensor,
    244         GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5);
    245 
    246     EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const"));
    247     EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op());
    248     Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({}));
    249     test::FillValues<int32>(&expected_expand_dims_tensor, {-1});
    250     test::ExpectTensorEqual<int32>(
    251         expected_expand_dims_tensor,
    252         GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")),
    253                           "value"));
    254 
    255     EXPECT_EQ(1, node_lookup.count("gather"));
    256     EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op());
    257 
    258     EXPECT_EQ(1, node_lookup.count(shared_init_name));
    259     EXPECT_EQ("NoOp", node_lookup.at(shared_init_name)->op());
    260 
    261     // Check connections
    262     EXPECT_EQ("w/part_1/HashTable",
    263               node_lookup.at("w/part_1/InitializeTable")->input(0));
    264     EXPECT_EQ("w/part_1/indices",
    265               node_lookup.at("w/part_1/InitializeTable")->input(1));
    266     EXPECT_EQ("w/part_1/values",
    267               node_lookup.at("w/part_1/InitializeTable")->input(2));
    268 
    269     EXPECT_EQ("w/part_1/HashTable",
    270               node_lookup.at("gather/LookupTableFind")->input(0));
    271     EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1));
    272     EXPECT_EQ("gather/Const",
    273               node_lookup.at("gather/LookupTableFind")->input(2));
    274 
    275     EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0));
    276 
    277     // Check control dependency.
    278     EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
    279                         node_lookup.at(shared_init_name)->input().end(),
    280                         "^w/part_1/InitializeTable"),
    281               node_lookup.at(shared_init_name)->input().end());
    282     EXPECT_EQ(1, node_lookup.at(shared_init_name)->input().size());
    283   }
    284 
    285   void TestMultiPartition(bool gather_v2, bool include_shared_init,
    286                           bool test_variable,
    287                           const string& shared_init_name = "group_deps") {
    288     // The 'ids' node is served input for two 'Gather's.
    289     GraphDef graph_def;
    290 
    291     const auto checkpoint_path =
    292         io::JoinPath(testing::TmpDir(), "checkpoint_multiple");
    293     // Build Graph:
    294     // Shared input node
    295     NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
    296 
    297     // Two partitions
    298     NodeDef* w_node1;
    299     NodeDef* w_node2;
    300     NodeDef* zeros_const1;
    301     NodeDef* zeros_shape1;
    302     NodeDef* zeros_node1;
    303     NodeDef* zeros_const2;
    304     NodeDef* zeros_shape2;
    305     NodeDef* zeros_node2;
    306     NodeDef* assign_node1;
    307     NodeDef* assign_node2;
    308 
    309     Tensor weights(DT_FLOAT, TensorShape({4, 1}));
    310     test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
    311     if (!test_variable) {
    312       w_node1 = CreateNode("w1/part_1", "Const", {}, &graph_def);
    313       w_node2 = CreateNode("w2/part_1", "Const", {}, &graph_def);
    314       SetNodeTensorAttr<float>("value", weights, w_node1);
    315       SetNodeTensorAttr<float>("value", weights, w_node2);
    316     } else {
    317       NodeDef* save_const_node =
    318           CreateNode("save/Const", "Const", {}, &graph_def);
    319 
    320       NodeDef* tensor_names_node =
    321           CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
    322       Tensor tensor_names_values(DT_STRING, TensorShape({2}));
    323       test::FillValues<string>(&tensor_names_values, {"w1", "w2"});
    324       SetNodeTensorAttr<string>("value", tensor_names_values,
    325                                 tensor_names_node);
    326 
    327       NodeDef* tensor_shapes_slices_node = CreateNode(
    328           "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
    329       Tensor shapes_slices_val(DT_STRING, TensorShape({2}));
    330       shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
    331       shapes_slices_val.flat<string>()(1) = "4 1 0,4:0,1";
    332       SetNodeTensorAttr<string>("value", shapes_slices_val,
    333                                 tensor_shapes_slices_node);
    334 
    335       NodeDef* restore_node = CreateNode(
    336           "save/RestoreV2", "RestoreV2",
    337           {save_const_node, tensor_names_node, tensor_shapes_slices_node},
    338           &graph_def);
    339 
    340       w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def);
    341 
    342       zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor",
    343                                 "Const", {}, &graph_def);
    344       zeros_const1 = CreateNode("w1/part_1/Initializer/zeros/Const", "Const",
    345                                 {}, &graph_def);
    346       zeros_node1 = CreateNode("w1/part_1/Initializer/zeros", "Fill",
    347                                {zeros_shape1, zeros_const1}, &graph_def);
    348       assign_node1 = CreateNode("w1/part_1/Assign", "Assign",
    349                                 {w_node1, zeros_node1}, &graph_def);
    350 
    351       CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def);
    352 
    353       w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def);
    354       zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor",
    355                                 "Const", {}, &graph_def);
    356       zeros_const2 = CreateNode("w2/part_1/Initializer/zeros/Const", "Const",
    357                                 {}, &graph_def);
    358       zeros_node2 = CreateNode("w2/part_1/Initializer/zeros", "Fill",
    359                                {zeros_shape2, zeros_const2}, &graph_def);
    360       assign_node2 = CreateNode("w2/part_1/Assign", "Assign",
    361                                 {w_node2, zeros_node2}, &graph_def);
    362 
    363       CreateNode("save/Assign_1", "Assign", {w_node2, restore_node},
    364                  &graph_def);
    365 
    366       BundleWriter writer(Env::Default(), checkpoint_path);
    367       TF_ASSERT_OK(writer.Add("w1", weights));
    368       TF_ASSERT_OK(writer.Add("w2", weights));
    369       TF_ASSERT_OK(writer.Finish());
    370     }
    371     SetNodeAttr("dtype", DT_FLOAT, w_node1);
    372     SetNodeAttr("dtype", DT_FLOAT, w_node2);
    373 
    374     NodeDef* identity_node1 =
    375         CreateNode("w1/part_1/read", "Identity", {w_node1}, &graph_def);
    376     NodeDef* identity_node2 =
    377         CreateNode("w2/part_1/read", "Identity", {w_node2}, &graph_def);
    378     MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def);
    379     MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def);
    380 
    381     NodeDef* concat_axis_node =
    382         CreateNode("linear/concat/axis", "Const", {}, &graph_def);
    383     NodeDef* concat_node = CreateNode(
    384         "concat/node", "ConcatV2",
    385         {identity_node1, identity_node2, concat_axis_node}, &graph_def);
    386     SetNodeAttr("N", 2, concat_node);
    387 
    388     // Shared init node
    389     if (include_shared_init) {
    390       if (!test_variable) {
    391         CreateNode(shared_init_name, "NoOp", {}, &graph_def);
    392       } else {
    393         CreateNode(shared_init_name, "NoOp", {assign_node1, assign_node2},
    394                    &graph_def, true);
    395       }
    396     }
    397 
    398     // Run the op.
    399     GraphDef result;
    400     TransformFuncContext context;
    401     context.input_names = {"ids"};
    402     context.output_names = {"gather1", "gather2"};
    403     if (test_variable) {
    404       context.params["input_checkpoint"] = {checkpoint_path};
    405     }
    406     if (shared_init_name != "group_deps") {
    407       context.params["group_init_node"] = {shared_init_name};
    408     }
    409     TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
    410 
    411     // Validation begins.
    412     std::map<string, const NodeDef*> node_lookup;
    413     MapNamesToNodes(result, &node_lookup);
    414 
    415     // Check nodes.
    416     EXPECT_EQ(0,
    417               node_lookup.count("w1/part_1/Initializer/zeros/shape_as_tensor"));
    418     EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros/Const"));
    419     EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros"));
    420     EXPECT_EQ(0, node_lookup.count("w1/part_1/Assign"));
    421     EXPECT_EQ(0,
    422               node_lookup.count("w2/part_1/Initializer/zeros/shape_as_tensor"));
    423     EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros/Const"));
    424     EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros"));
    425     EXPECT_EQ(0, node_lookup.count("w2/part_1/Assign"));
    426     EXPECT_EQ(1, node_lookup.count("ids"));
    427     EXPECT_EQ("Const", node_lookup.at("ids")->op());
    428 
    429     EXPECT_EQ(1, node_lookup.count(shared_init_name));
    430     EXPECT_EQ("NoOp", node_lookup.at(shared_init_name)->op());
    431 
    432     EXPECT_EQ(1, node_lookup.count("w1/part_1/indices"));
    433     EXPECT_EQ("Const", node_lookup.at("w1/part_1/indices")->op());
    434     Tensor expected_indices_tensor1(DT_INT64, TensorShape({3}));
    435     test::FillValues<int64>(&expected_indices_tensor1, {0, 2, 3});
    436     test::ExpectTensorEqual<int64>(
    437         expected_indices_tensor1,
    438         GetNodeTensorAttr(*(node_lookup.at("w1/part_1/indices")), "value"));
    439 
    440     EXPECT_EQ(1, node_lookup.count("w1/part_1/values"));
    441     EXPECT_EQ("Const", node_lookup.at("w1/part_1/values")->op());
    442     Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3}));
    443     test::FillValues<float>(&expected_values_tensor1, {0.2, 1.2, 0.001});
    444     test::ExpectTensorNear<float>(
    445         expected_values_tensor1,
    446         GetNodeTensorAttr(*(node_lookup.at("w1/part_1/values")), "value"),
    447         1e-5);
    448 
    449     EXPECT_EQ(1, node_lookup.count("w1/part_1/HashTable"));
    450     EXPECT_EQ("HashTable", node_lookup.at("w1/part_1/HashTable")->op());
    451 
    452     EXPECT_EQ(1, node_lookup.count("w1/part_1/InitializeTable"));
    453     EXPECT_EQ("InitializeTable",
    454               node_lookup.at("w1/part_1/InitializeTable")->op());
    455 
    456     // Nodes in "gather1" scope.
    457     EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind"));
    458     EXPECT_EQ("LookupTableFind",
    459               node_lookup.at("gather1/LookupTableFind")->op());
    460 
    461     EXPECT_EQ(1, node_lookup.count("gather1/Const"));
    462     EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op());
    463     Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({}));
    464     test::FillValues<float>(&expected_gather_default_tensor1, {0.0});
    465     test::ExpectTensorNear<float>(
    466         expected_gather_default_tensor1,
    467         GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5);
    468 
    469     EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const"));
    470     EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op());
    471     Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({}));
    472     test::FillValues<int32>(&expected_expand_dims_tensor1, {-1});
    473     test::ExpectTensorEqual<int32>(
    474         expected_expand_dims_tensor1,
    475         GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")),
    476                           "value"));
    477 
    478     EXPECT_EQ(1, node_lookup.count("gather1"));
    479     EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op());
    480 
    481     EXPECT_EQ(1, node_lookup.count("w2/part_1/indices"));
    482     EXPECT_EQ("Const", node_lookup.at("w2/part_1/indices")->op());
    483     Tensor expected_indices_tensor2(DT_INT64, TensorShape({3}));
    484     test::FillValues<int64>(&expected_indices_tensor2, {0, 2, 3});
    485     test::ExpectTensorEqual<int64>(
    486         expected_indices_tensor2,
    487         GetNodeTensorAttr(*(node_lookup.at("w2/part_1/indices")), "value"));
    488 
    489     EXPECT_EQ(1, node_lookup.count("w2/part_1/values"));
    490     EXPECT_EQ("Const", node_lookup.at("w2/part_1/values")->op());
    491     Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3}));
    492     test::FillValues<float>(&expected_values_tensor2, {0.2, 1.2, 0.001});
    493     test::ExpectTensorNear<float>(
    494         expected_values_tensor2,
    495         GetNodeTensorAttr(*(node_lookup.at("w2/part_1/values")), "value"),
    496         1e-5);
    497 
    498     EXPECT_EQ(1, node_lookup.count("w2/part_1/HashTable"));
    499     EXPECT_EQ("HashTable", node_lookup.at("w2/part_1/HashTable")->op());
    500 
    501     EXPECT_EQ(1, node_lookup.count("w2/part_1/InitializeTable"));
    502     EXPECT_EQ("InitializeTable",
    503               node_lookup.at("w2/part_1/InitializeTable")->op());
    504 
    505     // Nodes in "gather2" scope.
    506     EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind"));
    507     EXPECT_EQ("LookupTableFind",
    508               node_lookup.at("gather2/LookupTableFind")->op());
    509 
    510     EXPECT_EQ(1, node_lookup.count("gather2/Const"));
    511     EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op());
    512     Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({}));
    513     test::FillValues<float>(&expected_gather_default_tensor2, {0.0});
    514     test::ExpectTensorNear<float>(
    515         expected_gather_default_tensor2,
    516         GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5);
    517 
    518     EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const"));
    519     EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op());
    520     Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({}));
    521     test::FillValues<int32>(&expected_expand_dims_tensor2, {-1});
    522     test::ExpectTensorEqual<int32>(
    523         expected_expand_dims_tensor2,
    524         GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")),
    525                           "value"));
    526 
    527     EXPECT_EQ(1, node_lookup.count("gather2"));
    528     EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op());
    529 
    530     // Check connections
    531     EXPECT_EQ("w1/part_1/HashTable",
    532               node_lookup.at("w1/part_1/InitializeTable")->input(0));
    533     EXPECT_EQ("w1/part_1/indices",
    534               node_lookup.at("w1/part_1/InitializeTable")->input(1));
    535     EXPECT_EQ("w1/part_1/values",
    536               node_lookup.at("w1/part_1/InitializeTable")->input(2));
    537 
    538     EXPECT_EQ("w2/part_1/HashTable",
    539               node_lookup.at("w2/part_1/InitializeTable")->input(0));
    540     EXPECT_EQ("w2/part_1/indices",
    541               node_lookup.at("w2/part_1/InitializeTable")->input(1));
    542     EXPECT_EQ("w2/part_1/values",
    543               node_lookup.at("w2/part_1/InitializeTable")->input(2));
    544 
    545     EXPECT_EQ("w1/part_1/HashTable",
    546               node_lookup.at("gather1/LookupTableFind")->input(0));
    547     EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1));
    548     EXPECT_EQ("gather1/Const",
    549               node_lookup.at("gather1/LookupTableFind")->input(2));
    550     EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0));
    551 
    552     EXPECT_EQ("w2/part_1/HashTable",
    553               node_lookup.at("gather2/LookupTableFind")->input(0));
    554     EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1));
    555     EXPECT_EQ("gather2/Const",
    556               node_lookup.at("gather2/LookupTableFind")->input(2));
    557     EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
    558 
    559     EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
    560     EXPECT_EQ(0, node_lookup.count("concat/node"));
    561 
    562     // Check control deps.
    563     EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size());
    564     EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
    565                         node_lookup.at(shared_init_name)->input().end(),
    566                         "^w1/part_1/InitializeTable"),
    567               node_lookup.at(shared_init_name)->input().end());
    568 
    569     EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
    570                         node_lookup.at(shared_init_name)->input().end(),
    571                         "^w2/part_1/InitializeTable"),
    572               node_lookup.at(shared_init_name)->input().end());
    573   }
    574   void TestReadTensorSlice() {
    575     const auto checkpoint_path =
    576         io::JoinPath(testing::TmpDir(), "checkpoint_slice");
    577 
    578     Tensor weights(DT_FLOAT, TensorShape({2, 1}));
    579     test::FillValues<float>(&weights, {0.2, 0.000001});
    580     BundleWriter writer(Env::Default(), checkpoint_path);
    581     TF_ASSERT_OK(writer.AddSlice("w", TensorShape({4, 1}),
    582                                  TensorSlice::ParseOrDie("0,2:0,1"), weights));
    583     TF_ASSERT_OK(writer.Finish());
    584 
    585     std::unique_ptr<BundleReader> reader(
    586         new BundleReader(Env::Default(), checkpoint_path));
    587 
    588     Tensor results;
    589     TF_ASSERT_OK(
    590         ReadTensorFromCheckpoint("w/part_0", reader, "4 1 0,2:0,1", &results));
    591 
    592     test::ExpectTensorEqual<float>(weights, results);
    593   }
    594 };
    595 
    596 TEST_F(SparsifyGatherTest, TestSinglePartition) {
    597   TestSinglePartition(false, false, false, false);
    598   TestSinglePartition(false, true, false, false);
    599   TestSinglePartition(true, false, false, false);
    600   TestSinglePartition(true, true, false, false);
    601   TestSinglePartition(false, false, true, false);
    602   TestSinglePartition(false, true, true, false);
    603   TestSinglePartition(true, false, true, false);
    604   TestSinglePartition(true, true, true, false);
    605   TestSinglePartition(false, true, false, false, "shared_inits");
    606   TestSinglePartition(true, true, false, false, "shared_inits");
    607   TestSinglePartition(false, true, true, false, "shared_inits");
    608   TestSinglePartition(true, true, true, false, "shared_inits");
    609 
    610   TestSinglePartition(false, false, false, true);
    611   TestSinglePartition(false, true, false, true);
    612   TestSinglePartition(true, false, false, true);
    613   TestSinglePartition(true, true, false, true);
    614   TestSinglePartition(false, false, true, true);
    615   TestSinglePartition(false, true, true, true);
    616   TestSinglePartition(true, false, true, true);
    617   TestSinglePartition(true, true, true, true);
    618   TestSinglePartition(false, true, false, true, "shared_inits");
    619   TestSinglePartition(true, true, false, true, "shared_inits");
    620   TestSinglePartition(false, true, true, true, "shared_inits");
    621   TestSinglePartition(true, true, true, true, "shared_inits");
    622 }
    623 
    624 TEST_F(SparsifyGatherTest, TestMultiPartition) {
    625   TestMultiPartition(false, false, false);
    626   TestMultiPartition(false, true, false);
    627   TestMultiPartition(true, false, false);
    628   TestMultiPartition(true, true, false);
    629   TestMultiPartition(false, false, true);
    630   TestMultiPartition(false, true, true);
    631   TestMultiPartition(true, false, true);
    632   TestMultiPartition(true, true, true);
    633   TestMultiPartition(false, true, false, "shared_inits");
    634   TestMultiPartition(true, true, false, "shared_inits");
    635   TestMultiPartition(false, true, true, "shared_inits");
    636   TestMultiPartition(true, true, true, "shared_inits");
    637 }
    638 
    639 TEST_F(SparsifyGatherTest, TestTensorSlice) { TestReadTensorSlice(); }
    640 
    641 }  // namespace graph_transforms
    642 }  // namespace tensorflow
    643