1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/master.h" 17 18 #include <map> 19 #include <memory> 20 21 #include "grpc++/grpc++.h" 22 23 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 24 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 26 #include "tensorflow/core/framework/allocator.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_testutil.h" 29 #include "tensorflow/core/graph/testlib.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/notification.h" 32 #include "tensorflow/core/lib/core/status_test_util.h" 33 #include "tensorflow/core/lib/core/threadpool.h" 34 #include "tensorflow/core/lib/gtl/map_util.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/mutex.h" 37 #include "tensorflow/core/platform/test.h" 38 #include "tensorflow/core/platform/types.h" 39 #include "tensorflow/core/protobuf/master.pb.h" 40 #include "tensorflow/core/protobuf/master_service.grpc.pb.h" 41 42 namespace tensorflow { 43 44 class MasterTest : public ::testing::Test { 45 protected: 46 MasterTest() { 47 std::vector<string> targets; 48 SessionOptions options; 49 (*options.config.mutable_device_count())["CPU"] = 1; 50 (*options.config.mutable_device_count())["GPU"] = 0; 51 TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_)); 52 SharedGrpcChannelPtr channel_ptr; 53 TF_CHECK_OK(NewHostPortGrpcChannel(cluster_->targets()[0], &channel_ptr)); 54 master_ = grpc::MasterService::NewStub(channel_ptr); 55 } 56 57 std::unique_ptr<test::TestCluster> cluster_; 58 std::unique_ptr<grpc::MasterService::Stub> master_; 59 60 // Helpers for MasterService.{CreateSession,RunStep,CloseSession} 61 // rpc calls. 62 63 Status CreateSession(const GraphDef& def, string* handle, 64 int64* initial_version) { 65 ::grpc::ClientContext ctx; 66 CreateSessionRequest req; 67 *(req.mutable_graph_def()) = def; 68 // Invokes placement frequently. 69 req.mutable_config()->set_placement_period(1); 70 CreateSessionResponse resp; 71 const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp)); 72 if (s.ok()) { 73 *handle = resp.session_handle(); 74 *initial_version = resp.graph_version(); 75 } 76 return s; 77 } 78 79 Status ExtendSession(const string& handle, const GraphDef& def, 80 int64 current_version, int64* new_version) { 81 ::grpc::ClientContext ctx; 82 ExtendSessionRequest req; 83 req.set_session_handle(handle); 84 *(req.mutable_graph_def()) = def; 85 req.set_current_graph_version(current_version); 86 ExtendSessionResponse resp; 87 const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp)); 88 if (s.ok()) { 89 *new_version = resp.new_graph_version(); 90 } 91 return s; 92 } 93 94 Status RunStep(const string& handle, 95 const std::vector<std::pair<string, const Tensor*> >& feed, 96 const std::map<string, Tensor*>& fetch) { 97 ::grpc::ClientContext ctx; 98 RunStepRequest req; 99 req.set_session_handle(handle); 100 for (const auto& p : feed) { 101 const string& feed_name = p.first; 102 const Tensor* feed_tensor = p.second; 103 auto f = req.add_feed(); 104 f->set_name(feed_name); 105 feed_tensor->AsProtoTensorContent(f->mutable_tensor()); 106 } 107 for (const auto& p : fetch) { 108 const string& fetch_name = p.first; 109 req.add_fetch(fetch_name); 110 } 111 RunStepResponse resp; 112 const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp)); 113 if (s.ok()) { 114 for (const auto& fetch_resp : resp.tensor()) { 115 auto it = fetch.find(fetch_resp.name()); 116 CHECK(it != fetch.end()); 117 CHECK(it->second->FromProto(fetch_resp.tensor())); 118 } 119 } 120 return s; 121 } 122 123 Status CloseSession(const string& handle) { 124 ::grpc::ClientContext ctx; 125 CloseSessionRequest req; 126 req.set_session_handle(handle); 127 CloseSessionResponse resp; 128 return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp)); 129 } 130 131 Status Reset() { 132 ::grpc::ClientContext ctx; 133 ResetRequest req; 134 ResetResponse resp; 135 return FromGrpcStatus(master_->Reset(&ctx, req, &resp)); 136 } 137 }; 138 139 TEST_F(MasterTest, CreateClose) { 140 GraphDef def; // Empty. 141 string handle; 142 int64 initial_version; 143 TF_ASSERT_OK(CreateSession(def, &handle, &initial_version)); 144 EXPECT_TRUE(errors::IsAborted(CloseSession("randombits"))); 145 EXPECT_TRUE(CloseSession(handle).ok()); 146 } 147 148 TEST_F(MasterTest, ListDevices) { 149 ::grpc::ClientContext ctx; 150 ListDevicesRequest req; 151 ListDevicesResponse resp; 152 const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp)); 153 TF_EXPECT_OK(s); 154 EXPECT_EQ(1, resp.local_device_size()); 155 EXPECT_EQ("CPU", resp.local_device(0).device_type()); 156 } 157 158 TEST_F(MasterTest, Reset) { 159 GraphDef def; // Empty. 160 string s1, s2; 161 int64 initial_version1, initial_version2; 162 TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1)); 163 TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2)); 164 EXPECT_TRUE(Reset().ok()); 165 EXPECT_TRUE(errors::IsAborted(CloseSession(s1))); 166 EXPECT_TRUE(errors::IsAborted(CloseSession(s2))); 167 } 168 169 TEST_F(MasterTest, Extend) { 170 GraphDef def_0; // Empty. 171 string handle; 172 int64 initial_version; 173 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); 174 175 Tensor A_expected(DT_FLOAT, TensorShape({2, 2})); 176 test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0}); 177 178 Tensor x_expected(DT_FLOAT, TensorShape({2, 1})); 179 test::FillValues<float>(&x_expected, {2.0, 2.0}); 180 181 Graph graph_1(OpRegistry::Global()); 182 test::graph::Constant(&graph_1, A_expected, "A"); 183 GraphDef def_1; 184 test::graph::ToGraphDef(&graph_1, &def_1); 185 int64 version_1; 186 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); 187 EXPECT_GT(version_1, initial_version); 188 Tensor A(DT_FLOAT, TensorShape({2, 2})); 189 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}})); 190 test::ExpectTensorEqual<float>(A, A_expected); 191 192 Graph graph_2(OpRegistry::Global()); 193 test::graph::Constant(&graph_2, x_expected, "x"); 194 GraphDef def_2; 195 test::graph::ToGraphDef(&graph_2, &def_2); 196 int64 version_2; 197 EXPECT_TRUE(errors::IsAborted( 198 ExtendSession("randombits", def_2, version_1, &version_2))); 199 TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2)); 200 EXPECT_GT(version_2, version_1); 201 202 Tensor x(DT_FLOAT, TensorShape({2, 1})); 203 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}})); 204 test::ExpectTensorEqual<float>(A, A_expected); 205 test::ExpectTensorEqual<float>(x, x_expected); 206 207 TF_ASSERT_OK(CloseSession(handle)); 208 } 209 210 TEST_F(MasterTest, ExtendUpdateStatefulFails) { 211 GraphDef def_0; // Empty. 212 string handle; 213 int64 initial_version; 214 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); 215 216 Graph graph_1(OpRegistry::Global()); 217 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512})); 218 GraphDef def_1; 219 test::graph::ToGraphDef(&graph_1, &def_1); 220 221 int64 version_1, version_2; 222 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); 223 EXPECT_GT(version_1, initial_version); 224 EXPECT_TRUE(errors::IsInvalidArgument( 225 ExtendSession(handle, def_1, version_1, &version_2))); 226 TF_ASSERT_OK(CloseSession(handle)); 227 } 228 229 TEST_F(MasterTest, ExtendTwiceFails) { 230 GraphDef def_0; // Empty. 231 string handle; 232 int64 initial_version; 233 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); 234 235 Graph graph_1(OpRegistry::Global()); 236 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512})); 237 GraphDef def_1; 238 test::graph::ToGraphDef(&graph_1, &def_1); 239 240 int64 version_1; 241 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); 242 EXPECT_GT(version_1, initial_version); 243 EXPECT_TRUE(errors::IsAborted( 244 ExtendSession(handle, def_1, initial_version, &version_1))); 245 TF_ASSERT_OK(CloseSession(handle)); 246 } 247 248 TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) { 249 GraphDef def_0; // Empty. 250 string handle; 251 int64 initial_version; 252 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); 253 254 Graph graph_1(OpRegistry::Global()); 255 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512})); 256 GraphDef def_1; 257 test::graph::ToGraphDef(&graph_1, &def_1); 258 259 Notification n; 260 mutex mu; 261 int succeeded = 0; 262 int failed = 0; 263 auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded, 264 &failed]() { 265 n.WaitForNotification(); 266 int64 new_version; 267 Status s = ExtendSession(handle, def_1, initial_version, &new_version); 268 EXPECT_TRUE(s.ok() || errors::IsAborted(s)); 269 { 270 mutex_lock l(mu); 271 if (s.ok()) { 272 ++succeeded; 273 } else { 274 ++failed; 275 } 276 } 277 }; 278 279 // Run 100 concurrent Extend calls and expect only one to succeed. 280 { 281 thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100); 282 for (int i = 0; i < 100; ++i) { 283 thread_pool.Schedule(extend_fn); 284 } 285 n.Notify(); 286 } 287 288 EXPECT_EQ(failed, 99); 289 EXPECT_EQ(succeeded, 1); 290 TF_ASSERT_OK(CloseSession(handle)); 291 } 292 293 TEST_F(MasterTest, ConcurrentExtendAndRun) { 294 Graph graph_0(OpRegistry::Global()); 295 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); 296 test::FillValues<float>(&a_tensor, {3, 2, -1, 0}); 297 test::graph::Constant(&graph_0, a_tensor, "A"); 298 GraphDef def_0; 299 test::graph::ToGraphDef(&graph_0, &def_0); 300 301 string handle; 302 int64 initial_version; 303 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); 304 305 Graph graph_1(OpRegistry::Global()); 306 Tensor b_tensor(DT_FLOAT, TensorShape({2, 2})); 307 test::FillValues<float>(&b_tensor, {1, 0, 0, 1}); 308 test::graph::Constant(&graph_1, b_tensor, "B"); 309 GraphDef def_1; 310 test::graph::ToGraphDef(&graph_1, &def_1); 311 312 Notification extend_done; 313 Notification extend_can_start; 314 315 auto get_a_fn = [this, handle, &extend_done]() { 316 Tensor A(DT_FLOAT, TensorShape({2, 2})); 317 while (!extend_done.HasBeenNotified()) { 318 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}})); 319 } 320 // Run at least once after the Extend has completed. 321 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}})); 322 }; 323 324 auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() { 325 Tensor A(DT_FLOAT, TensorShape({2, 2})); 326 Tensor B(DT_FLOAT, TensorShape({2, 2})); 327 328 // Run at least once before the Extend has completed. 329 EXPECT_TRUE( 330 errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}))); 331 extend_can_start.Notify(); 332 333 // Concurrent with the Extend, we will either fail (as above), or 334 // succeed (as below). 335 while (!extend_done.HasBeenNotified()) { 336 Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}); 337 EXPECT_TRUE(errors::IsNotFound(s) || s.ok()); 338 } 339 340 // Run at least once after the Extend has completed. 341 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})); 342 }; 343 344 auto extend_fn = [this, handle, def_1, initial_version, &extend_done, 345 &extend_can_start]() { 346 extend_can_start.WaitForNotification(); 347 int64 version_1; 348 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); 349 extend_done.Notify(); 350 }; 351 352 { 353 thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3); 354 thread_pool.Schedule(get_a_fn); 355 thread_pool.Schedule(get_a_and_b_fn); 356 thread_pool.Schedule(extend_fn); 357 } 358 359 TF_ASSERT_OK(CloseSession(handle)); 360 } 361 362 TEST_F(MasterTest, EigenProblem) { 363 // A = [3 2; -1 0]; x = rand(2, 1); 364 // for i=1:100; x = A * x; end 365 // We'll try to compute the largest eigenvalue for A. 366 Graph graph(OpRegistry::Global()); 367 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); 368 // Store rows [3, 2] and [-1, 0] in row major format. 369 test::FillValues<float>(&a_tensor, {3, 2, -1, 0}); 370 Node* a_node = test::graph::Constant(&graph, a_tensor); 371 372 // x is from the feed. 373 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); 374 test::FillValues<float>(&x_tensor, {0, 0}); 375 Node* x_node = test::graph::Constant(&graph, x_tensor); 376 377 // y = A * x 378 Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false); 379 380 GraphDef def; 381 test::graph::ToGraphDef(&graph, &def); 382 383 string handle; 384 int64 initial_version; 385 TF_CHECK_OK(CreateSession(def, &handle, &initial_version)); 386 387 // Temps supporting the computation of the convergence condition. 388 const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0); 389 const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0}); 390 Tensor x(DT_FLOAT, TensorShape({2, 1})); 391 Tensor y(DT_FLOAT, TensorShape({2, 1})); 392 Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum; 393 Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1); 394 y_normalized.setRandom(); 395 Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum; 396 float lambda; 397 398 // The computation loop. 399 bool converged = false; 400 while (!converged) { 401 // Run one step of the graph. 402 auto x_matrix = x.matrix<float>(); 403 x_matrix = y_normalized; 404 TF_EXPECT_OK( 405 RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}})); 406 auto y_matrix = y.matrix<float>(); 407 408 // Client code computes the convergence condition. 409 { 410 lambda = y_matrix(0, 0) / x_matrix(0, 0); 411 y_square_sum = y.matrix<float>().square().sum(sum_along_dim); 412 const float norm = static_cast<float>(sqrt(y_square_sum(0))); 413 y_normalized = y_matrix * (1 / norm); 414 error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim); 415 VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = [" 416 << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda; 417 converged = sqrt(error_square_sum(0)) < 1e-10; 418 } 419 } 420 EXPECT_NEAR(lambda, 2.0, 0.01); 421 TF_EXPECT_OK(CloseSession(handle)); 422 } 423 424 } // namespace tensorflow 425