1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/training/queue_runner.h" 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/cc/framework/scope.h" 22 #include "tensorflow/cc/ops/standard_ops.h" 23 #include "tensorflow/cc/training/coordinator.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.pb.h" 28 #include "tensorflow/core/lib/core/error_codes.pb.h" 29 #include "tensorflow/core/lib/core/notification.h" 30 #include "tensorflow/core/lib/core/status_test_util.h" 31 #include "tensorflow/core/platform/env.h" 32 #include "tensorflow/core/platform/test.h" 33 #include "tensorflow/core/protobuf/queue_runner.pb.h" 34 #include "tensorflow/core/public/session.h" 35 36 namespace tensorflow { 37 namespace { 38 39 using error::Code; 40 using ops::Assign; 41 using ops::Const; 42 using ops::CountUpTo; 43 using ops::FIFOQueue; 44 using ops::QueueClose; 45 using ops::QueueDequeue; 46 using ops::QueueEnqueue; 47 using ops::RandomNormal; 48 using ops::Square; 49 using ops::Variable; 50 51 constexpr char kAssignOpName[] = "assign"; 52 constexpr char kCancelOp0[] = "cancel0"; 53 constexpr char kCancelOp1[] = "cancel1"; 54 constexpr char kCloseOp0[] = "close0"; 55 constexpr char kCloseOp1[] = "close1"; 56 constexpr char kCountUpToOpName[] = "count"; 57 constexpr char kDequeueOp0[] = "dequeue0"; 58 constexpr char kDequeueOp1[] = "dequeue1"; 59 constexpr char kEnqueueOp0[] = "enqueue0"; 60 constexpr char kEnqueueOp1[] = "enqueue1"; 61 constexpr char kIllegalOpName1[] = "would fail"; 62 constexpr char kIllegalOpName2[] = "fail again"; 63 constexpr char kQueueName[] = "unit_test"; 64 constexpr char kQueueName0[] = "q0"; 65 constexpr char kQueueName1[] = "q1"; 66 constexpr char kSquareOpName[] = "square"; 67 constexpr char kVarOpName[] = "var"; 68 69 GraphDef BuildSimpleGraph() { 70 Scope root = Scope::NewRootScope(); 71 auto init_value = Const(root, 0); 72 auto var = Variable(root.WithOpName(kVarOpName), TensorShape({}), 73 DataType::DT_INT32); 74 auto assign = Assign(root.WithOpName(kAssignOpName), var, init_value); 75 auto count = CountUpTo(root.WithOpName(kCountUpToOpName), var, 10); 76 Square(root.WithOpName(kSquareOpName), var); // NOLINT 77 78 GraphDef graph_def; 79 TF_EXPECT_OK(root.ToGraphDef(&graph_def)); 80 return graph_def; 81 } 82 83 QueueRunnerDef BuildQueueRunnerDef( 84 const std::string& queue_name, const std::vector<std::string>& enqueue_ops, 85 const std::string& close_op, const std::string& cancel_op, 86 const std::vector<Code>& queue_closed_error_codes) { 87 QueueRunnerDef queue_runner_def; 88 *queue_runner_def.mutable_queue_name() = queue_name; 89 for (const std::string& enqueue_op : enqueue_ops) { 90 *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op; 91 } 92 *queue_runner_def.mutable_close_op_name() = close_op; 93 *queue_runner_def.mutable_cancel_op_name() = cancel_op; 94 for (const auto& error_code : queue_closed_error_codes) { 95 *queue_runner_def.mutable_queue_closed_exception_types()->Add() = 96 error_code; 97 } 98 return queue_runner_def; 99 } 100 101 std::unique_ptr<Session> BuildSessionAndInitVariable( 102 const GraphDef& graph_def) { 103 SessionOptions options; 104 std::unique_ptr<Session> session(NewSession(options)); 105 TF_CHECK_OK(session->Create(graph_def)); 106 107 TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr)); 108 return session; 109 } 110 111 TEST(QueueRunnerTest, BasicTest) { 112 GraphDef graph_def = BuildSimpleGraph(); 113 auto session = BuildSessionAndInitVariable(graph_def); 114 115 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 116 kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); 117 118 std::unique_ptr<QueueRunner> qr; 119 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 120 TF_CHECK_OK(qr->Start(session.get())); 121 TF_EXPECT_OK(qr->Join()); 122 123 std::vector<Tensor> outputs; 124 TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); 125 int square_value = *outputs[0].scalar<int>().data(); 126 EXPECT_EQ(square_value, 100); 127 } 128 129 TEST(QueueRunnerTest, QueueClosedCode) { 130 GraphDef graph_def = BuildSimpleGraph(); 131 auto session = BuildSessionAndInitVariable(graph_def); 132 133 // Start two queues so that multiple threads are in Run. 134 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 135 kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", 136 {Code::OUT_OF_RANGE, Code::CANCELLED}); 137 138 std::unique_ptr<QueueRunner> qr; 139 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 140 TF_EXPECT_OK(qr->Start(session.get())); 141 TF_EXPECT_OK(qr->Join()); 142 143 std::vector<Tensor> outputs; 144 TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); 145 int square_value = *outputs[0].scalar<int>().data(); 146 EXPECT_EQ(square_value, 100); 147 } 148 149 TEST(QueueRunnerTest, QueueCloseFails) { 150 GraphDef graph_def = BuildSimpleGraph(); 151 auto session = BuildSessionAndInitVariable(graph_def); 152 153 QueueRunnerDef queue_runner_def = 154 BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kIllegalOpName1, "", 155 {Code::OUT_OF_RANGE}); 156 157 std::unique_ptr<QueueRunner> qr; 158 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 159 TF_EXPECT_OK(qr->Start(session.get())); 160 auto status = qr->Join(); 161 EXPECT_EQ(status.code(), Code::NOT_FOUND) << status; 162 } 163 164 TEST(QueueRunnerTest, CatchErrorInJoin) { 165 GraphDef graph_def = BuildSimpleGraph(); 166 auto session = BuildSessionAndInitVariable(graph_def); 167 168 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 169 kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); 170 171 std::unique_ptr<QueueRunner> qr; 172 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 173 TF_EXPECT_OK(qr->Start(session.get())); 174 EXPECT_EQ(qr->Join().code(), Code::NOT_FOUND); 175 } 176 177 GraphDef BuildDoubleQueueGraph() { 178 Scope root = Scope::NewRootScope(); 179 auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32}); 180 auto ten = Const(root, 10); 181 auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten}); 182 auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); 183 auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, 184 QueueClose::CancelPendingEnqueues(true)); 185 auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32}, 186 FIFOQueue::Capacity(3)); 187 auto dequeue0 = 188 QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32}); 189 auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]}); 190 auto dequeue1 = 191 QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32}); 192 auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1); 193 auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1, 194 QueueClose::CancelPendingEnqueues(true)); 195 196 GraphDef graph_def; 197 TF_EXPECT_OK(root.ToGraphDef(&graph_def)); 198 return graph_def; 199 } 200 201 TEST(QueueRunnerTest, RealEnqueueDequeue) { 202 auto graph_def = BuildDoubleQueueGraph(); 203 204 SessionOptions options; 205 std::unique_ptr<Session> session(NewSession(options)); 206 TF_CHECK_OK(session->Create(graph_def)); 207 208 QueueRunnerDef queue_runner_def = 209 BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {}); 210 std::unique_ptr<QueueRunner> qr; 211 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 212 TF_CHECK_OK(qr->Start(session.get())); 213 214 TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); 215 TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); 216 // Closing queue 0 would also close the queue runner. 217 TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr)); 218 219 TF_EXPECT_OK(qr->Join()); 220 std::vector<Tensor> dq1; 221 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1)); 222 EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); 223 std::vector<Tensor> dq2; 224 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2)); 225 EXPECT_EQ(*dq2[0].scalar<int>().data(), 10); 226 227 EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(), 228 Code::OUT_OF_RANGE); 229 } 230 231 void JoinThread(QueueRunner* queue_runner, bool* join_succeeded, 232 Notification* join_done) { 233 EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED); 234 *join_succeeded = true; 235 join_done->Notify(); 236 } 237 238 TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { 239 auto graph_def = BuildDoubleQueueGraph(); 240 241 SessionOptions options; 242 std::unique_ptr<Session> session(NewSession(options)); 243 TF_CHECK_OK(session->Create(graph_def)); 244 245 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 246 kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); 247 std::unique_ptr<QueueRunner> qr; 248 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 249 TF_CHECK_OK(qr->Start(session.get())); 250 251 TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); 252 253 std::vector<Tensor> dq1; 254 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1)); 255 EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); 256 257 // The expected behavior is the QueueRunner::Join() call is blocked until 258 // Session::Close() is called. 259 bool join_succeeded = false; 260 Notification join_done; 261 Env::Default()->SchedClosure( 262 std::bind(&JoinThread, qr.get(), &join_succeeded, &join_done)); 263 264 Env::Default()->SleepForMicroseconds(10000000); 265 EXPECT_EQ(join_succeeded, false); 266 267 // Closing the session is required to cancel pending enqueue nodes. 268 TF_EXPECT_OK(session->Close()); 269 270 join_done.WaitForNotification(); 271 EXPECT_EQ(join_succeeded, true); 272 } 273 274 TEST(QueueRunnerTest, EmptyEnqueueOps) { 275 QueueRunnerDef queue_runner_def = 276 BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); 277 278 std::unique_ptr<QueueRunner> qr; 279 EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), 280 Code::INVALID_ARGUMENT); 281 } 282 283 TEST(QueueRunnerTest, StartTimeout) { 284 GraphDef graph_def = BuildDoubleQueueGraph(); 285 SessionOptions options; 286 std::unique_ptr<Session> session(NewSession(options)); 287 TF_CHECK_OK(session->Create(graph_def)); 288 289 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 290 kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); 291 292 std::unique_ptr<QueueRunner> qr; 293 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 294 // This will timeout since queue0 is not fed and queue1 is fetching data from 295 // queue0. 296 EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); 297 TF_EXPECT_OK(session->Close()); 298 } 299 300 TEST(QueueRunnerTest, TestCoordinatorStop) { 301 auto graph_def = BuildDoubleQueueGraph(); 302 SessionOptions options; 303 std::unique_ptr<Session> session(NewSession(options)); 304 TF_CHECK_OK(session->Create(graph_def)); 305 306 QueueRunnerDef queue_runner0 = 307 BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0, 308 {Code::OUT_OF_RANGE, Code::CANCELLED}); 309 QueueRunnerDef queue_runner1 = 310 BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, 311 {Code::OUT_OF_RANGE, Code::CANCELLED}); 312 313 Coordinator coord; 314 std::unique_ptr<QueueRunner> qr0; 315 TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0)); 316 TF_CHECK_OK(qr0->Start(session.get())); 317 std::unique_ptr<QueueRunner> qr1; 318 TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1)); 319 TF_CHECK_OK(qr1->Start(session.get())); 320 321 TF_EXPECT_OK(coord.RegisterRunner(std::move(qr0))); 322 TF_EXPECT_OK(coord.RegisterRunner(std::move(qr1))); 323 324 std::vector<Tensor> dq; 325 TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); 326 EXPECT_EQ(*dq[0].scalar<int>().data(), 10); 327 328 TF_EXPECT_OK(coord.RequestStop()); 329 TF_EXPECT_OK(coord.Join()); 330 } 331 332 TEST(QueueRunnerTest, CallbackCalledOnError) { 333 GraphDef graph_def = BuildSimpleGraph(); 334 auto session = BuildSessionAndInitVariable(graph_def); 335 336 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 337 kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); 338 339 std::unique_ptr<QueueRunner> qr; 340 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 341 bool error_caught = false; 342 qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; }); 343 TF_EXPECT_OK(qr->Start(session.get())); 344 EXPECT_FALSE(qr->Join().ok()); 345 EXPECT_TRUE(error_caught); 346 } 347 348 TEST(QueueRunnerTest, RunMetaDataTest) { 349 Scope root = Scope::NewRootScope(); 350 auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT}); 351 Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT); 352 Output square = Square(root.WithOpName(kSquareOpName), rnd); 353 auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square}); 354 auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); 355 auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, 356 QueueClose::CancelPendingEnqueues(true)); 357 auto dequeue0 = 358 QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT}); 359 360 GraphDef graph_def; 361 TF_EXPECT_OK(root.ToGraphDef(&graph_def)); 362 for (auto& node : *graph_def.mutable_node()) { 363 node.set_device("/cpu:0"); 364 } 365 SessionOptions sess_options; 366 sess_options.config.mutable_graph_options()->set_build_cost_model(1); 367 std::unique_ptr<Session> session(NewSession(sess_options)); 368 369 TF_CHECK_OK(session->Create(graph_def)); 370 371 QueueRunnerDef queue_runner_def = 372 BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {}); 373 std::unique_ptr<QueueRunner> qr; 374 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 375 RunOptions run_options; 376 TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), run_options)); 377 378 // Make sure there was at least one element enqueued in q0: this prevents a 379 // race condition where we close the queue before it was populated. 380 std::vector<Tensor> dq0; 381 TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); 382 // Second call to run dequeue op is to make sure the cost graph has been 383 // stored. 384 TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); 385 386 CostGraphDef cost_graph; 387 TF_CHECK_OK(qr->ExportCostGraph(&cost_graph)); 388 EXPECT_TRUE(cost_graph.node_size() > 0); 389 390 qr->Stop(session.get()); 391 } 392 393 TEST(QueueRunnerTest, NoRunMetaDataTest) { 394 GraphDef graph_def = BuildSimpleGraph(); 395 auto session = BuildSessionAndInitVariable(graph_def); 396 397 QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( 398 kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); 399 std::unique_ptr<QueueRunner> qr; 400 TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); 401 TF_CHECK_OK(qr->Start(session.get())); 402 403 TF_EXPECT_OK(qr->Join()); 404 CostGraphDef cost_graph; 405 EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(), 406 error::FAILED_PRECONDITION); 407 } 408 409 } // namespace 410 } // namespace tensorflow 411