1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/io_ops.cc. 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/queue_interface.h" 20 #include "tensorflow/core/framework/reader_interface.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/kernels/ops_util.h" 23 #include "tensorflow/core/lib/core/threadpool.h" 24 #include "tensorflow/core/lib/strings/strcat.h" 25 26 namespace tensorflow { 27 28 class ReaderVerbSyncOpKernel : public OpKernel { 29 public: 30 using OpKernel::OpKernel; 31 32 void Compute(OpKernelContext* context) override { 33 ReaderInterface* reader; 34 OP_REQUIRES_OK(context, 35 GetResourceFromContext(context, "reader_handle", &reader)); 36 ComputeWithReader(context, reader); 37 reader->Unref(); 38 } 39 40 protected: 41 virtual void ComputeWithReader(OpKernelContext* context, 42 ReaderInterface* reader) = 0; 43 }; 44 45 class ReaderVerbAsyncOpKernel : public AsyncOpKernel { 46 public: 47 using AsyncOpKernel::AsyncOpKernel; 48 49 explicit ReaderVerbAsyncOpKernel(OpKernelConstruction* context) 50 : AsyncOpKernel(context), 51 thread_pool_(new thread::ThreadPool( 52 context->env(), ThreadOptions(), 53 strings::StrCat("reader_thread_", SanitizeThreadSuffix(name())), 54 1 /* num_threads */, false /* low_latency_hint */)) {} 55 56 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 57 ReaderInterface* reader; 58 OP_REQUIRES_OK_ASYNC( 59 context, GetResourceFromContext(context, "reader_handle", &reader), 60 done); 61 thread_pool_->Schedule([this, context, reader, done]() { 62 ComputeWithReader(context, reader); 63 reader->Unref(); 64 done(); 65 }); 66 } 67 68 protected: 69 virtual void ComputeWithReader(OpKernelContext* context, 70 ReaderInterface* reader) = 0; 71 72 private: 73 std::unique_ptr<thread::ThreadPool> thread_pool_; 74 }; 75 76 class ReaderReadOp : public ReaderVerbAsyncOpKernel { 77 public: 78 using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel; 79 80 void ComputeWithReader(OpKernelContext* context, 81 ReaderInterface* reader) override { 82 QueueInterface* queue; 83 OP_REQUIRES_OK(context, 84 GetResourceFromContext(context, "queue_handle", &queue)); 85 core::ScopedUnref unref_me(queue); 86 Tensor* key = nullptr; 87 OP_REQUIRES_OK(context, 88 context->allocate_output("key", TensorShape({}), &key)); 89 Tensor* value = nullptr; 90 OP_REQUIRES_OK(context, 91 context->allocate_output("value", TensorShape({}), &value)); 92 93 auto key_scalar = key->scalar<string>(); 94 auto value_scalar = value->scalar<string>(); 95 reader->Read(queue, &key_scalar(), &value_scalar(), context); 96 } 97 }; 98 99 REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp); 100 REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp); 101 102 class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel { 103 public: 104 using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel; 105 106 void ComputeWithReader(OpKernelContext* context, 107 ReaderInterface* reader) override { 108 QueueInterface* queue; 109 110 const Tensor* num_records_tensor; 111 OP_REQUIRES_OK(context, context->input("num_records", &num_records_tensor)); 112 int64 num_records = num_records_tensor->scalar<int64>()(); 113 114 OP_REQUIRES_OK(context, 115 GetResourceFromContext(context, "queue_handle", &queue)); 116 core::ScopedUnref unref_me(queue); 117 118 std::vector<string> keys_vec; 119 keys_vec.reserve(num_records); 120 std::vector<string> values_vec; 121 values_vec.reserve(num_records); 122 123 int64 num_actually_read = 124 reader->ReadUpTo(num_records, queue, &keys_vec, &values_vec, context); 125 126 OP_REQUIRES(context, num_actually_read == keys_vec.size(), 127 errors::InvalidArgument("num_actually_read != len(keys_vec")); 128 129 OP_REQUIRES(context, num_actually_read == values_vec.size(), 130 errors::InvalidArgument("num_actually_read != len(values_vec")); 131 132 Tensor* keys = nullptr; 133 OP_REQUIRES_OK(context, 134 context->allocate_output( 135 "keys", TensorShape({num_actually_read}), &keys)); 136 137 Tensor* values = nullptr; 138 OP_REQUIRES_OK(context, 139 context->allocate_output( 140 "values", TensorShape({num_actually_read}), &values)); 141 142 auto keys_t = keys->vec<string>(); 143 auto values_t = values->vec<string>(); 144 for (int i = 0; i < num_actually_read; ++i) { 145 keys_t(i) = std::move(keys_vec[i]); 146 values_t(i) = std::move(values_vec[i]); 147 } 148 } 149 }; 150 151 REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU), 152 ReaderReadUpToOp); 153 REGISTER_KERNEL_BUILDER(Name("ReaderReadUpToV2").Device(DEVICE_CPU), 154 ReaderReadUpToOp); 155 156 class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel { 157 public: 158 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; 159 160 void ComputeWithReader(OpKernelContext* context, 161 ReaderInterface* reader) override { 162 Tensor* output = nullptr; 163 OP_REQUIRES_OK(context, context->allocate_output("records_produced", 164 TensorShape({}), &output)); 165 output->scalar<int64>()() = reader->NumRecordsProduced(); 166 } 167 }; 168 169 REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU), 170 ReaderNumRecordsProducedOp); 171 REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProducedV2").Device(DEVICE_CPU), 172 ReaderNumRecordsProducedOp); 173 174 class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel { 175 public: 176 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; 177 178 void ComputeWithReader(OpKernelContext* context, 179 ReaderInterface* reader) override { 180 Tensor* output = nullptr; 181 OP_REQUIRES_OK(context, context->allocate_output("units_completed", 182 TensorShape({}), &output)); 183 output->scalar<int64>()() = reader->NumWorkUnitsCompleted(); 184 } 185 }; 186 187 REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU), 188 ReaderNumWorkUnitsCompletedOp); 189 REGISTER_KERNEL_BUILDER( 190 Name("ReaderNumWorkUnitsCompletedV2").Device(DEVICE_CPU), 191 ReaderNumWorkUnitsCompletedOp); 192 193 class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel { 194 public: 195 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; 196 197 void ComputeWithReader(OpKernelContext* context, 198 ReaderInterface* reader) override { 199 Tensor* output = nullptr; 200 OP_REQUIRES_OK(context, 201 context->allocate_output("state", TensorShape({}), &output)); 202 OP_REQUIRES_OK(context, 203 reader->SerializeState(&output->scalar<string>()())); 204 } 205 }; 206 207 REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU), 208 ReaderSerializeStateOp); 209 REGISTER_KERNEL_BUILDER(Name("ReaderSerializeStateV2").Device(DEVICE_CPU), 210 ReaderSerializeStateOp); 211 212 class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel { 213 public: 214 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; 215 216 void ComputeWithReader(OpKernelContext* context, 217 ReaderInterface* reader) override { 218 const Tensor* tensor; 219 OP_REQUIRES_OK(context, context->input("state", &tensor)); 220 OP_REQUIRES( 221 context, TensorShapeUtils::IsScalar(tensor->shape()), 222 errors::InvalidArgument("Reader state must be scalar, but had shape: ", 223 tensor->shape().DebugString())); 224 OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<string>()())); 225 } 226 }; 227 228 REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU), 229 ReaderRestoreStateOp); 230 REGISTER_KERNEL_BUILDER(Name("ReaderRestoreStateV2").Device(DEVICE_CPU), 231 ReaderRestoreStateOp); 232 233 class ReaderResetOp : public ReaderVerbSyncOpKernel { 234 public: 235 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; 236 237 void ComputeWithReader(OpKernelContext* context, 238 ReaderInterface* reader) override { 239 OP_REQUIRES_OK(context, reader->Reset()); 240 } 241 }; 242 243 REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp); 244 REGISTER_KERNEL_BUILDER(Name("ReaderResetV2").Device(DEVICE_CPU), 245 ReaderResetOp); 246 247 } // namespace tensorflow 248