1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/data_flow_ops.cc. 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/queue_interface.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/framework/types.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/platform/macros.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 29 class QueueOpKernel : public AsyncOpKernel { 30 public: 31 explicit QueueOpKernel(OpKernelConstruction* context) 32 : AsyncOpKernel(context) {} 33 34 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { 35 QueueInterface* queue; 36 if (ctx->input_dtype(0) == DT_RESOURCE) { 37 OP_REQUIRES_OK_ASYNC( 38 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); 39 } else { 40 OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), 41 callback); 42 } 43 ComputeAsync(ctx, queue, [callback, queue]() { 44 queue->Unref(); 45 callback(); 46 }); 47 } 48 49 protected: 50 virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 51 DoneCallback callback) = 0; 52 }; 53 54 class QueueAccessOpKernel : public QueueOpKernel { 55 public: 56 explicit QueueAccessOpKernel(OpKernelConstruction* context) 57 : QueueOpKernel(context) { 58 OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); 59 // TODO(keveman): Enable timeout. 60 OP_REQUIRES(context, timeout_ == -1, 61 errors::InvalidArgument("Timeout not supported yet.")); 62 } 63 64 protected: 65 int64 timeout_; 66 }; 67 68 // Defines an EnqueueOp, the execution of which enqueues a tuple of 69 // tensors in the given Queue. 70 // 71 // The op has 1 + k inputs, where k is the number of components in the 72 // tuples stored in the given Queue: 73 // - Input 0: queue handle. 74 // - Input 1: 0th element of the tuple. 75 // - ... 76 // - Input (1+k): kth element of the tuple. 77 class EnqueueOp : public QueueAccessOpKernel { 78 public: 79 explicit EnqueueOp(OpKernelConstruction* context) 80 : QueueAccessOpKernel(context) {} 81 82 protected: 83 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 84 DoneCallback callback) override { 85 DataTypeVector expected_inputs; 86 if (ctx->input_dtype(0) == DT_RESOURCE) { 87 expected_inputs.push_back(DT_RESOURCE); 88 } else { 89 expected_inputs.push_back(DT_STRING_REF); 90 } 91 for (DataType dt : queue->component_dtypes()) { 92 expected_inputs.push_back(dt); 93 } 94 OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), 95 callback); 96 97 QueueInterface::Tuple tuple; 98 OpInputList components; 99 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), 100 callback); 101 for (const Tensor& Tcomponent : components) { 102 tuple.push_back(Tcomponent); 103 } 104 105 OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); 106 queue->TryEnqueue(tuple, ctx, callback); 107 } 108 109 private: 110 TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); 111 }; 112 113 REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); 114 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp); 115 116 // Defines an EnqueueManyOp, the execution of which slices each 117 // component of a tuple of tensors along the 0th dimension, and 118 // enqueues tuples of slices in the given Queue. 119 // 120 // The op has 1 + k inputs, where k is the number of components in the 121 // tuples stored in the given Queue: 122 // - Input 0: queue handle. 123 // - Input 1: 0th element of the tuple. 124 // - ... 125 // - Input (1+k): kth element of the tuple. 126 // 127 // N.B. All tuple components must have the same size in the 0th 128 // dimension. 129 class EnqueueManyOp : public QueueAccessOpKernel { 130 public: 131 explicit EnqueueManyOp(OpKernelConstruction* context) 132 : QueueAccessOpKernel(context) {} 133 134 protected: 135 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 136 DoneCallback callback) override { 137 DataTypeVector expected_inputs; 138 if (ctx->input_dtype(0) == DT_RESOURCE) { 139 expected_inputs.push_back(DT_RESOURCE); 140 } else { 141 expected_inputs.push_back(DT_STRING_REF); 142 } 143 for (DataType dt : queue->component_dtypes()) { 144 expected_inputs.push_back(dt); 145 } 146 OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), 147 callback); 148 149 QueueInterface::Tuple tuple; 150 OpInputList components; 151 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), 152 callback); 153 for (const Tensor& Tcomponent : components) { 154 tuple.push_back(Tcomponent); 155 } 156 157 OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); 158 queue->TryEnqueueMany(tuple, ctx, callback); 159 } 160 161 ~EnqueueManyOp() override {} 162 163 private: 164 TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); 165 }; 166 167 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), 168 EnqueueManyOp); 169 REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU), 170 EnqueueManyOp); 171 172 // Defines a DequeueOp, the execution of which dequeues a tuple of 173 // tensors from the given Queue. 174 // 175 // The op has one input, which is the handle of the appropriate 176 // Queue. The op has k outputs, where k is the number of components in 177 // the tuples stored in the given Queue, and output i is the ith 178 // component of the dequeued tuple. 179 class DequeueOp : public QueueAccessOpKernel { 180 public: 181 explicit DequeueOp(OpKernelConstruction* context) 182 : QueueAccessOpKernel(context) {} 183 184 protected: 185 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 186 DoneCallback callback) override { 187 if (ctx->input_dtype(0) == DT_RESOURCE) { 188 OP_REQUIRES_OK_ASYNC( 189 ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), 190 callback); 191 } else { 192 OP_REQUIRES_OK_ASYNC( 193 ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), 194 callback); 195 } 196 197 queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { 198 if (!ctx->status().ok()) { 199 callback(); 200 return; 201 } 202 OpOutputList output_components; 203 OP_REQUIRES_OK_ASYNC( 204 ctx, ctx->output_list("components", &output_components), callback); 205 for (int i = 0; i < ctx->num_outputs(); ++i) { 206 output_components.set(i, tuple[i]); 207 } 208 callback(); 209 }); 210 } 211 212 ~DequeueOp() override {} 213 214 private: 215 TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); 216 }; 217 218 REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); 219 REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp); 220 221 // Defines a DequeueManyOp, the execution of which concatenates the 222 // requested number of elements from the given Queue along the 0th 223 // dimension, and emits the result as a single tuple of tensors. 224 // 225 // The op has two inputs: 226 // - Input 0: the handle to a queue. 227 // - Input 1: the number of elements to dequeue. 228 // 229 // The op has k outputs, where k is the number of components in the 230 // tuples stored in the given Queue, and output i is the ith component 231 // of the dequeued tuple. 232 class DequeueManyOp : public QueueAccessOpKernel { 233 public: 234 explicit DequeueManyOp(OpKernelConstruction* context) 235 : QueueAccessOpKernel(context) {} 236 237 protected: 238 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 239 DoneCallback callback) override { 240 const Tensor& Tnum_elements = ctx->input(1); 241 int32 num_elements = Tnum_elements.flat<int32>()(0); 242 243 OP_REQUIRES_ASYNC(ctx, num_elements >= 0, 244 errors::InvalidArgument("DequeueManyOp requested ", 245 num_elements, " < 0 elements"), 246 callback); 247 248 if (ctx->input_dtype(0) == DT_RESOURCE) { 249 OP_REQUIRES_OK_ASYNC(ctx, 250 ctx->MatchSignature({DT_RESOURCE, DT_INT32}, 251 queue->component_dtypes()), 252 callback); 253 } else { 254 OP_REQUIRES_OK_ASYNC(ctx, 255 ctx->MatchSignature({DT_STRING_REF, DT_INT32}, 256 queue->component_dtypes()), 257 callback); 258 } 259 260 queue->TryDequeueMany( 261 num_elements, ctx, false /* allow_small_batch */, 262 [ctx, callback](const QueueInterface::Tuple& tuple) { 263 if (!ctx->status().ok()) { 264 callback(); 265 return; 266 } 267 OpOutputList output_components; 268 OP_REQUIRES_OK_ASYNC( 269 ctx, ctx->output_list("components", &output_components), 270 callback); 271 for (int i = 0; i < ctx->num_outputs(); ++i) { 272 output_components.set(i, tuple[i]); 273 } 274 callback(); 275 }); 276 } 277 278 ~DequeueManyOp() override {} 279 280 private: 281 TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); 282 }; 283 284 REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), 285 DequeueManyOp); 286 REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU), 287 DequeueManyOp); 288 289 // Defines a DequeueUpToOp, the execution of which concatenates the 290 // requested number of elements from the given Queue along the 0th 291 // dimension, and emits the result as a single tuple of tensors. 292 // 293 // The difference between this op and DequeueMany is the handling when 294 // the Queue is closed. While the DequeueMany op will return if there 295 // an error when there are less than num_elements elements left in the 296 // closed queue, this op will return between 1 and 297 // min(num_elements, elements_remaining_in_queue), and will not block. 298 // If there are no elements left, then the standard DequeueMany error 299 // is returned. 300 // 301 // This op only works if the underlying Queue implementation accepts 302 // the allow_small_batch = true parameter to TryDequeueMany. 303 // If it does not, an errors::Unimplemented exception is returned. 304 // 305 // The op has two inputs: 306 // - Input 0: the handle to a queue. 307 // - Input 1: the number of elements to dequeue. 308 // 309 // The op has k outputs, where k is the number of components in the 310 // tuples stored in the given Queue, and output i is the ith component 311 // of the dequeued tuple. 312 // 313 // The op has one attribute: allow_small_batch. If the Queue supports 314 // it, setting this to true causes the queue to return smaller 315 // (possibly zero length) batches when it is closed, up to however 316 // many elements are available when the op executes. In this case, 317 // the Queue does not block when closed. 318 class DequeueUpToOp : public QueueAccessOpKernel { 319 public: 320 explicit DequeueUpToOp(OpKernelConstruction* context) 321 : QueueAccessOpKernel(context) {} 322 323 protected: 324 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 325 DoneCallback callback) override { 326 const Tensor& Tnum_elements = ctx->input(1); 327 int32 num_elements = Tnum_elements.flat<int32>()(0); 328 329 OP_REQUIRES_ASYNC(ctx, num_elements >= 0, 330 errors::InvalidArgument("DequeueUpToOp requested ", 331 num_elements, " < 0 elements"), 332 callback); 333 334 if (ctx->input_dtype(0) == DT_RESOURCE) { 335 OP_REQUIRES_OK_ASYNC(ctx, 336 ctx->MatchSignature({DT_RESOURCE, DT_INT32}, 337 queue->component_dtypes()), 338 callback); 339 } else { 340 OP_REQUIRES_OK_ASYNC(ctx, 341 ctx->MatchSignature({DT_STRING_REF, DT_INT32}, 342 queue->component_dtypes()), 343 callback); 344 } 345 346 queue->TryDequeueMany( 347 num_elements, ctx, true /* allow_small_batch */, 348 [ctx, callback](const QueueInterface::Tuple& tuple) { 349 if (!ctx->status().ok()) { 350 callback(); 351 return; 352 } 353 OpOutputList output_components; 354 OP_REQUIRES_OK_ASYNC( 355 ctx, ctx->output_list("components", &output_components), 356 callback); 357 for (int i = 0; i < ctx->num_outputs(); ++i) { 358 output_components.set(i, tuple[i]); 359 } 360 callback(); 361 }); 362 } 363 364 ~DequeueUpToOp() override {} 365 366 private: 367 TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); 368 }; 369 370 REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU), 371 DequeueUpToOp); 372 REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU), 373 DequeueUpToOp); 374 375 // Defines a QueueCloseOp, which closes the given Queue. Closing a 376 // Queue signals that no more elements will be enqueued in it. 377 // 378 // The op has one input, which is the handle of the appropriate Queue. 379 class QueueCloseOp : public QueueOpKernel { 380 public: 381 explicit QueueCloseOp(OpKernelConstruction* context) 382 : QueueOpKernel(context) { 383 OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", 384 &cancel_pending_enqueues_)); 385 } 386 387 protected: 388 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 389 DoneCallback callback) override { 390 queue->Close(ctx, cancel_pending_enqueues_, callback); 391 } 392 393 private: 394 bool cancel_pending_enqueues_; 395 TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); 396 }; 397 398 REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); 399 REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp); 400 401 // Defines a QueueSizeOp, which computes the number of elements in the 402 // given Queue, and emits it as an output tensor. 403 // 404 // The op has one input, which is the handle of the appropriate Queue; 405 // and one output, which is a single-element tensor containing the current 406 // size of that Queue. 407 class QueueSizeOp : public QueueOpKernel { 408 public: 409 explicit QueueSizeOp(OpKernelConstruction* context) 410 : QueueOpKernel(context) {} 411 412 protected: 413 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 414 DoneCallback callback) override { 415 Tensor* Tqueue_size = nullptr; 416 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); 417 Tqueue_size->flat<int32>().setConstant(queue->size()); 418 callback(); 419 } 420 421 private: 422 TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); 423 }; 424 425 REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); 426 REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp); 427 428 class QueueIsClosedOp : public QueueOpKernel { 429 public: 430 explicit QueueIsClosedOp(OpKernelConstruction* context) 431 : QueueOpKernel(context) {} 432 433 protected: 434 void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, 435 DoneCallback callback) override { 436 Tensor* Tqueue_is_closed = nullptr; 437 OP_REQUIRES_OK(ctx, 438 ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); 439 Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed()); 440 callback(); 441 } 442 443 private: 444 TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp); 445 }; 446 447 REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU), 448 QueueIsClosedOp); 449 REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU), 450 QueueIsClosedOp); 451 452 class FakeQueueOp : public OpKernel { 453 public: 454 explicit FakeQueueOp(OpKernelConstruction* context) : OpKernel(context) { 455 OP_REQUIRES_OK(context, 456 context->allocate_persistent(DT_STRING, TensorShape({2}), 457 &handle_, nullptr)); 458 } 459 460 void Compute(OpKernelContext* context) override { 461 ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0); 462 handle_.AccessTensor(context)->flat<string>()(0) = ref.container(); 463 handle_.AccessTensor(context)->flat<string>()(1) = ref.name(); 464 context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); 465 } 466 467 private: 468 mutex mu_; 469 PersistentTensor handle_; 470 }; 471 472 REGISTER_KERNEL_BUILDER(Name("FakeQueue").Device(DEVICE_CPU), FakeQueueOp); 473 474 } // namespace tensorflow 475