Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/subgraph.h"
     17 
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/graph.pb.h"
     22 #include "tensorflow/core/framework/partial_tensor_shape.h"
     23 #include "tensorflow/core/graph/graph.h"
     24 #include "tensorflow/core/graph/graph_constructor.h"
     25 #include "tensorflow/core/graph/graph_def_builder.h"
     26 #include "tensorflow/core/graph/graph_def_builder_util.h"
     27 #include "tensorflow/core/kernels/ops_util.h"
     28 #include "tensorflow/core/lib/core/status.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/protobuf.h"
     33 #include "tensorflow/core/platform/test.h"
     34 #include "tensorflow/core/platform/test_benchmark.h"
     35 
     36 // TODO(josh11b): Test setting the "device" field of a NodeDef.
     37 // TODO(josh11b): Test that feeding won't prune targets.
     38 
     39 namespace tensorflow {
     40 namespace {
     41 
     42 class SubgraphTest : public ::testing::Test {
     43  protected:
     44   SubgraphTest() : g_(new Graph(OpRegistry::Global())) {
     45     device_info_.set_name("/job:a/replica:0/task:0/cpu:0");
     46     device_info_.set_device_type(DeviceType(DEVICE_CPU).type());
     47     device_info_.set_incarnation(0);
     48   }
     49 
     50   ~SubgraphTest() override {}
     51 
     52   void ExpectOK(const string& gdef_ascii) {
     53     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_));
     54     GraphConstructorOptions opts;
     55     TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
     56   }
     57 
     58   Node* FindNode(const string& name) {
     59     for (Node* n : g_->nodes()) {
     60       if (n->name() == name) return n;
     61     }
     62     return nullptr;
     63   }
     64 
     65   bool HasNode(const string& name) { return FindNode(name) != nullptr; }
     66 
     67   void ExpectNodes(const string& nodes) {
     68     int count = 0;
     69     std::vector<string> actual_nodes;
     70     for (Node* n : g_->nodes()) {
     71       if (n->IsOp()) {
     72         count++;
     73         actual_nodes.push_back(n->name());
     74       }
     75     }
     76     std::sort(actual_nodes.begin(), actual_nodes.end());
     77 
     78     LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " ");
     79 
     80     std::vector<string> expected_nodes = str_util::Split(nodes, ',');
     81     std::sort(expected_nodes.begin(), expected_nodes.end());
     82     for (const string& s : expected_nodes) {
     83       Node* n = FindNode(s);
     84       EXPECT_TRUE(n != nullptr) << s;
     85       if (n->type_string() == "_Send" || n->type_string() == "_Recv") {
     86         EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s;
     87       }
     88     }
     89 
     90     EXPECT_TRUE(actual_nodes.size() == expected_nodes.size())
     91         << "\nActual:   " << str_util::Join(actual_nodes, ",")
     92         << "\nExpected: " << str_util::Join(expected_nodes, ",");
     93   }
     94 
     95   bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) {
     96     for (const Edge* e : g_->edges()) {
     97       if (e->src()->name() == src && e->src_output() == src_out &&
     98           e->dst()->name() == dst && e->dst_input() == dst_in)
     99         return true;
    100     }
    101     return false;
    102   }
    103   bool HasControlEdge(const string& src, const string& dst) {
    104     return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
    105   }
    106 
    107   string Subgraph(const string& fed_str, const string& fetch_str,
    108                   const string& targets_str,
    109                   bool use_function_convention = false) {
    110     Graph* subgraph = new Graph(OpRegistry::Global());
    111     CopyGraph(*g_, subgraph);
    112     std::vector<string> fed =
    113         str_util::Split(fed_str, ',', str_util::SkipEmpty());
    114     std::vector<string> fetch =
    115         str_util::Split(fetch_str, ',', str_util::SkipEmpty());
    116     std::vector<string> targets =
    117         str_util::Split(targets_str, ',', str_util::SkipEmpty());
    118 
    119     subgraph::RewriteGraphMetadata metadata;
    120     Status s = subgraph::RewriteGraphForExecution(
    121         subgraph, fed, fetch, targets, device_info_, use_function_convention,
    122         &metadata);
    123     if (!s.ok()) {
    124       delete subgraph;
    125       return s.ToString();
    126     }
    127 
    128     EXPECT_EQ(fed.size(), metadata.feed_types.size());
    129     EXPECT_EQ(fetch.size(), metadata.fetch_types.size());
    130 
    131     // Replace the graph with the subgraph for the rest of the display program
    132     g_.reset(subgraph);
    133     return "OK";
    134   }
    135 
    136   Graph* graph() { return g_.get(); }
    137 
    138  private:
    139   GraphDef gdef_;
    140   std::unique_ptr<Graph> g_;
    141   DeviceAttributes device_info_;
    142 };
    143 
    144 REGISTER_OP("TestParams").Output("o: float");
    145 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
    146 REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
    147 REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
    148 
    149 TEST_F(SubgraphTest, Targets1) {
    150   ExpectOK(
    151       "node { name: 'W1' op: 'TestParams' }"
    152       "node { name: 'W2' op: 'TestParams' }"
    153       "node { name: 'input' op: 'TestInput' }"
    154       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    155       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    156       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    157       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    158   EXPECT_EQ("OK", Subgraph("", "", "t1"));
    159   ExpectNodes("W1,input,t1");
    160 }
    161 
    162 TEST_F(SubgraphTest, Targets2) {
    163   ExpectOK(
    164       "node { name: 'W1' op: 'TestParams' }"
    165       "node { name: 'W2' op: 'TestParams' }"
    166       "node { name: 'input' op: 'TestInput' }"
    167       "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }"
    168       "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }"
    169       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    170       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    171   EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a"));
    172   ExpectNodes("W1,W2,input,t1,t2,t3_a");
    173 }
    174 
    175 TEST_F(SubgraphTest, FedOutputs1) {
    176   ExpectOK(
    177       "node { name: 'W1' op: 'TestParams' }"
    178       "node { name: 'W2' op: 'TestParams' }"
    179       "node { name: 'input' op: 'TestInput' }"
    180       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    181       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    182       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    183       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    184   EXPECT_EQ("OK", Subgraph("input:1", "", "t2"));
    185   ExpectNodes("W1,W2,_recv_input_1,t1,t2");
    186 }
    187 
    188 TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) {
    189   ExpectOK(
    190       "node { name: 'W1' op: 'TestParams' }"
    191       "node { name: 'W2' op: 'TestParams' }"
    192       "node { name: 'input' op: 'TestInput' }"
    193       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    194       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    195       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    196       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    197   EXPECT_EQ("OK",
    198             Subgraph("input:1", "", "t2", true /* use_function_convention */));
    199   ExpectNodes("W1,W2,_arg_input_1_0,t1,t2");
    200 }
    201 
    202 TEST_F(SubgraphTest, FedRefNode) {
    203   ExpectOK(
    204       "node { name: 'W1' op: 'TestParams' }"
    205       "node { name: 'W2' op: 'TestParams' }"
    206       "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
    207   EXPECT_EQ("OK", Subgraph("W1:0", "", "t1"));
    208   ExpectNodes("_recv_W1_0,W2,t1");
    209   Node* n = FindNode("_recv_W1_0");
    210   EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
    211 }
    212 
    213 TEST_F(SubgraphTest, FedRefNode_FunctionConvention) {
    214   ExpectOK(
    215       "node { name: 'W1' op: 'TestParams' }"
    216       "node { name: 'W2' op: 'TestParams' }"
    217       "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
    218   EXPECT_EQ("OK",
    219             Subgraph("W1:0", "", "t1", true /* use_function_convention */));
    220   ExpectNodes("_arg_W1_0_0,W2,t1");
    221   Node* n = FindNode("_arg_W1_0_0");
    222   EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
    223 }
    224 
    225 TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) {
    226   ExpectOK(
    227       "node { name: 'W1' op: 'TestParams' }"
    228       "node { name: 'W2' op: 'TestParams' }"
    229       "node { name: 'input' op: 'TestInput' }"
    230       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    231       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    232       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    233       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    234   // We feed input:1, but nothing connects to it, so the _recv(input:1)
    235   // node also disappears.
    236   EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2",
    237                            true /* use_function_convention */));
    238   ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2");
    239 }
    240 
    241 TEST_F(SubgraphTest, FetchOutputs1) {
    242   ExpectOK(
    243       "node { name: 'W1' op: 'TestParams' }"
    244       "node { name: 'W2' op: 'TestParams' }"
    245       "node { name: 'input' op: 'TestInput' }"
    246       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    247       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    248       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    249       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    250   EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2"));
    251   ExpectNodes(
    252       "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
    253 }
    254 
    255 TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) {
    256   ExpectOK(
    257       "node { name: 'W1' op: 'TestParams' }"
    258       "node { name: 'W2' op: 'TestParams' }"
    259       "node { name: 'input' op: 'TestInput' }"
    260       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    261       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    262       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    263       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    264   EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2",
    265                            true /* use_function_convention */));
    266   ExpectNodes(
    267       "W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_"
    268       "retval_t2_0_3");
    269 }
    270 
    271 TEST_F(SubgraphTest, FetchOutputs2) {
    272   ExpectOK(
    273       "node { name: 'W1' op: 'TestParams' }"
    274       "node { name: 'W2' op: 'TestParams' }"
    275       "node { name: 'input' op: 'TestInput' }"
    276       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    277       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    278       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    279       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    280   EXPECT_EQ("OK", Subgraph("", "t3_a", "t2"));
    281   ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
    282 }
    283 
    284 TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) {
    285   ExpectOK(
    286       "node { name: 'W1' op: 'TestParams' }"
    287       "node { name: 'W2' op: 'TestParams' }"
    288       "node { name: 'input' op: 'TestInput' }"
    289       "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
    290       "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
    291       "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
    292       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
    293   EXPECT_EQ("OK",
    294             Subgraph("", "t3_a", "t2", true /* use_function_convention */));
    295   ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0");
    296 }
    297 
    298 TEST_F(SubgraphTest, ChainOfFools) {
    299   ExpectOK(
    300       "node { name: 'a' op: 'TestParams' }"
    301       "node { name: 'b' op: 'TestRelu' input: 'a'}"
    302       "node { name: 'c' op: 'TestRelu' input: 'b'}"
    303       "node { name: 'd' op: 'TestRelu' input: 'c'}"
    304       "node { name: 'e' op: 'TestRelu' input: 'd'}"
    305       "node { name: 'f' op: 'TestRelu' input: 'e'}");
    306   EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", ""));
    307   ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0");
    308   EXPECT_TRUE(HasEdge("a", 0, "b", 0));
    309   EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0));
    310   EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0));
    311   EXPECT_TRUE(HasEdge("d", 0, "e", 0));
    312   EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0));
    313 }
    314 
    315 static bool HasSubstr(const string& base, const string& substr) {
    316   bool ok = StringPiece(base).contains(substr);
    317   EXPECT_TRUE(ok) << base << ", expected substring " << substr;
    318   return ok;
    319 }
    320 
    321 TEST_F(SubgraphTest, Errors) {
    322   ExpectOK(
    323       "node { name: 'a' op: 'TestParams' }"
    324       "node { name: 'b' op: 'TestRelu' input: 'a'}"
    325       "node { name: 'c' op: 'TestRelu' input: 'b'}"
    326       "node { name: 'd' op: 'TestRelu' input: 'c'}"
    327       "node { name: 'e' op: 'TestRelu' input: 'd'}"
    328       "node { name: 'f' op: 'TestRelu' input: 'e'}");
    329   // Duplicated feed and fetch
    330   EXPECT_TRUE(
    331       HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched"));
    332   // Feed not found.
    333   EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "c:0", ""), "unable to find"));
    334   // Fetch not found.
    335   EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found"));
    336   // Target not found.
    337   EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found"));
    338   // No targets specified.
    339   EXPECT_TRUE(HasSubstr(Subgraph("", "", ""), "at least one target"));
    340 }
    341 
    342 REGISTER_OP("In").Output("o: float");
    343 REGISTER_OP("Op").Input("i: float").Output("o: float");
    344 
    345 static void BM_SubgraphHelper(int iters, int num_nodes,
    346                               bool use_function_convention) {
    347   DeviceAttributes device_info;
    348   device_info.set_name("/job:a/replica:0/task:0/cpu:0");
    349   device_info.set_device_type(DeviceType(DEVICE_CPU).type());
    350   device_info.set_incarnation(0);
    351 
    352   testing::StopTiming();
    353   Graph g(OpRegistry::Global());
    354   {  // Scope for temporary variables used to construct g.
    355     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
    356     Node* last_node = nullptr;
    357     for (int i = 0; i < num_nodes; i++) {
    358       string name = strings::StrCat("N", i);
    359       if (i > 0) {
    360         last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name));
    361       } else {
    362         last_node = ops::SourceOp("In", b.opts().WithName(name));
    363       }
    364     }
    365     TF_CHECK_OK(GraphDefBuilderToGraph(b, &g));
    366   }
    367 
    368   std::vector<string> fed;
    369   if (num_nodes > 1000) {
    370     fed.push_back(strings::StrCat("N", num_nodes - 1000));
    371   }
    372   std::vector<string> fetch;
    373   std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)};
    374   testing::StartTiming();
    375   while (--iters > 0) {
    376     Graph* subgraph = new Graph(OpRegistry::Global());
    377     CopyGraph(g, subgraph);
    378     subgraph::RewriteGraphMetadata metadata;
    379     TF_CHECK_OK(subgraph::RewriteGraphForExecution(
    380         subgraph, fed, fetch, targets, device_info, use_function_convention,
    381         &metadata));
    382     delete subgraph;
    383   }
    384 }
    385 
    386 static void BM_Subgraph(int iters, int num_nodes) {
    387   BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */);
    388 }
    389 static void BM_SubgraphFunctionConvention(int iters, int num_nodes) {
    390   BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */);
    391 }
    392 BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
    393 BENCHMARK(BM_SubgraphFunctionConvention)
    394     ->Arg(100)
    395     ->Arg(1000)
    396     ->Arg(10000)
    397     ->Arg(100000);
    398 
    399 }  // namespace
    400 }  // namespace tensorflow
    401