Home | History | Annotate | Download | only in training
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/training/queue_runner.h"
     17 
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/cc/framework/scope.h"
     22 #include "tensorflow/cc/ops/standard_ops.h"
     23 #include "tensorflow/cc/training/coordinator.h"
     24 #include "tensorflow/core/framework/graph.pb.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/types.pb.h"
     28 #include "tensorflow/core/lib/core/error_codes.pb.h"
     29 #include "tensorflow/core/lib/core/notification.h"
     30 #include "tensorflow/core/lib/core/status_test_util.h"
     31 #include "tensorflow/core/platform/env.h"
     32 #include "tensorflow/core/platform/test.h"
     33 #include "tensorflow/core/protobuf/queue_runner.pb.h"
     34 #include "tensorflow/core/public/session.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 using error::Code;
     40 using ops::Assign;
     41 using ops::Const;
     42 using ops::CountUpTo;
     43 using ops::FIFOQueue;
     44 using ops::QueueClose;
     45 using ops::QueueDequeue;
     46 using ops::QueueEnqueue;
     47 using ops::RandomNormal;
     48 using ops::Square;
     49 using ops::Variable;
     50 
     51 constexpr char kAssignOpName[] = "assign";
     52 constexpr char kCancelOp0[] = "cancel0";
     53 constexpr char kCancelOp1[] = "cancel1";
     54 constexpr char kCloseOp0[] = "close0";
     55 constexpr char kCloseOp1[] = "close1";
     56 constexpr char kCountUpToOpName[] = "count";
     57 constexpr char kDequeueOp0[] = "dequeue0";
     58 constexpr char kDequeueOp1[] = "dequeue1";
     59 constexpr char kEnqueueOp0[] = "enqueue0";
     60 constexpr char kEnqueueOp1[] = "enqueue1";
     61 constexpr char kIllegalOpName1[] = "would fail";
     62 constexpr char kIllegalOpName2[] = "fail again";
     63 constexpr char kQueueName[] = "unit_test";
     64 constexpr char kQueueName0[] = "q0";
     65 constexpr char kQueueName1[] = "q1";
     66 constexpr char kSquareOpName[] = "square";
     67 constexpr char kVarOpName[] = "var";
     68 
     69 GraphDef BuildSimpleGraph() {
     70   Scope root = Scope::NewRootScope();
     71   auto init_value = Const(root, 0);
     72   auto var = Variable(root.WithOpName(kVarOpName), TensorShape({}),
     73                       DataType::DT_INT32);
     74   auto assign = Assign(root.WithOpName(kAssignOpName), var, init_value);
     75   auto count = CountUpTo(root.WithOpName(kCountUpToOpName), var, 10);
     76   Square(root.WithOpName(kSquareOpName), var);  // NOLINT
     77 
     78   GraphDef graph_def;
     79   TF_EXPECT_OK(root.ToGraphDef(&graph_def));
     80   return graph_def;
     81 }
     82 
     83 QueueRunnerDef BuildQueueRunnerDef(
     84     const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
     85     const std::string& close_op, const std::string& cancel_op,
     86     const std::vector<Code>& queue_closed_error_codes) {
     87   QueueRunnerDef queue_runner_def;
     88   *queue_runner_def.mutable_queue_name() = queue_name;
     89   for (const std::string& enqueue_op : enqueue_ops) {
     90     *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
     91   }
     92   *queue_runner_def.mutable_close_op_name() = close_op;
     93   *queue_runner_def.mutable_cancel_op_name() = cancel_op;
     94   for (const auto& error_code : queue_closed_error_codes) {
     95     *queue_runner_def.mutable_queue_closed_exception_types()->Add() =
     96         error_code;
     97   }
     98   return queue_runner_def;
     99 }
    100 
    101 std::unique_ptr<Session> BuildSessionAndInitVariable(
    102     const GraphDef& graph_def) {
    103   SessionOptions options;
    104   std::unique_ptr<Session> session(NewSession(options));
    105   TF_CHECK_OK(session->Create(graph_def));
    106 
    107   TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
    108   return session;
    109 }
    110 
    111 TEST(QueueRunnerTest, BasicTest) {
    112   GraphDef graph_def = BuildSimpleGraph();
    113   auto session = BuildSessionAndInitVariable(graph_def);
    114 
    115   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    116       kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
    117 
    118   std::unique_ptr<QueueRunner> qr;
    119   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    120   TF_CHECK_OK(qr->Start(session.get()));
    121   TF_EXPECT_OK(qr->Join());
    122 
    123   std::vector<Tensor> outputs;
    124   TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
    125   int square_value = *outputs[0].scalar<int>().data();
    126   EXPECT_EQ(square_value, 100);
    127 }
    128 
    129 TEST(QueueRunnerTest, QueueClosedCode) {
    130   GraphDef graph_def = BuildSimpleGraph();
    131   auto session = BuildSessionAndInitVariable(graph_def);
    132 
    133   // Start two queues so that multiple threads are in Run.
    134   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    135       kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "",
    136       {Code::OUT_OF_RANGE, Code::CANCELLED});
    137 
    138   std::unique_ptr<QueueRunner> qr;
    139   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    140   TF_EXPECT_OK(qr->Start(session.get()));
    141   TF_EXPECT_OK(qr->Join());
    142 
    143   std::vector<Tensor> outputs;
    144   TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
    145   int square_value = *outputs[0].scalar<int>().data();
    146   EXPECT_EQ(square_value, 100);
    147 }
    148 
    149 TEST(QueueRunnerTest, QueueCloseFails) {
    150   GraphDef graph_def = BuildSimpleGraph();
    151   auto session = BuildSessionAndInitVariable(graph_def);
    152 
    153   QueueRunnerDef queue_runner_def =
    154       BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kIllegalOpName1, "",
    155                           {Code::OUT_OF_RANGE});
    156 
    157   std::unique_ptr<QueueRunner> qr;
    158   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    159   TF_EXPECT_OK(qr->Start(session.get()));
    160   auto status = qr->Join();
    161   EXPECT_EQ(status.code(), Code::NOT_FOUND) << status;
    162 }
    163 
    164 TEST(QueueRunnerTest, CatchErrorInJoin) {
    165   GraphDef graph_def = BuildSimpleGraph();
    166   auto session = BuildSessionAndInitVariable(graph_def);
    167 
    168   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    169       kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
    170 
    171   std::unique_ptr<QueueRunner> qr;
    172   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    173   TF_EXPECT_OK(qr->Start(session.get()));
    174   EXPECT_EQ(qr->Join().code(), Code::NOT_FOUND);
    175 }
    176 
    177 GraphDef BuildDoubleQueueGraph() {
    178   Scope root = Scope::NewRootScope();
    179   auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32});
    180   auto ten = Const(root, 10);
    181   auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten});
    182   auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
    183   auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
    184                             QueueClose::CancelPendingEnqueues(true));
    185   auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32},
    186                       FIFOQueue::Capacity(3));
    187   auto dequeue0 =
    188       QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
    189   auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
    190   auto dequeue1 =
    191       QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32});
    192   auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1);
    193   auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1,
    194                             QueueClose::CancelPendingEnqueues(true));
    195 
    196   GraphDef graph_def;
    197   TF_EXPECT_OK(root.ToGraphDef(&graph_def));
    198   return graph_def;
    199 }
    200 
    201 TEST(QueueRunnerTest, RealEnqueueDequeue) {
    202   auto graph_def = BuildDoubleQueueGraph();
    203 
    204   SessionOptions options;
    205   std::unique_ptr<Session> session(NewSession(options));
    206   TF_CHECK_OK(session->Create(graph_def));
    207 
    208   QueueRunnerDef queue_runner_def =
    209       BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {});
    210   std::unique_ptr<QueueRunner> qr;
    211   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    212   TF_CHECK_OK(qr->Start(session.get()));
    213 
    214   TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
    215   TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
    216   // Closing queue 0 would also close the queue runner.
    217   TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr));
    218 
    219   TF_EXPECT_OK(qr->Join());
    220   std::vector<Tensor> dq1;
    221   TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
    222   EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
    223   std::vector<Tensor> dq2;
    224   TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2));
    225   EXPECT_EQ(*dq2[0].scalar<int>().data(), 10);
    226 
    227   EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
    228             Code::OUT_OF_RANGE);
    229 }
    230 
    231 void JoinThread(QueueRunner* queue_runner, bool* join_succeeded,
    232                 Notification* join_done) {
    233   EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED);
    234   *join_succeeded = true;
    235   join_done->Notify();
    236 }
    237 
    238 TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
    239   auto graph_def = BuildDoubleQueueGraph();
    240 
    241   SessionOptions options;
    242   std::unique_ptr<Session> session(NewSession(options));
    243   TF_CHECK_OK(session->Create(graph_def));
    244 
    245   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    246       kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
    247   std::unique_ptr<QueueRunner> qr;
    248   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    249   TF_CHECK_OK(qr->Start(session.get()));
    250 
    251   TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
    252 
    253   std::vector<Tensor> dq1;
    254   TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
    255   EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
    256 
    257   // The expected behavior is the QueueRunner::Join() call is blocked until
    258   // Session::Close() is called.
    259   bool join_succeeded = false;
    260   Notification join_done;
    261   Env::Default()->SchedClosure(
    262       std::bind(&JoinThread, qr.get(), &join_succeeded, &join_done));
    263 
    264   Env::Default()->SleepForMicroseconds(10000000);
    265   EXPECT_EQ(join_succeeded, false);
    266 
    267   // Closing the session is required to cancel pending enqueue nodes.
    268   TF_EXPECT_OK(session->Close());
    269 
    270   join_done.WaitForNotification();
    271   EXPECT_EQ(join_succeeded, true);
    272 }
    273 
    274 TEST(QueueRunnerTest, EmptyEnqueueOps) {
    275   QueueRunnerDef queue_runner_def =
    276       BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
    277 
    278   std::unique_ptr<QueueRunner> qr;
    279   EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
    280             Code::INVALID_ARGUMENT);
    281 }
    282 
    283 TEST(QueueRunnerTest, StartTimeout) {
    284   GraphDef graph_def = BuildDoubleQueueGraph();
    285   SessionOptions options;
    286   std::unique_ptr<Session> session(NewSession(options));
    287   TF_CHECK_OK(session->Create(graph_def));
    288 
    289   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    290       kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
    291 
    292   std::unique_ptr<QueueRunner> qr;
    293   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    294   // This will timeout since queue0 is not fed and queue1 is fetching data from
    295   // queue0.
    296   EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
    297   TF_EXPECT_OK(session->Close());
    298 }
    299 
    300 TEST(QueueRunnerTest, TestCoordinatorStop) {
    301   auto graph_def = BuildDoubleQueueGraph();
    302   SessionOptions options;
    303   std::unique_ptr<Session> session(NewSession(options));
    304   TF_CHECK_OK(session->Create(graph_def));
    305 
    306   QueueRunnerDef queue_runner0 =
    307       BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0,
    308                           {Code::OUT_OF_RANGE, Code::CANCELLED});
    309   QueueRunnerDef queue_runner1 =
    310       BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
    311                           {Code::OUT_OF_RANGE, Code::CANCELLED});
    312 
    313   Coordinator coord;
    314   std::unique_ptr<QueueRunner> qr0;
    315   TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0));
    316   TF_CHECK_OK(qr0->Start(session.get()));
    317   std::unique_ptr<QueueRunner> qr1;
    318   TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
    319   TF_CHECK_OK(qr1->Start(session.get()));
    320 
    321   TF_EXPECT_OK(coord.RegisterRunner(std::move(qr0)));
    322   TF_EXPECT_OK(coord.RegisterRunner(std::move(qr1)));
    323 
    324   std::vector<Tensor> dq;
    325   TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
    326   EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
    327 
    328   TF_EXPECT_OK(coord.RequestStop());
    329   TF_EXPECT_OK(coord.Join());
    330 }
    331 
    332 TEST(QueueRunnerTest, CallbackCalledOnError) {
    333   GraphDef graph_def = BuildSimpleGraph();
    334   auto session = BuildSessionAndInitVariable(graph_def);
    335 
    336   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    337       kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
    338 
    339   std::unique_ptr<QueueRunner> qr;
    340   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    341   bool error_caught = false;
    342   qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; });
    343   TF_EXPECT_OK(qr->Start(session.get()));
    344   EXPECT_FALSE(qr->Join().ok());
    345   EXPECT_TRUE(error_caught);
    346 }
    347 
    348 TEST(QueueRunnerTest, RunMetaDataTest) {
    349   Scope root = Scope::NewRootScope();
    350   auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT});
    351   Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT);
    352   Output square = Square(root.WithOpName(kSquareOpName), rnd);
    353   auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square});
    354   auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
    355   auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
    356                             QueueClose::CancelPendingEnqueues(true));
    357   auto dequeue0 =
    358       QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT});
    359 
    360   GraphDef graph_def;
    361   TF_EXPECT_OK(root.ToGraphDef(&graph_def));
    362   for (auto& node : *graph_def.mutable_node()) {
    363     node.set_device("/cpu:0");
    364   }
    365   SessionOptions sess_options;
    366   sess_options.config.mutable_graph_options()->set_build_cost_model(1);
    367   std::unique_ptr<Session> session(NewSession(sess_options));
    368 
    369   TF_CHECK_OK(session->Create(graph_def));
    370 
    371   QueueRunnerDef queue_runner_def =
    372       BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {});
    373   std::unique_ptr<QueueRunner> qr;
    374   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    375   RunOptions run_options;
    376   TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), run_options));
    377 
    378   // Make sure there was at least one element enqueued in q0: this prevents a
    379   // race condition where we close the queue before it was populated.
    380   std::vector<Tensor> dq0;
    381   TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
    382   // Second call to run dequeue op is to make sure the cost graph has been
    383   // stored.
    384   TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0));
    385 
    386   CostGraphDef cost_graph;
    387   TF_CHECK_OK(qr->ExportCostGraph(&cost_graph));
    388   EXPECT_TRUE(cost_graph.node_size() > 0);
    389 
    390   qr->Stop(session.get());
    391 }
    392 
    393 TEST(QueueRunnerTest, NoRunMetaDataTest) {
    394   GraphDef graph_def = BuildSimpleGraph();
    395   auto session = BuildSessionAndInitVariable(graph_def);
    396 
    397   QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
    398       kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
    399   std::unique_ptr<QueueRunner> qr;
    400   TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    401   TF_CHECK_OK(qr->Start(session.get()));
    402 
    403   TF_EXPECT_OK(qr->Join());
    404   CostGraphDef cost_graph;
    405   EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(),
    406             error::FAILED_PRECONDITION);
    407 }
    408 
    409 }  // namespace
    410 }  // namespace tensorflow
    411