Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/graph_partition.h"
     17 
     18 #include <unordered_map>
     19 #include <utility>
     20 
     21 #include "tensorflow/cc/ops/array_ops.h"
     22 #include "tensorflow/cc/ops/const_op.h"
     23 #include "tensorflow/cc/ops/control_flow_ops.h"
     24 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
     25 #include "tensorflow/cc/ops/math_ops.h"
     26 #include "tensorflow/cc/ops/random_ops.h"
     27 #include "tensorflow/cc/ops/sendrecv_ops.h"
     28 #include "tensorflow/cc/ops/while_loop.h"
     29 #include "tensorflow/core/framework/common_shape_fns.h"
     30 #include "tensorflow/core/framework/function_testlib.h"
     31 #include "tensorflow/core/framework/op.h"
     32 #include "tensorflow/core/framework/versions.pb.h"
     33 #include "tensorflow/core/graph/graph.h"
     34 #include "tensorflow/core/graph/graph_constructor.h"
     35 #include "tensorflow/core/graph/graph_def_builder.h"
     36 #include "tensorflow/core/kernels/ops_util.h"
     37 #include "tensorflow/core/lib/core/status_test_util.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/platform/protobuf.h"
     40 #include "tensorflow/core/platform/test.h"
     41 #include "tensorflow/core/public/version.h"
     42 #include "tensorflow/core/util/equal_graph_def.h"
     43 
     44 namespace tensorflow {
     45 
     46 // from graph_partition.cc
     47 extern Status TopologicalSortNodesWithTimePriority(
     48     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
     49     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out);
     50 
     51 namespace {
     52 
     53 using ops::_Recv;
     54 using ops::_Send;
     55 using ops::Const;
     56 using ops::Identity;
     57 using ops::LoopCond;
     58 using ops::NextIteration;
     59 
     60 const char gpu_device[] = "/job:a/replica:0/task:0/device:GPU:0";
     61 
     62 string SplitByDevice(const Node* node) { return node->assigned_device_name(); }
     63 
     64 string DeviceName(const Node* node) {
     65   char first = node->name()[0];
     66   if (first == 'G') {
     67     return gpu_device;
     68   } else {
     69     const string cpu_prefix = "/job:a/replica:0/task:0/cpu:";
     70     int index = first - 'A';
     71     return strings::StrCat(cpu_prefix, index);
     72   }
     73 }
     74 
     75 void Partition(const GraphDef& graph_def,
     76                std::unordered_map<string, GraphDef>* partitions) {
     77   Graph g(OpRegistry::Global());
     78   GraphConstructorOptions opts;
     79   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
     80 
     81   // Assigns devices to each node. Uses 1st letter of the node name as the
     82   // device index if no device is specified.
     83   for (Node* node : g.nodes()) {
     84     string device_name = !node->requested_device().empty()
     85                              ? node->requested_device()
     86                              : DeviceName(node);
     87     node->set_assigned_device_name(device_name);
     88   }
     89 
     90   PartitionOptions popts;
     91   popts.node_to_loc = SplitByDevice;
     92   popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
     93   popts.get_incarnation = [](const string& name) {
     94     return (name[0] - 'A') + 100;
     95   };
     96   Status s = Partition(popts, &g, partitions);
     97   CHECK(s.ok()) << s;
     98 
     99   // Check versions.
    100   EXPECT_EQ(graph_def.versions().producer(), TF_GRAPH_DEF_VERSION);
    101   // Partitions must inherit the versions of the original graph.
    102   for (auto& it : *partitions) {
    103     EXPECT_EQ(graph_def.versions().producer(), it.second.versions().producer());
    104     EXPECT_EQ(graph_def.versions().min_consumer(),
    105               it.second.versions().min_consumer());
    106   }
    107 }
    108 
    109 void CheckLoopConstruction(const GraphDef& graph_def) {
    110   std::unordered_map<string, GraphDef> partitions;
    111   Partition(graph_def, &partitions);
    112   for (const auto& kv : partitions) {
    113     const GraphDef& gdef = kv.second;
    114     bool has_control_enter = false;
    115     bool has_control_merge = false;
    116     bool has_control_switch = false;
    117     bool has_control_next = false;
    118     for (const NodeDef& ndef : gdef.node()) {
    119       // _recvs must have a control input
    120       if (ndef.op() == "_Recv") {
    121         bool has_control = false;
    122         for (const string& input_name : ndef.input()) {
    123           if (StringPiece(input_name).starts_with("^")) {
    124             has_control = true;
    125             break;
    126           }
    127         }
    128         EXPECT_TRUE(has_control);
    129       }
    130       // Must have a control loop
    131       if (StringPiece(ndef.name()).starts_with("_cloop")) {
    132         if (ndef.op() == "Enter") {
    133           has_control_enter = true;
    134         }
    135         if (ndef.op() == "Merge") {
    136           has_control_merge = true;
    137         }
    138         if (ndef.op() == "Switch") {
    139           has_control_switch = true;
    140         }
    141         if (ndef.op() == "NextIteration") {
    142           has_control_next = true;
    143         }
    144       }
    145     }
    146     EXPECT_TRUE(has_control_enter);
    147     EXPECT_TRUE(has_control_merge);
    148     EXPECT_TRUE(has_control_switch);
    149     EXPECT_TRUE(has_control_next);
    150   }
    151 }
    152 
    153 REGISTER_OP("FloatInput")
    154     .Output("o: float")
    155     .SetShapeFn(shape_inference::UnknownShape);
    156 REGISTER_OP("BoolInput")
    157     .Output("o: bool")
    158     .SetShapeFn(shape_inference::UnknownShape);
    159 REGISTER_OP("Combine")
    160     .Input("a: float")
    161     .Input("b: float")
    162     .Output("o: float")
    163     .SetShapeFn(shape_inference::UnknownShape);
    164 
    165 Output ConstructOp(const Scope& scope, const string& op_type,
    166                    const gtl::ArraySlice<Input>& inputs) {
    167   if (!scope.ok()) return Output();
    168   const string unique_name = scope.GetUniqueNameForOp(op_type);
    169   auto builder =
    170       NodeBuilder(unique_name, op_type, scope.graph()->op_registry());
    171   for (auto const& input : inputs) {
    172     builder.Input(ops::NodeOut(input.node(), input.index()));
    173   }
    174   scope.UpdateBuilder(&builder);
    175   Node* ret;
    176   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
    177   if (!scope.ok()) return Output();
    178   scope.UpdateStatus(scope.DoShapeInference(ret));
    179   if (!scope.ok()) return Output();
    180   return Output(ret);
    181 }
    182 
    183 Output FloatInput(const Scope& scope) {
    184   return ConstructOp(scope, "FloatInput", {});
    185 }
    186 
    187 Output BoolInput(const Scope& scope) {
    188   return ConstructOp(scope, "BoolInput", {});
    189 }
    190 
    191 Output Combine(const Scope& scope, Input a, Input b) {
    192   return ConstructOp(scope, "Combine", {std::move(a), std::move(b)});
    193 }
    194 
    195 class GraphPartitionTest : public ::testing::Test {
    196  protected:
    197   GraphPartitionTest()
    198       : in_(Scope::NewRootScope().ExitOnError()),
    199         scope_a_(Scope::NewRootScope().ExitOnError().WithDevice(
    200             "/job:a/replica:0/task:0/cpu:0")),
    201         scope_b_(Scope::NewRootScope().ExitOnError().WithDevice(
    202             "/job:a/replica:0/task:0/cpu:1")) {}
    203 
    204   const GraphDef& ToGraphDef() {
    205     TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_));
    206     return in_graph_def_;
    207   }
    208 
    209   void ExpectMatchA() {
    210     GraphDef graph_def;
    211     TF_EXPECT_OK(scope_a_.ToGraphDef(&graph_def));
    212     string a = "/job:a/replica:0/task:0/cpu:0";
    213     TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
    214   }
    215 
    216   void ExpectMatchB() {
    217     GraphDef graph_def;
    218     TF_EXPECT_OK(scope_b_.ToGraphDef(&graph_def));
    219     string b = "/job:a/replica:0/task:0/cpu:1";
    220     TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
    221   }
    222 
    223   void ExpectFunctions(const FunctionDefLibrary& library,
    224                        const std::set<string>& expected_names) {
    225     std::set<string> actual_names;
    226     for (const FunctionDef& fdef : library.function()) {
    227       actual_names.insert(fdef.signature().name());
    228     }
    229     EXPECT_EQ(actual_names, expected_names);
    230   }
    231 
    232   Scope in_;
    233   GraphDef in_graph_def_;
    234   Scope scope_a_;
    235   Scope scope_b_;
    236   std::unordered_map<string, GraphDef> partitions_;
    237 };
    238 
    239 TEST_F(GraphPartitionTest, SingleDevice) {
    240   auto a1 = FloatInput(in_.WithOpName("A1"));
    241   Combine(in_.WithOpName("A2"), a1, a1);
    242 
    243   Partition(ToGraphDef(), &partitions_);
    244   EXPECT_EQ(1, partitions_.size());
    245 
    246   a1 = FloatInput(scope_a_.WithOpName("A1"));
    247   Combine(scope_a_.WithOpName("A2"), a1, a1);
    248   ExpectMatchA();
    249 }
    250 
    251 TEST_F(GraphPartitionTest, CrossDeviceData) {
    252   auto a1 = FloatInput(in_.WithOpName("A1"));
    253   auto b1 = FloatInput(in_.WithOpName("B1"));
    254   Combine(in_.WithOpName("B2"), a1, b1);
    255 
    256   Partition(ToGraphDef(), &partitions_);
    257   EXPECT_EQ(2, partitions_.size());
    258 
    259   string a = "/job:a/replica:0/task:0/cpu:0";
    260   string b = "/job:a/replica:0/task:0/cpu:1";
    261   a1 = FloatInput(scope_a_.WithOpName("A1"));
    262   _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
    263   ExpectMatchA();
    264 
    265   b1 = FloatInput(scope_b_.WithOpName("B1"));
    266   auto recv =
    267       _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
    268   Combine(scope_b_.WithOpName("B2"), recv, b1);
    269   ExpectMatchB();
    270 }
    271 
    272 TEST_F(GraphPartitionTest, CrossDeviceControl) {
    273   auto a1 = FloatInput(in_.WithOpName("A1"));
    274   auto b1 = FloatInput(in_.WithOpName("B1"));
    275   Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
    276 
    277   Partition(ToGraphDef(), &partitions_);
    278   EXPECT_EQ(2, partitions_.size());
    279 
    280   string a = "/job:a/replica:0/task:0/cpu:0";
    281   string b = "/job:a/replica:0/task:0/cpu:1";
    282   a1 = FloatInput(scope_a_.WithOpName("A1"));
    283   auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
    284   _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
    285   ExpectMatchA();
    286 
    287   auto recv =
    288       _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
    289   auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
    290   b1 = FloatInput(scope_b_.WithOpName("B1"));
    291   Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
    292   ExpectMatchB();
    293 }
    294 
    295 TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
    296   auto a1 = FloatInput(in_.WithOpName("A1"));
    297   auto b1 = FloatInput(in_.WithOpName("B1"));
    298   Combine(in_.WithOpName("B2"), a1, b1);
    299   Combine(in_.WithOpName("B3"), a1, a1);
    300 
    301   Partition(ToGraphDef(), &partitions_);
    302   EXPECT_EQ(2, partitions_.size());
    303 
    304   string a = "/job:a/replica:0/task:0/cpu:0";
    305   string b = "/job:a/replica:0/task:0/cpu:1";
    306   a1 = FloatInput(scope_a_.WithOpName("A1"));
    307   _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
    308   ExpectMatchA();
    309 
    310   auto recv =
    311       _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
    312   b1 = FloatInput(scope_b_.WithOpName("B1"));
    313   Combine(scope_b_.WithOpName("B2"), recv, b1);
    314   Combine(scope_b_.WithOpName("B3"), recv, recv);
    315   ExpectMatchB();
    316 }
    317 
    318 TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
    319   auto a1 = FloatInput(in_.WithOpName("A1"));
    320   auto b1 = FloatInput(in_.WithOpName("B1"));
    321   Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
    322   FloatInput(in_.WithOpName("B3").WithControlDependencies(a1));
    323 
    324   Partition(ToGraphDef(), &partitions_);
    325   EXPECT_EQ(2, partitions_.size());
    326 
    327   string a = "/job:a/replica:0/task:0/cpu:0";
    328   string b = "/job:a/replica:0/task:0/cpu:1";
    329   a1 = FloatInput(scope_a_.WithOpName("A1"));
    330   auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
    331   _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
    332   ExpectMatchA();
    333 
    334   auto recv =
    335       _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
    336   auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
    337   b1 = FloatInput(scope_b_.WithOpName("B1"));
    338   Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
    339   FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id));
    340   ExpectMatchB();
    341 }
    342 
    343 TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
    344   auto a1 = FloatInput(in_.WithOpName("A1"));
    345   auto b1 = FloatInput(in_.WithOpName("B1"));
    346   Combine(in_.WithOpName("B2"), a1, b1);
    347   FloatInput(in_.WithOpName("B3").WithControlDependencies(a1));
    348 
    349   Partition(ToGraphDef(), &partitions_);
    350   EXPECT_EQ(2, partitions_.size());
    351 
    352   string a = "/job:a/replica:0/task:0/cpu:0";
    353   string b = "/job:a/replica:0/task:0/cpu:1";
    354   a1 = FloatInput(scope_a_.WithOpName("A1"));
    355   auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
    356   // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
    357   // use A1/_0 -> A1/_4 as the control as a minor optimization.
    358   _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
    359   _Send(scope_a_.WithOpName("A1/_4"), a1, "edge_2_A1", a, 82, b);
    360   ExpectMatchA();
    361 
    362   auto recv1 =
    363       _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
    364   auto id1 = Identity(scope_b_.WithOpName("A1/_3"), recv1);
    365   auto recv2 =
    366       _Recv(scope_b_.WithOpName("A1/_5"), DT_FLOAT, "edge_2_A1", a, 82, b);
    367   b1 = FloatInput(scope_b_.WithOpName("B1"));
    368   Combine(scope_b_.WithOpName("B2"), recv2, b1);
    369   FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id1));
    370   ExpectMatchB();
    371 }
    372 
    373 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) {
    374   auto a1 = BoolInput(in_.WithOpName("A1"));
    375   auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo");
    376   auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"),
    377                                      {a2, Input("A5", 0, DT_BOOL)})
    378                 .output;
    379   LoopCond(in_.WithOpName("A4"), a3);
    380   auto b1 = Identity(in_.WithOpName("B1"), a3);
    381   NextIteration(in_.WithOpName("A5"), b1);
    382 
    383   CheckLoopConstruction(ToGraphDef());
    384 }
    385 
    386 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) {
    387   auto a1 = BoolInput(in_.WithOpName("A1"));
    388   auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo");
    389   auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"),
    390                                      {a2, Input("B5", 0, DT_BOOL)})
    391                 .output;
    392   LoopCond(in_.WithOpName("A4"), a3);
    393   auto b1 = Identity(in_.WithOpName("B1"), a3);
    394   NextIteration(in_.WithOpName("B5"), b1);
    395 
    396   std::unordered_map<string, GraphDef> partitions;
    397   Partition(ToGraphDef(), &partitions);
    398   for (const auto& kv : partitions) {
    399     const GraphDef& gdef = kv.second;
    400     for (const NodeDef& ndef : gdef.node()) {
    401       if (ndef.name() == "A3") {
    402         // A3, B2, and B5 are on the same device.
    403         EXPECT_EQ(ndef.input(0), "B2");
    404         EXPECT_EQ(ndef.input(1), "B5");
    405       }
    406     }
    407   }
    408 }
    409 
    410 TEST_F(GraphPartitionTest, CrossDeviceLoopFull) {
    411   Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0");
    412   auto p1 = ops::Placeholder(cpu0, DT_INT32);
    413   auto p2 = ops::Placeholder(cpu0, DT_INT32);
    414   OutputList outputs;
    415   // while i1 < 10: i1 += i2
    416   TF_ASSERT_OK(ops::BuildWhileLoop(
    417       cpu0, {p1, p2},
    418       [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
    419         *output = ops::Less(s, inputs[0], 10);
    420         return s.status();
    421       },
    422       [](const Scope& s, const std::vector<Output>& inputs,
    423          std::vector<Output>* outputs) {
    424         Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1");
    425         outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]}));
    426         outputs->push_back(inputs[1]);
    427         return s.status();
    428       },
    429       "test_loop", &outputs));
    430   CheckLoopConstruction(ToGraphDef());
    431 }
    432 
    433 TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
    434   NodeDef ndef;
    435   Graph g(OpRegistry::Global());
    436   // Invalid graph since the Combine node requires an input.
    437   bool parsed = protobuf::TextFormat::ParseFromString(
    438       R"EOF(
    439       name: "N"
    440       op: "Combine"
    441       )EOF",
    442       &ndef);
    443   ASSERT_TRUE(parsed);
    444   Status status;
    445   g.AddNode(ndef, &status);
    446   TF_ASSERT_OK(status);
    447 
    448   PartitionOptions popts;
    449   popts.node_to_loc = SplitByDevice;
    450   popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
    451   popts.get_incarnation = [](const string&) { return 1; };
    452 
    453   std::unordered_map<string, GraphDef> partitions;
    454   status = Partition(popts, &g, &partitions);
    455   // Partitioning should fail, but not crash like it did before the
    456   // changes that accompanied the addition of this test.
    457   EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status;
    458 }
    459 
    460 TEST_F(GraphPartitionTest, Functions) {
    461   FunctionDefLibrary fdef_lib;
    462   *fdef_lib.add_function() = test::function::XTimesTwo();
    463   *fdef_lib.add_function() = test::function::XTimesFour();
    464   TF_ASSERT_OK(in_.graph()->AddFunctionLibrary(fdef_lib));
    465 
    466   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    467   auto a1 = FloatInput(in_.WithOpName("A1"));
    468   auto b1 = FloatInput(in_.WithOpName("B1"));
    469   ConstructOp(in_.WithOpName("A2"), "XTimesTwo", {a1});
    470   ConstructOp(in_.WithOpName("B2"), "XTimesFour", {b1});
    471 
    472   Partition(ToGraphDef(), &partitions_);
    473   EXPECT_EQ(2, partitions_.size());
    474 
    475   // Test that partition graphs inherit function library from original graph
    476   string a = "/job:a/replica:0/task:0/cpu:0";
    477   string b = "/job:a/replica:0/task:0/cpu:1";
    478   ExpectFunctions(partitions_[a].library(), {"XTimesTwo", "XTimesFour"});
    479   ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"});
    480 }
    481 
    482 TEST_F(GraphPartitionTest, SetIncarnation) {
    483   GraphDef gdef;
    484   const char* const kSendRecvAttrs = R"proto(
    485   attr { key: 'T' value { type: DT_FLOAT  }  }
    486   attr { key: 'client_terminated' value {  b: false } }
    487   attr { key: 'recv_device' value { s: 'B' } }
    488   attr { key: 'send_device' value { s: 'A' } }
    489   attr { key: 'send_device_incarnation' value { i: 0 }  }
    490   attr { key: 'tensor_name' value { s: 'test' } }
    491 )proto";
    492   CHECK(protobuf::TextFormat::ParseFromString(
    493       strings::StrCat(
    494           "node { name: 'A/Pi' op: 'Const' ",
    495           "  attr { key: 'dtype' value { type: DT_FLOAT } } ",
    496           "  attr { key: 'value' value { tensor { ",
    497           "    dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }",
    498           "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}",
    499           "node { name: 'B' op: '_Recv' ", kSendRecvAttrs,
    500           "  attr { key: 'tensor_type' value { type:DT_FLOAT}}}"),
    501       &gdef));
    502   gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION);
    503   Partition(gdef, &partitions_);
    504   EXPECT_EQ(2, partitions_.size());
    505 
    506   for (const auto& kv : partitions_) {
    507     const GraphDef& gdef = kv.second;
    508     for (const NodeDef& ndef : gdef.node()) {
    509       if (ndef.name() == "A" || ndef.name() == "B") {
    510         int64 val;
    511         TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val));
    512         EXPECT_EQ(val, 100);  // Send device is "A".
    513       }
    514     }
    515   }
    516 }
    517 
    518 TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) {
    519   // Create placeholders, shuffle them so the order in the graph is not strictly
    520   // increasing.
    521   Scope root = Scope::NewRootScope().ExitOnError();
    522   std::vector<int> indexes;
    523   for (int i = 0; i < 20; ++i) {
    524     indexes.push_back((i + 2001) % 20);
    525   }
    526   std::vector<ops::Placeholder> placeholders;
    527   for (int i : indexes) {
    528     placeholders.emplace_back(root.WithOpName(strings::StrCat("p", i)),
    529                               DT_FLOAT);
    530     placeholders.back().node()->AddAttr("_start_time", i + 1);
    531   }
    532 
    533   GraphDef gdef;
    534   TF_EXPECT_OK(root.ToGraphDef(&gdef));
    535 
    536   std::vector<std::pair<const NodeDef*, int64>> nodes;
    537   std::unordered_map<const NodeDef*, int64> node_to_start_time;
    538   TF_CHECK_OK(
    539       TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
    540   ASSERT_EQ(nodes.size(), 20);
    541   for (int i = 0; i < nodes.size(); ++i) {
    542     EXPECT_EQ(strings::StrCat("p", i), nodes[i].first->name());
    543     EXPECT_EQ(i + 1, nodes[i].second);
    544   }
    545 }
    546 
    547 TEST(TopologicalSortNodesWithTimePriority, Dependencies) {
    548   // Create placeholders, shuffle them so the order in the graph is not strictly
    549   // increasing.
    550   Scope root = Scope::NewRootScope().ExitOnError();
    551   std::vector<int> indexes;
    552   std::vector<ops::Placeholder> placeholders_in_order;
    553   const int num_leaves = 20;
    554   for (int i = 0; i < num_leaves; ++i) {
    555     indexes.push_back((i + 2001) % num_leaves);
    556     placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)),
    557                                        DT_FLOAT);
    558     placeholders_in_order.back().node()->AddAttr("_start_time", i + 1);
    559   }
    560   std::vector<ops::Placeholder> placeholders;
    561   for (int i : indexes) {
    562     placeholders.push_back(placeholders_in_order[i]);
    563   }
    564 
    565   // Create ops that depend on the placeholders. We give start times to these
    566   // that are in descending order (e.g., the op that depends on the first
    567   // placeholder runs last).
    568   std::vector<ops::Square> squares;
    569   for (int i : indexes) {
    570     squares.emplace_back(root.WithOpName(strings::StrCat("s", i)),
    571                          placeholders[i]);
    572     squares.back().node()->AddAttr("_start_time", 50 - (i + 1));
    573   }
    574 
    575   // Create addn to sum all squares.
    576   std::vector<Input> inputs;
    577   for (const auto& s : squares) inputs.push_back(s);
    578   ops::AddN addn = ops::AddN(root.WithOpName("addn"),
    579                              tensorflow::gtl::ArraySlice<Input>(inputs));
    580   // Start times is actually listed earlier than the nodes it depends on.
    581   // But because of dependency ordering, it is last in the list.
    582   addn.node()->AddAttr("_start_time", 1);
    583 
    584   GraphDef gdef;
    585   TF_EXPECT_OK(root.ToGraphDef(&gdef));
    586 
    587   std::vector<std::pair<const NodeDef*, int64>> nodes;
    588   std::unordered_map<const NodeDef*, int64> node_to_start_time;
    589   TF_CHECK_OK(
    590       TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
    591   ASSERT_EQ(1 + squares.size() + placeholders.size(), nodes.size());
    592   for (int i = 0; i < placeholders.size(); ++i) {
    593     const NodeDef* node = nodes[i].first;
    594     EXPECT_EQ(strings::StrCat("p", i), node->name());
    595     EXPECT_EQ(i + 1, nodes[i].second);
    596     EXPECT_EQ(i + 1, node_to_start_time[node]);
    597   }
    598   for (int i = 0; i < squares.size(); ++i) {
    599     int node_index = placeholders.size() + i;
    600     int square_index = num_leaves - 1 - i;
    601     const NodeDef* node = nodes[node_index].first;
    602     EXPECT_EQ(strings::StrCat("s", square_index), node->name());
    603     EXPECT_EQ(50 - (square_index + 1), nodes[node_index].second);
    604     EXPECT_EQ(50 - (square_index + 1), node_to_start_time[node]);
    605   }
    606   EXPECT_EQ("addn", nodes.back().first->name());
    607   EXPECT_EQ(50, nodes.back().second);
    608   EXPECT_EQ(50, node_to_start_time[nodes.back().first]);
    609 }
    610 
    611 TEST(TopologicalSortNodesWithTimePriority, WhileLoop) {
    612   using namespace ::tensorflow::ops;            // NOLINT(build/namespaces)
    613   using namespace ::tensorflow::ops::internal;  // NOLINT(build/namespaces)
    614 
    615   // Create placeholders.
    616   Scope root = Scope::NewRootScope().ExitOnError();
    617   std::vector<int> indexes;
    618   std::vector<Placeholder> placeholders_in_order;
    619   const int num_leaves = 20;
    620   for (int i = 0; i < num_leaves; ++i) {
    621     indexes.push_back((i + 2001) % num_leaves);
    622     placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)),
    623                                        DT_FLOAT);
    624     placeholders_in_order.back().node()->AddAttr("_start_time", i + 1);
    625   }
    626   std::vector<Placeholder> placeholders;
    627   placeholders.reserve(indexes.size());
    628   for (int i : indexes) {
    629     placeholders.push_back(placeholders_in_order[i]);
    630   }
    631 
    632   // Add a while loop above each placeholder.
    633   std::vector<Exit> while_exits;
    634   const int nodes_per_loop = 8;
    635   for (int i : indexes) {
    636     Scope scope = root.NewSubScope(strings::StrCat("while", i));
    637     auto dummy = Placeholder(scope, DT_FLOAT);
    638 
    639     Enter enter(scope, placeholders[i], strings::StrCat("frame", i));
    640     Merge merge(scope, std::initializer_list<Input>{enter, dummy});
    641     auto cv = Const(scope.WithControlDependencies({merge.output}), false);
    642     LoopCond loop_cond(scope, cv);
    643     Switch switch_node(scope, merge.output, loop_cond);
    644     Identity identity(scope, switch_node.output_true);
    645     NextIteration next_iteration(scope, identity);
    646     while_exits.emplace_back(scope.WithOpName("exit"),
    647                              switch_node.output_false);
    648 
    649     // Complete loop by removing dummy node and attaching NextIteration to
    650     // that input of the merge node.
    651     scope.graph()->RemoveNode(dummy.node());
    652     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
    653 
    654     int base_start_time = i * 10 + 100;
    655     for (const auto& op : std::initializer_list<Output>{
    656              enter, merge.output, cv, loop_cond, switch_node.output_false,
    657              identity, next_iteration, while_exits.back()}) {
    658       op.node()->AddAttr("_start_time", base_start_time++);
    659     }
    660   }
    661 
    662   // Create ops that depend on the loop exits.
    663   std::vector<Square> squares;
    664   squares.reserve(indexes.size());
    665   for (int i : indexes) {
    666     squares.emplace_back(root.WithOpName(strings::StrCat("s", i)),
    667                          while_exits[i]);
    668     squares.back().node()->AddAttr("_start_time", 500 - (i + 1));
    669   }
    670 
    671   GraphDef gdef;
    672   TF_EXPECT_OK(root.ToGraphDef(&gdef));
    673 
    674   // Run the sort. The while loop nodes do not appear in the output <nodes>.
    675   std::vector<std::pair<const NodeDef*, int64>> nodes;
    676   std::unordered_map<const NodeDef*, int64> node_to_start_time;
    677   TF_CHECK_OK(
    678       TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
    679   ASSERT_LT(while_exits.size() + squares.size() + placeholders.size(),
    680             nodes.size());
    681   int node_index = 0;
    682   for (int i = 0; i < placeholders.size(); ++i, ++node_index) {
    683     const NodeDef* node = nodes[i].first;
    684     EXPECT_EQ(strings::StrCat("p", i), node->name());
    685     EXPECT_EQ(i + 1, nodes[i].second);
    686     EXPECT_EQ(i + 1, node_to_start_time[node]);
    687   }
    688   for (int i = 0; i < while_exits.size(); ++i, node_index += nodes_per_loop) {
    689     const NodeDef* node = nodes[node_index].first;
    690     EXPECT_EQ(strings::StrCat("while", i, "/Enter"), node->name());
    691     EXPECT_EQ(100 + i * 10, nodes[node_index].second);
    692     EXPECT_EQ(100 + i * 10, node_to_start_time[node]);
    693   }
    694   for (int i = 0; i < squares.size(); ++i, ++node_index) {
    695     int square_index = num_leaves - 1 - i;
    696     const NodeDef* node = nodes[node_index].first;
    697     EXPECT_EQ(strings::StrCat("s", square_index), node->name());
    698     EXPECT_EQ(500 - (square_index + 1), nodes[node_index].second);
    699     EXPECT_EQ(500 - (square_index + 1), node_to_start_time[node]);
    700   }
    701 }
    702 
    703 }  // namespace
    704 }  // namespace tensorflow
    705