Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/queue_interface.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/framework/types.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/platform/macros.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 
     29 class QueueOpKernel : public AsyncOpKernel {
     30  public:
     31   explicit QueueOpKernel(OpKernelConstruction* context)
     32       : AsyncOpKernel(context) {}
     33 
     34   void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
     35     QueueInterface* queue;
     36     if (ctx->input_dtype(0) == DT_RESOURCE) {
     37       OP_REQUIRES_OK_ASYNC(
     38           ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
     39     } else {
     40       OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
     41                            callback);
     42     }
     43     ComputeAsync(ctx, queue, [callback, queue]() {
     44       queue->Unref();
     45       callback();
     46     });
     47   }
     48 
     49  protected:
     50   virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
     51                             DoneCallback callback) = 0;
     52 };
     53 
     54 class QueueAccessOpKernel : public QueueOpKernel {
     55  public:
     56   explicit QueueAccessOpKernel(OpKernelConstruction* context)
     57       : QueueOpKernel(context) {
     58     OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
     59     // TODO(keveman): Enable timeout.
     60     OP_REQUIRES(context, timeout_ == -1,
     61                 errors::InvalidArgument("Timeout not supported yet."));
     62   }
     63 
     64  protected:
     65   int64 timeout_;
     66 };
     67 
     68 // Defines an EnqueueOp, the execution of which enqueues a tuple of
     69 // tensors in the given Queue.
     70 //
     71 // The op has 1 + k inputs, where k is the number of components in the
     72 // tuples stored in the given Queue:
     73 // - Input 0: queue handle.
     74 // - Input 1: 0th element of the tuple.
     75 // - ...
     76 // - Input (1+k): kth element of the tuple.
     77 class EnqueueOp : public QueueAccessOpKernel {
     78  public:
     79   explicit EnqueueOp(OpKernelConstruction* context)
     80       : QueueAccessOpKernel(context) {}
     81 
     82  protected:
     83   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
     84                     DoneCallback callback) override {
     85     DataTypeVector expected_inputs;
     86     if (ctx->input_dtype(0) == DT_RESOURCE) {
     87       expected_inputs.push_back(DT_RESOURCE);
     88     } else {
     89       expected_inputs.push_back(DT_STRING_REF);
     90     }
     91     for (DataType dt : queue->component_dtypes()) {
     92       expected_inputs.push_back(dt);
     93     }
     94     OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
     95                          callback);
     96 
     97     QueueInterface::Tuple tuple;
     98     OpInputList components;
     99     OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
    100                          callback);
    101     for (const Tensor& Tcomponent : components) {
    102       tuple.push_back(Tcomponent);
    103     }
    104 
    105     OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
    106     queue->TryEnqueue(tuple, ctx, callback);
    107   }
    108 
    109  private:
    110   TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
    111 };
    112 
    113 REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
    114 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
    115 
    116 // Defines an EnqueueManyOp, the execution of which slices each
    117 // component of a tuple of tensors along the 0th dimension, and
    118 // enqueues tuples of slices in the given Queue.
    119 //
    120 // The op has 1 + k inputs, where k is the number of components in the
    121 // tuples stored in the given Queue:
    122 // - Input 0: queue handle.
    123 // - Input 1: 0th element of the tuple.
    124 // - ...
    125 // - Input (1+k): kth element of the tuple.
    126 //
    127 // N.B. All tuple components must have the same size in the 0th
    128 // dimension.
    129 class EnqueueManyOp : public QueueAccessOpKernel {
    130  public:
    131   explicit EnqueueManyOp(OpKernelConstruction* context)
    132       : QueueAccessOpKernel(context) {}
    133 
    134  protected:
    135   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    136                     DoneCallback callback) override {
    137     DataTypeVector expected_inputs;
    138     if (ctx->input_dtype(0) == DT_RESOURCE) {
    139       expected_inputs.push_back(DT_RESOURCE);
    140     } else {
    141       expected_inputs.push_back(DT_STRING_REF);
    142     }
    143     for (DataType dt : queue->component_dtypes()) {
    144       expected_inputs.push_back(dt);
    145     }
    146     OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
    147                          callback);
    148 
    149     QueueInterface::Tuple tuple;
    150     OpInputList components;
    151     OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
    152                          callback);
    153     for (const Tensor& Tcomponent : components) {
    154       tuple.push_back(Tcomponent);
    155     }
    156 
    157     OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
    158     queue->TryEnqueueMany(tuple, ctx, callback);
    159   }
    160 
    161   ~EnqueueManyOp() override {}
    162 
    163  private:
    164   TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
    165 };
    166 
    167 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
    168                         EnqueueManyOp);
    169 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
    170                         EnqueueManyOp);
    171 
    172 // Defines a DequeueOp, the execution of which dequeues a tuple of
    173 // tensors from the given Queue.
    174 //
    175 // The op has one input, which is the handle of the appropriate
    176 // Queue. The op has k outputs, where k is the number of components in
    177 // the tuples stored in the given Queue, and output i is the ith
    178 // component of the dequeued tuple.
    179 class DequeueOp : public QueueAccessOpKernel {
    180  public:
    181   explicit DequeueOp(OpKernelConstruction* context)
    182       : QueueAccessOpKernel(context) {}
    183 
    184  protected:
    185   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    186                     DoneCallback callback) override {
    187     if (ctx->input_dtype(0) == DT_RESOURCE) {
    188       OP_REQUIRES_OK_ASYNC(
    189           ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
    190           callback);
    191     } else {
    192       OP_REQUIRES_OK_ASYNC(
    193           ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
    194           callback);
    195     }
    196 
    197     queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
    198       if (!ctx->status().ok()) {
    199         callback();
    200         return;
    201       }
    202       OpOutputList output_components;
    203       OP_REQUIRES_OK_ASYNC(
    204           ctx, ctx->output_list("components", &output_components), callback);
    205       for (int i = 0; i < ctx->num_outputs(); ++i) {
    206         output_components.set(i, tuple[i]);
    207       }
    208       callback();
    209     });
    210   }
    211 
    212   ~DequeueOp() override {}
    213 
    214  private:
    215   TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
    216 };
    217 
    218 REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
    219 REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
    220 
    221 // Defines a DequeueManyOp, the execution of which concatenates the
    222 // requested number of elements from the given Queue along the 0th
    223 // dimension, and emits the result as a single tuple of tensors.
    224 //
    225 // The op has two inputs:
    226 // - Input 0: the handle to a queue.
    227 // - Input 1: the number of elements to dequeue.
    228 //
    229 // The op has k outputs, where k is the number of components in the
    230 // tuples stored in the given Queue, and output i is the ith component
    231 // of the dequeued tuple.
    232 class DequeueManyOp : public QueueAccessOpKernel {
    233  public:
    234   explicit DequeueManyOp(OpKernelConstruction* context)
    235       : QueueAccessOpKernel(context) {}
    236 
    237  protected:
    238   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    239                     DoneCallback callback) override {
    240     const Tensor& Tnum_elements = ctx->input(1);
    241     int32 num_elements = Tnum_elements.flat<int32>()(0);
    242 
    243     OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
    244                       errors::InvalidArgument("DequeueManyOp requested ",
    245                                               num_elements, " < 0 elements"),
    246                       callback);
    247 
    248     if (ctx->input_dtype(0) == DT_RESOURCE) {
    249       OP_REQUIRES_OK_ASYNC(ctx,
    250                            ctx->MatchSignature({DT_RESOURCE, DT_INT32},
    251                                                queue->component_dtypes()),
    252                            callback);
    253     } else {
    254       OP_REQUIRES_OK_ASYNC(ctx,
    255                            ctx->MatchSignature({DT_STRING_REF, DT_INT32},
    256                                                queue->component_dtypes()),
    257                            callback);
    258     }
    259 
    260     queue->TryDequeueMany(
    261         num_elements, ctx, false /* allow_small_batch */,
    262         [ctx, callback](const QueueInterface::Tuple& tuple) {
    263           if (!ctx->status().ok()) {
    264             callback();
    265             return;
    266           }
    267           OpOutputList output_components;
    268           OP_REQUIRES_OK_ASYNC(
    269               ctx, ctx->output_list("components", &output_components),
    270               callback);
    271           for (int i = 0; i < ctx->num_outputs(); ++i) {
    272             output_components.set(i, tuple[i]);
    273           }
    274           callback();
    275         });
    276   }
    277 
    278   ~DequeueManyOp() override {}
    279 
    280  private:
    281   TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
    282 };
    283 
    284 REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
    285                         DequeueManyOp);
    286 REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
    287                         DequeueManyOp);
    288 
    289 // Defines a DequeueUpToOp, the execution of which concatenates the
    290 // requested number of elements from the given Queue along the 0th
    291 // dimension, and emits the result as a single tuple of tensors.
    292 //
    293 // The difference between this op and DequeueMany is the handling when
    294 // the Queue is closed.  While the DequeueMany op will return if there
    295 // an error when there are less than num_elements elements left in the
    296 // closed queue, this op will return between 1 and
    297 // min(num_elements, elements_remaining_in_queue), and will not block.
    298 // If there are no elements left, then the standard DequeueMany error
    299 // is returned.
    300 //
    301 // This op only works if the underlying Queue implementation accepts
    302 // the allow_small_batch = true parameter to TryDequeueMany.
    303 // If it does not, an errors::Unimplemented exception is returned.
    304 //
    305 // The op has two inputs:
    306 // - Input 0: the handle to a queue.
    307 // - Input 1: the number of elements to dequeue.
    308 //
    309 // The op has k outputs, where k is the number of components in the
    310 // tuples stored in the given Queue, and output i is the ith component
    311 // of the dequeued tuple.
    312 //
    313 // The op has one attribute: allow_small_batch.  If the Queue supports
    314 // it, setting this to true causes the queue to return smaller
    315 // (possibly zero length) batches when it is closed, up to however
    316 // many elements are available when the op executes.  In this case,
    317 // the Queue does not block when closed.
    318 class DequeueUpToOp : public QueueAccessOpKernel {
    319  public:
    320   explicit DequeueUpToOp(OpKernelConstruction* context)
    321       : QueueAccessOpKernel(context) {}
    322 
    323  protected:
    324   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    325                     DoneCallback callback) override {
    326     const Tensor& Tnum_elements = ctx->input(1);
    327     int32 num_elements = Tnum_elements.flat<int32>()(0);
    328 
    329     OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
    330                       errors::InvalidArgument("DequeueUpToOp requested ",
    331                                               num_elements, " < 0 elements"),
    332                       callback);
    333 
    334     if (ctx->input_dtype(0) == DT_RESOURCE) {
    335       OP_REQUIRES_OK_ASYNC(ctx,
    336                            ctx->MatchSignature({DT_RESOURCE, DT_INT32},
    337                                                queue->component_dtypes()),
    338                            callback);
    339     } else {
    340       OP_REQUIRES_OK_ASYNC(ctx,
    341                            ctx->MatchSignature({DT_STRING_REF, DT_INT32},
    342                                                queue->component_dtypes()),
    343                            callback);
    344     }
    345 
    346     queue->TryDequeueMany(
    347         num_elements, ctx, true /* allow_small_batch */,
    348         [ctx, callback](const QueueInterface::Tuple& tuple) {
    349           if (!ctx->status().ok()) {
    350             callback();
    351             return;
    352           }
    353           OpOutputList output_components;
    354           OP_REQUIRES_OK_ASYNC(
    355               ctx, ctx->output_list("components", &output_components),
    356               callback);
    357           for (int i = 0; i < ctx->num_outputs(); ++i) {
    358             output_components.set(i, tuple[i]);
    359           }
    360           callback();
    361         });
    362   }
    363 
    364   ~DequeueUpToOp() override {}
    365 
    366  private:
    367   TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
    368 };
    369 
    370 REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
    371                         DequeueUpToOp);
    372 REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
    373                         DequeueUpToOp);
    374 
    375 // Defines a QueueCloseOp, which closes the given Queue. Closing a
    376 // Queue signals that no more elements will be enqueued in it.
    377 //
    378 // The op has one input, which is the handle of the appropriate Queue.
    379 class QueueCloseOp : public QueueOpKernel {
    380  public:
    381   explicit QueueCloseOp(OpKernelConstruction* context)
    382       : QueueOpKernel(context) {
    383     OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
    384                                              &cancel_pending_enqueues_));
    385   }
    386 
    387  protected:
    388   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    389                     DoneCallback callback) override {
    390     queue->Close(ctx, cancel_pending_enqueues_, callback);
    391   }
    392 
    393  private:
    394   bool cancel_pending_enqueues_;
    395   TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
    396 };
    397 
    398 REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
    399 REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
    400 
    401 // Defines a QueueSizeOp, which computes the number of elements in the
    402 // given Queue, and emits it as an output tensor.
    403 //
    404 // The op has one input, which is the handle of the appropriate Queue;
    405 // and one output, which is a single-element tensor containing the current
    406 // size of that Queue.
    407 class QueueSizeOp : public QueueOpKernel {
    408  public:
    409   explicit QueueSizeOp(OpKernelConstruction* context)
    410       : QueueOpKernel(context) {}
    411 
    412  protected:
    413   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    414                     DoneCallback callback) override {
    415     Tensor* Tqueue_size = nullptr;
    416     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
    417     Tqueue_size->flat<int32>().setConstant(queue->size());
    418     callback();
    419   }
    420 
    421  private:
    422   TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
    423 };
    424 
    425 REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
    426 REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
    427 
    428 class QueueIsClosedOp : public QueueOpKernel {
    429  public:
    430   explicit QueueIsClosedOp(OpKernelConstruction* context)
    431       : QueueOpKernel(context) {}
    432 
    433  protected:
    434   void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
    435                     DoneCallback callback) override {
    436     Tensor* Tqueue_is_closed = nullptr;
    437     OP_REQUIRES_OK(ctx,
    438                    ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
    439     Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
    440     callback();
    441   }
    442 
    443  private:
    444   TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
    445 };
    446 
    447 REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
    448                         QueueIsClosedOp);
    449 REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),
    450                         QueueIsClosedOp);
    451 
    452 class FakeQueueOp : public OpKernel {
    453  public:
    454   explicit FakeQueueOp(OpKernelConstruction* context) : OpKernel(context) {
    455     OP_REQUIRES_OK(context,
    456                    context->allocate_persistent(DT_STRING, TensorShape({2}),
    457                                                 &handle_, nullptr));
    458   }
    459 
    460   void Compute(OpKernelContext* context) override {
    461     ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
    462     handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
    463     handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
    464     context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
    465   }
    466 
    467  private:
    468   mutex mu_;
    469   PersistentTensor handle_;
    470 };
    471 
    472 REGISTER_KERNEL_BUILDER(Name("FakeQueue").Device(DEVICE_CPU), FakeQueueOp);
    473 
    474 }  // namespace tensorflow
    475