Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
     17 
     18 #include "tensorflow/core/common_runtime/device.h"
     19 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
     20 #include "tensorflow/core/framework/graph.pb.h"
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/tensor_testutil.h"
     23 #include "tensorflow/core/graph/default_device.h"
     24 #include "tensorflow/core/graph/graph.h"
     25 #include "tensorflow/core/graph/testlib.h"
     26 #include "tensorflow/core/lib/core/error_codes.pb.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/platform/env.h"
     29 #include "tensorflow/core/platform/init_main.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/test.h"
     32 #include "tensorflow/core/public/session.h"
     33 #include "tensorflow/core/util/port.h"
     34 
     35 namespace tensorflow {
     36 
     37 static SessionOptions Devices(int num_cpus, int num_gpus) {
     38   SessionOptions result;
     39   (*result.config.mutable_device_count())["CPU"] = num_cpus;
     40   (*result.config.mutable_device_count())["GPU"] = num_gpus;
     41   return result;
     42 }
     43 
     44 void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
     45   Graph graph(OpRegistry::Global());
     46 
     47   Tensor a_tensor(DT_FLOAT, TensorShape({1, 2}));
     48   test::FillValues<float>(&a_tensor, {1, 2});
     49   Node* a = test::graph::Constant(&graph, a_tensor);
     50   node_names[0] = a->name();
     51 
     52   Tensor b_tensor(DT_FLOAT, TensorShape({2, 1}));
     53   test::FillValues<float>(&b_tensor, {2, 1});
     54   Node* b = test::graph::Constant(&graph, b_tensor);
     55   node_names[1] = b->name();
     56 
     57   Node* c = test::graph::Matmul(&graph, a, b, false, false);
     58   node_names[2] = c->name();
     59 
     60   test::graph::ToGraphDef(&graph, graph_def);
     61 }
     62 
     63 // Asserts that "val" is a single float tensor. The only float is
     64 // "expected_val".
     65 static void IsSingleFloatValue(const Tensor& val, float expected_val) {
     66   ASSERT_EQ(val.dtype(), DT_FLOAT);
     67   ASSERT_EQ(val.NumElements(), 1);
     68   ASSERT_EQ(val.flat<float>()(0), expected_val);
     69 }
     70 
     71 static SessionOptions Options(const string& target, int placement_period) {
     72   SessionOptions options;
     73   // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
     74   // string.
     75   options.target = strings::StrCat("grpc://", target);
     76   options.config.set_placement_period(placement_period);
     77   options.config.mutable_graph_options()
     78       ->mutable_optimizer_options()
     79       ->set_opt_level(OptimizerOptions::L0);
     80   return options;
     81 }
     82 
     83 static Session* NewRemote(const SessionOptions& options) {
     84   return CHECK_NOTNULL(NewSession(options));
     85 }
     86 
     87 TEST(GrpcSessionTest, BasicNonProtoAPI) {
     88   GraphDef graph;
     89   string node_names[3];
     90   // c = a * b
     91   CreateGraphDef(&graph, node_names);
     92 
     93   std::unique_ptr<test::TestCluster> cluster;
     94   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
     95 
     96   std::unique_ptr<Session> session(
     97       NewRemote(Options(cluster->targets()[0], 1)));
     98   ASSERT_TRUE(session != nullptr);
     99 
    100   for (int iters = 0; iters < 25; ++iters) {
    101     TF_CHECK_OK(session->Create(graph));
    102     {
    103       // Just run to target node
    104       std::vector<std::pair<string, Tensor>> inputs;
    105       std::vector<string> targets = {node_names[2]};
    106       TF_CHECK_OK(session->Run(inputs, {}, targets, nullptr));
    107     }
    108     {
    109       // Run to a target node and a real tensor
    110       std::vector<std::pair<string, Tensor>> inputs;
    111       std::vector<string> names = {node_names[2] + ":0"};
    112       std::vector<string> targets = {node_names[1]};
    113       std::vector<Tensor> outputs;
    114       TF_CHECK_OK(session->Run(inputs, names, targets, &outputs));
    115       ASSERT_TRUE(outputs[0].IsInitialized());
    116       ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
    117     }
    118 
    119     TF_CHECK_OK(session->Close());
    120   }
    121 }
    122 
    123 TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
    124   GraphDef graph;
    125   string node_names[3];
    126   // c = a * b
    127   CreateGraphDef(&graph, node_names);
    128 
    129   std::unique_ptr<test::TestCluster> cluster;
    130   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    131 
    132   std::unique_ptr<Session> session(
    133       NewRemote(Options(cluster->targets()[0], 1)));
    134   ASSERT_TRUE(session != nullptr);
    135   ASSERT_TRUE(session->Create(graph).ok());
    136 
    137   // Test that the order of the output names matches the order of the
    138   // returned Tensors.
    139   std::vector<std::pair<string, Tensor>> inputs;
    140   std::vector<string> names = {node_names[2] + ":0", node_names[0] + ":0",
    141                                node_names[1] + ":0"};
    142 
    143   std::vector<string> target_ops = {node_names[1]};
    144   std::vector<Tensor> outputs;
    145   ASSERT_TRUE(session->Run(inputs, names, target_ops, &outputs).ok());
    146   ASSERT_TRUE(outputs[0].IsInitialized());
    147   ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
    148   ASSERT_TRUE(outputs[1].IsInitialized());
    149   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
    150   ASSERT_TRUE(outputs[2].IsInitialized());
    151   ASSERT_EQ(2.0, outputs[2].flat<float>()(0));
    152   ASSERT_TRUE(session->Close().ok());
    153 }
    154 
    155 TEST(GrpcSessionTest, NonLocalWithFilters) {
    156   GraphDef graph;
    157   string node_names[3];
    158   // c = a * b
    159   CreateGraphDef(&graph, node_names);
    160 
    161   std::unique_ptr<test::TestCluster> cluster;
    162   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    163 
    164   SessionOptions options;
    165   options.target = strings::StrCat("grpc://", cluster->targets()[0]);
    166   options.config.add_device_filters(cluster->devices()[0].name());
    167 
    168   std::unique_ptr<Session> session(NewRemote(options));
    169   ASSERT_TRUE(session != nullptr);
    170 
    171   {
    172     GraphDef graph_copy(graph);
    173     graph::SetDefaultDevice(cluster->devices()[0].name(), &graph_copy);
    174     TF_CHECK_OK(session->Create(graph_copy));
    175     TF_CHECK_OK(session->Run({}, {}, {node_names[2]}, nullptr));
    176     TF_CHECK_OK(session->Close());
    177   }
    178   {
    179     GraphDef graph_copy(graph);
    180     graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy);
    181     auto status = session->Create(graph_copy);
    182     EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
    183   }
    184 }
    185 
    186 TEST(GrpcSessionTest, FetchMultipleTimes) {
    187   GraphDef graph;
    188   string node_names[3];
    189   CreateGraphDef(&graph, node_names);
    190 
    191   std::unique_ptr<test::TestCluster> cluster;
    192   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    193 
    194   std::unique_ptr<Session> session(
    195       NewRemote(Options(cluster->targets()[0], 1)));
    196   ASSERT_TRUE(session != nullptr);
    197 
    198   TF_CHECK_OK(session->Create(graph));
    199   const std::vector<std::pair<string, Tensor>> inputs;
    200   std::vector<Tensor> outputs;
    201 
    202   const string node = node_names[2] + ":0";
    203   TF_CHECK_OK(session->Run(inputs, {node, node}, {}, &outputs));
    204   EXPECT_EQ(2, outputs.size());
    205   for (int i = 0; i < outputs.size(); ++i) {
    206     const Tensor& t = outputs[i];
    207     ASSERT_TRUE(t.IsInitialized()) << i;
    208     ASSERT_EQ(4.0, t.flat<float>()(0)) << i;
    209   }
    210   TF_CHECK_OK(session->Close());
    211 }
    212 
    213 // A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
    214 // eigenvalue for A, which is 2.0. Iteratively, we do
    215 //   repeat x = y / y.norm(); y = A * x; end
    216 // At the end, we expect "lambda" converges to 2.0.
    217 void FindMaxEigen(const string& target) {
    218   Graph graph(OpRegistry::Global());
    219 
    220   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
    221   // Store rows [3, 2] and [-1, 0] in row major format.
    222   test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
    223   Node* a = test::graph::Constant(&graph, a_tensor);
    224 
    225   // x is from the feed.
    226   Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
    227   test::FillValues<float>(&x_tensor, {0, 0});
    228   Node* x = test::graph::Constant(&graph, x_tensor);
    229 
    230   // y = A * x
    231   Node* y = test::graph::Matmul(&graph, a, x, false, false);
    232 
    233   // y2 = y.^2
    234   Node* y2 = test::graph::Unary(&graph, "Square", y);
    235 
    236   // const tensor for reduction
    237   Tensor rdim_tensor(DT_INT32, TensorShape({}));
    238   rdim_tensor.scalar<int32>()() = 0;
    239   Node* rdim = test::graph::Constant(&graph, rdim_tensor);
    240 
    241   // y2_sum = sum(y2)
    242   Node* y2_sum = test::graph::Reduce(&graph, "Sum", y2, rdim);
    243 
    244   // y_norm = sqrt(y2_sum)
    245   Node* y_norm = test::graph::Unary(&graph, "Sqrt", y2_sum);
    246 
    247   // y_normalized = y ./ y_norm
    248   Node* y_normalized = test::graph::Binary(&graph, "Div", y, y_norm);
    249 
    250   GraphDef def;
    251   test::graph::ToGraphDef(&graph, &def);
    252 
    253   std::unique_ptr<Session> session(NewRemote(Options(target, 1)));
    254   ASSERT_TRUE(session != nullptr);
    255   TF_CHECK_OK(session->Create(def));
    256 
    257   // Setup feeds and fetches.
    258   float lambda;
    259   Tensor feed_value(DT_FLOAT, TensorShape({2, 1}));
    260   feed_value.matrix<float>()(0, 0) = -3.1415;
    261   feed_value.matrix<float>()(1, 0) = +2.7183;
    262 
    263   for (int i = 0; i < 25; ++i) {
    264     std::vector<Tensor> outputs;
    265     TF_CHECK_OK(session->Run({{x->name(), feed_value}},
    266                              {y->name(), y_normalized->name()}, {}, &outputs));
    267     const Tensor& y = outputs[0];
    268     const Tensor& y_normalized = outputs[1];
    269     // Print out lambda, x, and y.
    270     CHECK_EQ(2, feed_value.NumElements());
    271     CHECK_EQ(2, y.NumElements());
    272     lambda = y.flat<float>()(0) / feed_value.flat<float>()(0);
    273     printf("%06d lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]\n", i,
    274            lambda, feed_value.flat<float>()(0), feed_value.flat<float>()(1),
    275            y.flat<float>()(0), y.flat<float>()(1));
    276     // Copies y_normalized to  *x.
    277     feed_value = y_normalized;
    278   }
    279   EXPECT_NEAR(2.0, lambda, 1e-6);
    280 }
    281 
    282 TEST(FindMaxEigenTest, RemoteDevice) {
    283   std::unique_ptr<test::TestCluster> cluster;
    284   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    285   FindMaxEigen(cluster->targets()[0]);
    286 }
    287 
    288 void SetDevice(GraphDef* graph, const string& name, const string& dev) {
    289   for (int i = 0; i < graph->node_size(); ++i) {
    290     if (graph->node(i).name() == name) {
    291       graph->mutable_node(i)->set_device(dev);
    292       return;
    293     }
    294   }
    295   LOG(FATAL) << "Name '" << name << "' not found.";
    296 }
    297 
    298 // TODO(b/32636929): This test fails 1/1000 times. Disable it while we
    299 // figure out why.
    300 TEST(GrpcSessionTest, DISABLED_MultiDevices) {
    301   std::unique_ptr<test::TestCluster> cluster;
    302   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    303 
    304   Graph graph(OpRegistry::Global());
    305   const int kSize = 1048576;
    306 
    307   // c = a * b = 2 * 3 * kSize
    308   Tensor a_tensor(DT_FLOAT, TensorShape({1, kSize}));
    309   Tensor b_tensor(DT_FLOAT, TensorShape({kSize, 1}));
    310   for (int i = 0; i < kSize; ++i) {
    311     a_tensor.flat<float>()(i) = 2;
    312     b_tensor.flat<float>()(i) = 3;
    313   }
    314   Node* a = test::graph::Constant(&graph, a_tensor);
    315   Node* b = test::graph::Constant(&graph, b_tensor);
    316   Node* c = test::graph::Matmul(&graph, a, b, false, false);
    317 
    318   GraphDef def;
    319   test::graph::ToGraphDef(&graph, &def);
    320 
    321   // In this test, we force each node (a, b, c) on every possible device.
    322   // We test all possible cases.
    323   for (const auto& a_dev : cluster->devices()) {
    324     for (const auto& b_dev : cluster->devices()) {
    325       for (const auto& c_dev : cluster->devices()) {
    326         LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name()
    327                   << " c: " << c_dev.name();
    328 
    329         SetDevice(&def, a->name(), a_dev.name());
    330         SetDevice(&def, b->name(), b_dev.name());
    331         SetDevice(&def, c->name(), c_dev.name());
    332 
    333         std::unique_ptr<Session> session(
    334             NewRemote(Options(cluster->targets()[0], 1000)));
    335         ASSERT_TRUE(session != nullptr);
    336         TF_CHECK_OK(session->Create(def));
    337         {
    338           std::vector<Tensor> outputs;
    339           RunOptions options;
    340           options.set_trace_level(RunOptions::FULL_TRACE);
    341           RunMetadata metadata;
    342           TF_CHECK_OK(
    343               session->Run(options, {}, {c->name()}, {}, &outputs, &metadata));
    344           ASSERT_EQ(1, outputs.size());
    345           IsSingleFloatValue(outputs[0], 6.0 * kSize);
    346 
    347           const StepStats& ss = metadata.step_stats();
    348           // NOTE(mrry): We only assert that `c` is placed correctly,
    349           // because the current placement algorithm will move its
    350           // inputs to be colocated with it, when it is the sole
    351           // consumer.
    352           bool c_placed_correctly = false;
    353           for (const auto& dev : ss.dev_stats()) {
    354             for (const auto& node : dev.node_stats()) {
    355               if (node.node_name() == c->name() &&
    356                   dev.device() == c_dev.name()) {
    357                 c_placed_correctly = true;
    358               }
    359             }
    360           }
    361           ASSERT_TRUE(c_placed_correctly);
    362         }
    363         TF_CHECK_OK(session->Close());
    364       }
    365     }
    366   }
    367 }
    368 
    369 TEST(GrpcSessionTest, LargeTensorSend) {
    370   std::unique_ptr<test::TestCluster> cluster;
    371   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    372 
    373   Graph graph(OpRegistry::Global());
    374 
    375   // Define a 3 GB fill result.
    376   Tensor fill_shape_tensor(DT_INT32, TensorShape({4}));
    377   fill_shape_tensor.vec<int32>()(0) = 1;
    378   fill_shape_tensor.vec<int32>()(1) = 256;
    379   fill_shape_tensor.vec<int32>()(2) = 1024;
    380   fill_shape_tensor.vec<int32>()(3) = 1024;
    381   Node* fill_shape_node = test::graph::Constant(&graph, fill_shape_tensor);
    382 
    383   Tensor fill_val_tensor(DT_FLOAT, TensorShape({}));
    384   fill_val_tensor.flat<float>()(0) = 1.0;
    385   Node* fill_val_node = test::graph::Constant(&graph, fill_val_tensor);
    386 
    387   Node* fill_node =
    388       test::graph::Binary(&graph, "Fill", fill_shape_node, fill_val_node);
    389 
    390   Tensor max_axes_tensor(DT_INT32, TensorShape({4}));
    391   max_axes_tensor.vec<int32>()(0) = 0;
    392   max_axes_tensor.vec<int32>()(1) = 1;
    393   max_axes_tensor.vec<int32>()(2) = 2;
    394   max_axes_tensor.vec<int32>()(3) = 3;
    395   Node* max_axes_node = test::graph::Constant(&graph, max_axes_tensor);
    396   Node* max_node = test::graph::Reduce(&graph, "Max", fill_node, max_axes_node);
    397 
    398   GraphDef def;
    399   test::graph::ToGraphDef(&graph, &def);
    400 
    401   SetDevice(&def, fill_node->name(), cluster->devices()[0].name());
    402   SetDevice(&def, fill_node->name(), cluster->devices()[1].name());
    403 
    404   std::unique_ptr<Session> session(
    405       NewRemote(Options(cluster->targets()[0], 1000)));
    406   ASSERT_TRUE(session != nullptr);
    407   TF_CHECK_OK(session->Create(def));
    408   {
    409     std::vector<Tensor> outputs;
    410     TF_CHECK_OK(session->Run({}, {max_node->name()}, {}, &outputs));
    411     ASSERT_EQ(1, outputs.size());
    412     IsSingleFloatValue(outputs[0], 1.0);
    413   }
    414   TF_CHECK_OK(session->Close());
    415 }
    416 
    417 TEST(GrpcSessionTest, MultiDevices_String) {
    418   std::unique_ptr<test::TestCluster> cluster;
    419   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
    420   std::unique_ptr<Session> session(
    421       NewRemote(Options(cluster->targets()[0], 1000)));
    422   ASSERT_TRUE(session != nullptr);
    423 
    424   // b = a
    425   Graph graph(OpRegistry::Global());
    426   Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
    427   for (int i = 0; i < 4; ++i) {
    428     a_tensor.flat<string>()(i) = "hello, world";
    429   }
    430   Node* a = test::graph::Constant(&graph, a_tensor);
    431   Node* b = test::graph::Identity(&graph, a);
    432 
    433   GraphDef def;
    434   test::graph::ToGraphDef(&graph, &def);
    435 
    436   // In this test, we force each node (a, b) on every possible device.
    437   // We test all possible cases.
    438   for (const auto& a_dev : cluster->devices()) {
    439     for (const auto& b_dev : cluster->devices()) {
    440       LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name();
    441       SetDevice(&def, a->name(), a_dev.name());
    442       SetDevice(&def, b->name(), b_dev.name());
    443 
    444       Status s = session->Create(def);
    445       if (s.ok()) {
    446         std::vector<Tensor> outputs;
    447         TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
    448         ASSERT_EQ(1, outputs.size());
    449         ASSERT_EQ(outputs[0].dtype(), DT_STRING);
    450         ASSERT_EQ(outputs[0].NumElements(), 4);
    451         for (int i = 0; i < outputs[0].NumElements(); ++i) {
    452           EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
    453         }
    454         TF_CHECK_OK(session->Close());
    455       } else {
    456         LOG(ERROR) << "Error: " << s;
    457         ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
    458                     (b_dev.device_type() == DEVICE_GPU));
    459         ASSERT_FALSE(s.ok());
    460       }
    461     }
    462   }
    463 }
    464 
    465 TEST(GrpcSessionTest, SendRecv_Node_Naming) {
    466   std::unique_ptr<test::TestCluster> cluster;
    467   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 3, &cluster));
    468   std::unique_ptr<Session> session(
    469       NewRemote(Options(cluster->targets()[0], 1)));
    470   ASSERT_TRUE(session != nullptr);
    471 
    472   // This test case needs at least 3 devices.
    473   CHECK_GE(cluster->devices().size(), 3);
    474   const DeviceAttributes& src = cluster->devices()[0];
    475   const DeviceAttributes& dst0 = cluster->devices()[1];
    476   const DeviceAttributes& dst1 = cluster->devices()[2];
    477   LOG(INFO) << "src = " << src.name() << " dst0 = " << dst0.name()
    478             << " dst1 = " << dst1.name();
    479 
    480   // Within the same session, we compute two subgraphs:
    481   //   1) a on 'src' sends to b on 'dst0';
    482   //   2) a on 'src' sends to c on 'dst1'.
    483   Graph graph(OpRegistry::Global());
    484   Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
    485   a_tensor.flat<float>()(0) = 100;
    486   Node* a = test::graph::Constant(&graph, a_tensor);
    487   Node* b = test::graph::Identity(&graph, a);
    488   Node* c = test::graph::Identity(&graph, a);
    489 
    490   GraphDef def;
    491   test::graph::ToGraphDef(&graph, &def);
    492 
    493   // The base graph have a, b, c, assigned to devices explicitly.
    494   SetDevice(&def, a->name(), src.name());
    495   SetDevice(&def, b->name(), dst0.name());
    496   SetDevice(&def, c->name(), dst1.name());
    497   TF_CHECK_OK(session->Create(def));
    498 
    499   // Run subgraph a -> b, and fetch b.
    500   {
    501     std::vector<Tensor> outputs;
    502     TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
    503     ASSERT_EQ(1, outputs.size());
    504     IsSingleFloatValue(outputs[0], 100);
    505   }
    506 
    507   // Run subgraph a -> c, and fetch c.
    508   {
    509     std::vector<Tensor> outputs;
    510     TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
    511     ASSERT_EQ(1, outputs.size());
    512     IsSingleFloatValue(outputs[0], 100);
    513   }
    514 
    515   TF_CHECK_OK(session->Close());
    516 }
    517 
    518 TEST(GrpcSessionTest, Error) {
    519   std::unique_ptr<test::TestCluster> cluster;
    520   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    521   const string& master = cluster->targets()[0];
    522   const string& dev_a = cluster->devices()[0].name();
    523   const string& dev_b = cluster->devices()[1].name();
    524   LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
    525   GraphDef gdef;
    526   std::vector<string> fetches;
    527   {
    528     Graph g(OpRegistry::Global());
    529 
    530     // a2 = a + error(a)
    531     //
    532     // Subgraph for "a" fails. The master will cancel the subgraph for
    533     // "b" and then returns the Session::Run.
    534     auto a = test::graph::Constant(&g, Tensor());
    535     a->set_assigned_device_name(dev_a);
    536     auto a_err = test::graph::Error(&g, a, "fantasia!");
    537     a_err->set_assigned_device_name(dev_a);
    538     auto a2 = test::graph::Add(&g, a, a_err);
    539     a2->set_assigned_device_name(dev_a);
    540     fetches.push_back(a2->name());
    541 
    542     // b2 = b + delay(b)
    543     //
    544     // Subgraph for "b" sleeps at the node "b_delay". When the sleep
    545     // finishes, the subgraph "b" will continue execution till it
    546     // notices that it is canceled. Meanwhile, subgraph's executor
    547     // and its related state (registered ops) should still be alive.
    548     auto b = test::graph::Constant(&g, Tensor());
    549     b->set_assigned_device_name(dev_b);
    550     auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
    551     b_delay->set_assigned_device_name(dev_b);
    552     auto b2 = test::graph::Add(&g, b, b_delay);
    553     b2->set_assigned_device_name(dev_b);
    554     fetches.push_back(b2->name());
    555     test::graph::ToGraphDef(&g, &gdef);
    556   }
    557   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
    558   ASSERT_TRUE(session != nullptr);
    559 
    560   TF_CHECK_OK(session->Create(gdef));
    561   {
    562     Status status = session->Run({}, fetches, {}, nullptr);
    563     EXPECT_FALSE(status.ok());
    564     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
    565   }
    566   // session->Close() shall clean up all states related to the session->
    567   // E.g., deregisters subgraph with workers, etc.
    568   TF_CHECK_OK(session->Close());
    569 
    570   // Sleep a bit so that most of asynchronous works finishes before
    571   // the test process finishes.
    572   Env::Default()->SleepForMicroseconds(2000000);
    573 }
    574 
    575 TEST(GrpcSessionTest, LongErrorMessage) {
    576   std::unique_ptr<test::TestCluster> cluster;
    577   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    578   const string& master = cluster->targets()[0];
    579   const string& dev_a = cluster->devices()[0].name();
    580   const string& dev_b = cluster->devices()[1].name();
    581   LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
    582   GraphDef gdef;
    583   std::vector<string> fetches;
    584   {
    585     Graph g(OpRegistry::Global());
    586 
    587     // a2 = a + error(a)
    588     //
    589     // Subgraph for "a" fails. The master will cancel the subgraph for
    590     // "b" and then returns the Session::Run.
    591     auto a = test::graph::Constant(&g, Tensor());
    592     a->set_assigned_device_name(dev_a);
    593     std::vector<char> long_string_buffer(1024 * 1024, 'x');
    594     StringPiece long_string(long_string_buffer.data(), 1024 * 1024);
    595     string name = strings::StrCat(long_string, "fantasia!");
    596     auto a_err = test::graph::Error(&g, a, name);
    597     a_err->set_assigned_device_name(dev_a);
    598     auto a2 = test::graph::Add(&g, a, a_err);
    599     a2->set_assigned_device_name(dev_a);
    600     fetches.push_back(a2->name());
    601 
    602     // b2 = b + delay(b)
    603     //
    604     // Subgraph for "b" sleeps at the node "b_delay". When the sleep
    605     // finishes, the subgraph "b" will continue execution till it
    606     // notices that it is canceled. Meanwhile, subgraph's executor
    607     // and its related state (registered ops) should still be alive.
    608     auto b = test::graph::Constant(&g, Tensor());
    609     b->set_assigned_device_name(dev_b);
    610     auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
    611     b_delay->set_assigned_device_name(dev_b);
    612     auto b2 = test::graph::Add(&g, b, b_delay);
    613     b2->set_assigned_device_name(dev_b);
    614     fetches.push_back(b2->name());
    615     test::graph::ToGraphDef(&g, &gdef);
    616   }
    617   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
    618   ASSERT_TRUE(session != nullptr);
    619 
    620   TF_CHECK_OK(session->Create(gdef));
    621   {
    622     Status status = session->Run({}, fetches, {}, nullptr);
    623     EXPECT_FALSE(status.ok());
    624     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
    625   }
    626   // session->Close() shall clean up all states related to the session->
    627   // E.g., deregisters subgraph with workers, etc.
    628   TF_CHECK_OK(session->Close());
    629 
    630   // Sleep a bit so that most of asynchronous works finishes before
    631   // the test process finishes.
    632   Env::Default()->SleepForMicroseconds(2000000);
    633 }
    634 
    635 TEST(SessionTest, SharedVar) {
    636   std::unique_ptr<test::TestCluster> cluster;
    637   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
    638   const string master = cluster->targets()[0];
    639   CHECK_EQ(cluster->devices().size(), 1);
    640 
    641   GraphDef gdef;
    642   string init_name;
    643   string inc_name;
    644   string get_name;
    645   {
    646     Graph g(OpRegistry::Global());
    647     Tensor one(DT_FLOAT, TensorShape({}));
    648     one.scalar<float>()() = 1.0;
    649     Node* var = test::graph::Var(&g, DT_FLOAT, one.shape());
    650     Node* init = test::graph::Assign(&g, var, test::graph::Constant(&g, one));
    651     init_name = init->name();
    652     Node* update = test::graph::Assign(
    653         &g, var, test::graph::Add(&g, var, test::graph::Constant(&g, one)));
    654     inc_name = update->name();
    655     get_name = var->name();
    656     test::graph::ToGraphDef(&g, &gdef);
    657   }
    658 
    659   // Init a variable
    660   {
    661     Session* sess = NewRemote(Options(master, 1));
    662     TF_CHECK_OK(sess->Create(gdef));
    663     std::vector<std::pair<string, Tensor>> inp;
    664     TF_CHECK_OK(sess->Run(inp, {}, {init_name}, nullptr));
    665     TF_CHECK_OK(sess->Close());
    666     delete sess;
    667   }
    668 
    669   for (int rep = 1; rep < 10; ++rep) {
    670     // Update a variable
    671     {
    672       Session* sess = NewRemote(Options(master, 1));
    673       TF_CHECK_OK(sess->Create(gdef));
    674       std::vector<std::pair<string, Tensor>> inp;
    675       TF_CHECK_OK(sess->Run(inp, {}, {inc_name}, nullptr));
    676       TF_CHECK_OK(sess->Close());
    677       delete sess;
    678     }
    679 
    680     // Gets the variable's value.
    681     {
    682       Session* sess = NewRemote(Options(master, 1));
    683       TF_CHECK_OK(sess->Create(gdef));
    684       std::vector<std::pair<string, Tensor>> inp;
    685       std::vector<Tensor> ret;
    686       TF_CHECK_OK(sess->Run(inp, {get_name}, {}, &ret));
    687       ASSERT_EQ(ret.size(), 1);
    688       EXPECT_EQ(ret[0].scalar<float>()(), 1.0 * (1 + rep));
    689       TF_CHECK_OK(sess->Close());
    690       delete sess;
    691     }
    692   }
    693 }
    694 
    695 void CreateInvalidGraph(const string& graph_def_ascii,
    696                         const string& error_substring) {
    697   GraphDef graph;
    698   CHECK(protobuf::TextFormat::ParseFromString(graph_def_ascii, &graph));
    699 
    700   std::unique_ptr<test::TestCluster> cluster;
    701   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    702 
    703   std::unique_ptr<Session> session(
    704       NewRemote(Options(cluster->targets()[0], 1)));
    705   Status s = session->Create(graph);
    706 
    707   ASSERT_FALSE(s.ok());
    708   EXPECT_NE(s.error_message().find(error_substring), string::npos);
    709 }
    710 
    711 TEST(SessionTest, InvalidOpName) {
    712   CreateInvalidGraph(R"(
    713     node {
    714       name: 'a:b' op: 'Const'
    715       attr { key: 'dtype' value { type: DT_FLOAT } }
    716       attr { key: 'value' value {
    717         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    718                  float_val: [100] }
    719       } }
    720     }
    721   )",
    722                      "Illegal op name");
    723 
    724   CreateInvalidGraph(R"(
    725     node {
    726       name: 'a:0' op: 'Const'
    727       attr { key: 'dtype' value { type: DT_FLOAT } }
    728       attr { key: 'value' value {
    729         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    730                  float_val: [100] }
    731       } }
    732     }
    733   )",
    734                      "Illegal op name");
    735 
    736   CreateInvalidGraph(R"(
    737     node {
    738       name: '_a' op: 'Const'
    739       attr { key: 'dtype' value { type: DT_FLOAT } }
    740       attr { key: 'value' value {
    741         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    742                  float_val: [100] }
    743       } }
    744     }
    745   )",
    746                      "Illegal op name");
    747 }
    748 
    749 TEST(SessionTest, InvalidOpInputName) {
    750   CreateInvalidGraph(R"(
    751     node {
    752       name: 'a' op: 'const'
    753       attr { key: 'dtype' value { type: DT_FLOAT } }
    754       attr { key: 'value' value {
    755         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    756                  float_val: [100] }
    757       } }
    758     }
    759     node {
    760       name:'b' op:'MatMul' input:'a:first' input:'a'
    761       attr { key: 'T' value { type: DT_FLOAT } }
    762       attr { key: 'transpose_a' value { b: false } }
    763       attr { key: 'transpose_b' value { b: false } }
    764       attr { key: '_kernel' value { s: 'eigen' } }
    765     }
    766   )",
    767                      "Illegal op input name");
    768 
    769   CreateInvalidGraph(R"(
    770     node {
    771       name: 'a' op: 'const'
    772       attr { key: 'dtype' value { type: DT_FLOAT } }
    773       attr { key: 'value' value {
    774         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    775                  float_val: [100] }
    776       } }
    777     }
    778     node {
    779       name:'b' op:'MatMul' input:'_a' input:'a'
    780       attr { key: 'T' value { type: DT_FLOAT } }
    781       attr { key: 'transpose_a' value { b: false } }
    782       attr { key: 'transpose_b' value { b: false } }
    783       attr { key: '_kernel' value { s: 'eigen' } }
    784     }
    785   )",
    786                      "Illegal op input name");
    787 
    788   CreateInvalidGraph(R"(
    789     node {
    790       name: 'a' op: 'const'
    791       attr { key: 'dtype' value { type: DT_FLOAT } }
    792       attr { key: 'value' value {
    793         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    794                  float_val: [100] }
    795       } }
    796     }
    797     node {
    798       name:'b' op:'MatMul' input:'_a:0' input:'a'
    799       attr { key: 'T' value { type: DT_FLOAT } }
    800       attr { key: 'transpose_a' value { b: false } }
    801       attr { key: 'transpose_b' value { b: false } }
    802       attr { key: '_kernel' value { s: 'eigen' } }
    803     }
    804   )",
    805                      "Illegal op input name");
    806 
    807   CreateInvalidGraph(R"(
    808     node {
    809       name: 'a' op: 'const'
    810       attr { key: 'dtype' value { type: DT_FLOAT } }
    811       attr { key: 'value' value {
    812         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    813                  float_val: [100] }
    814       } }
    815     }
    816     node {
    817       name:'b' op:'MatMul' input:'a:01' input:'a'
    818       attr { key: 'T' value { type: DT_FLOAT } }
    819       attr { key: 'transpose_a' value { b: false } }
    820       attr { key: 'transpose_b' value { b: false } }
    821       attr { key: '_kernel' value { s: 'eigen' } }
    822     }
    823   )",
    824                      "Illegal op input name");
    825 }
    826 
    827 TEST(SessionTest, ExtendValidation) {
    828   GraphDef graph;
    829   bool success = protobuf::TextFormat::ParseFromString(R"(
    830     node {
    831       name: 'a' op: 'Const'
    832       attr { key: 'dtype' value { type: DT_FLOAT } }
    833       attr { key: 'value' value {
    834         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
    835                  float_val: [100] }
    836       } }
    837     }
    838   )",
    839                                                        &graph);
    840   // NOTE(mrry): CHECK not done inline to avoid a compilation error in
    841   // open-source (due to a multi-line string in a macro argument).
    842   ASSERT_TRUE(success);
    843 
    844   std::unique_ptr<test::TestCluster> cluster;
    845   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
    846 
    847   std::unique_ptr<Session> session(
    848       NewRemote(Options(cluster->targets()[0], 1)));
    849   TF_CHECK_OK(session->Create(graph));
    850 
    851   // 1. Fail with an unknown input name.
    852   GraphDef extension;
    853   success = protobuf::TextFormat::ParseFromString(R"(
    854     node {
    855       name:'b' op:'MatMul' input:'a:first' input:'a'
    856       attr { key: 'T' value { type: DT_FLOAT } }
    857       attr { key: 'transpose_a' value { b: false } }
    858       attr { key: 'transpose_b' value { b: false } }
    859       attr { key: '_kernel' value { s: 'eigen' } }
    860     }
    861   )",
    862                                                   &extension);
    863   ASSERT_TRUE(success);
    864 
    865   Status s = session->Extend(extension);
    866   ASSERT_FALSE(s.ok());
    867   EXPECT_NE(s.error_message().find("Illegal op input name"), string::npos);
    868 
    869   // 2. Succeed with a valid node.
    870   success = protobuf::TextFormat::ParseFromString(R"(
    871     node {
    872       name:'b' op:'MatMul' input:'a' input:'a'
    873       attr { key: 'T' value { type: DT_FLOAT } }
    874       attr { key: 'transpose_a' value { b: false } }
    875       attr { key: 'transpose_b' value { b: false } }
    876       attr { key: '_kernel' value { s: 'eigen' } }
    877     }
    878   )",
    879                                                   &extension);
    880   ASSERT_TRUE(success);
    881   TF_CHECK_OK(session->Extend(extension));
    882 
    883   // 2. Fail with a duplicate node.
    884   success = protobuf::TextFormat::ParseFromString(R"(
    885     node {
    886       name:'b' op:'MatMul' input:'a' input:'a'
    887       attr { key: 'T' value { type: DT_FLOAT } }
    888       attr { key: 'transpose_a' value { b: false } }
    889       attr { key: 'transpose_b' value { b: false } }
    890       attr { key: '_kernel' value { s: 'eigen' } }
    891     }
    892   )",
    893                                                   &extension);
    894   ASSERT_TRUE(success);
    895   s = session->Extend(extension);
    896   ASSERT_FALSE(s.ok());
    897   EXPECT_NE(s.error_message().find("'b', which was created by a previous call"),
    898             string::npos);
    899 }
    900 // Tests that Create() with "operation_timeout_in_ms" set times out.
    901 TEST(SessionTest, CreateTimeoutWithSessionOptions) {
    902   // Creates a RemoteSession with "operation_timeout_in_ms" set to 100.
    903   SessionOptions options = Options("example.org:2222", 1);
    904   options.config.set_operation_timeout_in_ms(100);
    905   std::unique_ptr<Session> session(NewRemote(options));
    906 
    907   // Creates a long running op.
    908   Graph graph(OpRegistry::Global());
    909   Node* b = test::graph::Constant(&graph, Tensor());
    910   test::graph::Delay(&graph, b, Microseconds(1000000));
    911   GraphDef gdef;
    912   test::graph::ToGraphDef(&graph, &gdef);
    913   Status status = session->Create(gdef);
    914   // Either error is possible, depending on the environment.
    915   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
    916               error::UNAVAILABLE == status.code());
    917 }
    918 
    919 // Tests that Create() with "timeout_in_ms" in RunOptions set times out.
    920 TEST(SessionTest, CreateTimeoutWithRunOptions) {
    921   SessionOptions options = Options("example.org:2222", 1);
    922   std::unique_ptr<Session> session(NewRemote(options));
    923 
    924   // Creates a long running op.
    925   Graph graph(OpRegistry::Global());
    926   Node* b = test::graph::Constant(&graph, Tensor());
    927   test::graph::Delay(&graph, b, Microseconds(1000000));
    928   GraphDef gdef;
    929   test::graph::ToGraphDef(&graph, &gdef);
    930   RunOptions run_options;
    931   // Sets RunOption timeout_in_ms to 20.
    932   run_options.set_timeout_in_ms(20);
    933   Status status = session->Create(run_options, gdef);
    934   // Either error is possible, depending on the environment.
    935   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
    936               error::UNAVAILABLE == status.code());
    937 }
    938 
    939 // Tests that Run() with "operation_timeout_in_ms" set times out.
    940 TEST(SessionTest, RunTimeoutWithSessionOptions) {
    941   // Creates a RemoteSession with "operation_timeout_in_ms" set to 100.
    942   std::unique_ptr<test::TestCluster> cluster;
    943   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
    944   SessionOptions options = Options(cluster->targets()[0], 100);
    945   options.config.set_operation_timeout_in_ms(1);
    946   std::unique_ptr<Session> session(NewRemote(options));
    947 
    948   // Creates a long running op.
    949   Graph graph(OpRegistry::Global());
    950   Node* b = test::graph::Constant(&graph, Tensor());
    951   Node* b_delay = test::graph::Delay(&graph, b, Microseconds(2000000));
    952   GraphDef gdef;
    953   test::graph::ToGraphDef(&graph, &gdef);
    954   RunOptions run_options;
    955   TF_CHECK_OK(session->Create(run_options, gdef));
    956 
    957   // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED.
    958   std::vector<std::pair<string, Tensor>> inputs;
    959   Status status = session->Run(inputs, {}, {b_delay->name()}, nullptr);
    960   // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get
    961   // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL.
    962   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
    963               error::INTERNAL == status.code());
    964 }
    965 
    966 // Tests that Run() with "timeout_in_ms" set times out.
    967 TEST(SessionTest, RunTimeoutWithRunOptions) {
    968   std::unique_ptr<test::TestCluster> cluster;
    969   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
    970   SessionOptions options = Options(cluster->targets()[0], 1);
    971   std::unique_ptr<Session> session(NewRemote(options));
    972 
    973   // Creates a long running op.
    974   Graph graph(OpRegistry::Global());
    975   Node* b = test::graph::Constant(&graph, Tensor());
    976   Node* b_delay = test::graph::Delay(&graph, b, Microseconds(1000000));
    977   GraphDef gdef;
    978   test::graph::ToGraphDef(&graph, &gdef);
    979   TF_CHECK_OK(session->Create(gdef));
    980 
    981   // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED.
    982   std::vector<std::pair<string, Tensor>> inputs;
    983   RunOptions run_options;
    984   run_options.set_timeout_in_ms(100);
    985   Status status = session->Run(run_options, inputs, {}, {b_delay->name()},
    986                                nullptr, nullptr);
    987   // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get
    988   // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL.
    989   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
    990               error::INTERNAL == status.code());
    991 }
    992 
    993 }  // namespace tensorflow
    994