Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/common_runtime/function.h"
     16 #include "tensorflow/core/common_runtime/graph_runner.h"
     17 #include "tensorflow/core/common_runtime/renamed_device.h"
     18 #include "tensorflow/core/common_runtime/threadpool_device.h"
     19 #include "tensorflow/core/framework/iterator.pb.h"
     20 #include "tensorflow/core/framework/partial_tensor_shape.h"
     21 #include "tensorflow/core/framework/resource_op_kernel.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/variant_op_registry.h"
     24 #include "tensorflow/core/graph/graph_constructor.h"
     25 #include "tensorflow/core/kernels/data/dataset.h"
     26 #include "tensorflow/core/kernels/data/stats_aggregator.h"
     27 #include "tensorflow/core/kernels/ops_util.h"
     28 #include "tensorflow/core/lib/core/threadpool.h"
     29 #include "tensorflow/core/lib/gtl/cleanup.h"
     30 #include "tensorflow/core/lib/random/random.h"
     31 #include "tensorflow/core/lib/strings/strcat.h"
     32 #include "tensorflow/core/lib/strings/stringprintf.h"
     33 #include "tensorflow/core/platform/env.h"
     34 #include "tensorflow/core/public/session_options.h"
     35 
     36 namespace tensorflow {
     37 
     38 namespace {
     39 
     40 // See documentation in ../ops/dataset_ops.cc for a high-level
     41 // description of the following ops.
     42 
     43 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
     44 
     45 Status VerifyTypesMatch(const DataTypeVector& expected,
     46                         const DataTypeVector& received) {
     47   if (expected.size() != received.size()) {
     48     return errors::InvalidArgument(
     49         "Number of components does not match: expected ", expected.size(),
     50         " types but got ", received.size(), ".");
     51   }
     52   for (size_t i = 0; i < expected.size(); ++i) {
     53     if (expected[i] != received[i]) {
     54       return errors::InvalidArgument("Data type mismatch at component ", i,
     55                                      ": expected ", DataTypeString(expected[i]),
     56                                      " but got ", DataTypeString(received[i]),
     57                                      ".");
     58     }
     59   }
     60   return Status::OK();
     61 }
     62 
     63 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
     64                               const std::vector<PartialTensorShape>& received) {
     65   if (expected.size() != received.size()) {
     66     return errors::InvalidArgument(
     67         "Number of components does not match: expected ", expected.size(),
     68         " shapes but got ", received.size(), ".");
     69   }
     70   for (size_t i = 0; i < expected.size(); ++i) {
     71     if (!expected[i].IsCompatibleWith(received[i])) {
     72       return errors::InvalidArgument("Incompatible shapes at component ", i,
     73                                      ": expected ", expected[i].DebugString(),
     74                                      " but got ", received[i].DebugString(),
     75                                      ".");
     76     }
     77   }
     78 
     79   return Status::OK();
     80 }
     81 
     82 class IteratorResource : public ResourceBase {
     83  public:
     84   IteratorResource(const DataTypeVector& output_dtypes,
     85                    const std::vector<PartialTensorShape>& output_shapes,
     86                    const int /*unused: graph_def_version*/,
     87                    std::unique_ptr<DeviceMgr> device_mgr,
     88                    std::unique_ptr<FunctionLibraryDefinition> flib_def,
     89                    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
     90                    FunctionLibraryRuntime* lib)
     91       : device_mgr_(std::move(device_mgr)),
     92         flib_def_(std::move(flib_def)),
     93         pflr_(std::move(pflr)),
     94         lib_(lib),
     95         iterator_(nullptr),
     96         output_dtypes_(output_dtypes),
     97         output_shapes_(output_shapes) {}
     98 
     99   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
    100                  bool* end_of_sequence) {
    101     std::shared_ptr<IteratorBase> captured_iterator(iterator_);
    102     if (captured_iterator) {
    103       if (lib_ != nullptr) {
    104         ctx->set_lib(lib_);
    105       }
    106       return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
    107     } else {
    108       return errors::FailedPrecondition(
    109           "GetNext() failed because the iterator has not been initialized. "
    110           "Ensure that you have run the initializer operation for this "
    111           "iterator before getting the next element.");
    112     }
    113   }
    114 
    115   Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
    116     std::shared_ptr<IteratorBase> captured_iterator(iterator_);
    117     if (captured_iterator) {
    118       return captured_iterator->Save(ctx, writer);
    119     } else {
    120       return errors::FailedPrecondition(
    121           "Save() failed because the iterator has not been initialized. "
    122           "Ensure that you have run the initializer operation for this "
    123           "iterator before saving it.");
    124     }
    125   }
    126 
    127   Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
    128     string serialized_graph_def;
    129     TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
    130                                           &serialized_graph_def));
    131     GraphDef graph_def;
    132     if (!graph_def.ParseFromString(serialized_graph_def)) {
    133       return errors::Internal("Error parsing dataset GraphDef.");
    134     }
    135     string output_node;
    136     TF_RETURN_IF_ERROR(reader->ReadScalar(
    137         GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
    138     DatasetBase* dataset = nullptr;
    139     Graph graph(OpRegistry::Global());
    140     TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
    141     std::vector<Tensor> outputs;
    142     GraphRunner graph_runner(ctx->env());
    143 
    144     // Build a new FLR that knows about the functions in the graph.
    145     std::shared_ptr<FunctionLibraryDefinition> flib_def(
    146         new FunctionLibraryDefinition(
    147             *ctx->function_library()->GetFunctionLibraryDefinition()));
    148     TF_RETURN_IF_ERROR(flib_def->AddLibrary(graph_def.library()));
    149 
    150     TF_RETURN_IF_ERROR(
    151         graph_runner.Run(&graph, lib_, {}, {output_node}, &outputs));
    152     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
    153 
    154     TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
    155     std::shared_ptr<IteratorBase> captured_iterator(iterator_);
    156 
    157     if (captured_iterator) {
    158       IteratorContext::Params params;
    159       params.env = ctx->env();
    160       params.runner = *(ctx->runner());
    161       params.function_library = flib_def;
    162       params.lib = lib_;
    163       DeviceBase* device = lib_->device();
    164       params.allocator_getter = [device](AllocatorAttributes attrs) {
    165         return device->GetAllocator(attrs);
    166       };
    167       IteratorContext iter_ctx(std::move(params));
    168 
    169       TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader));
    170       mutex_lock l(mu_);
    171       lib_def_ = std::move(flib_def);
    172       return Status::OK();
    173     } else {
    174       return errors::FailedPrecondition(
    175           "Failed to restore iterator. Make sure the checkpoint ",
    176           "is not corrupt. If the checkpoint does not contain the GraphDef, ",
    177           "you will need to initialize your iterator before restoring.");
    178     }
    179   }
    180 
    181   std::shared_ptr<const FunctionLibraryDefinition> function_library() {
    182     tf_shared_lock l(mu_);
    183     return lib_def_;
    184   }
    185 
    186   // Transfers ownership of iterator to this. This method is thread-safe.
    187   Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
    188     if (iterator) {
    189       TF_RETURN_IF_ERROR(
    190           VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
    191       TF_RETURN_IF_ERROR(
    192           VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
    193     }
    194     iterator_.reset(iterator.release());
    195     return Status::OK();
    196   }
    197 
    198   void set_stats_aggregator(std::shared_ptr<StatsAggregator> stats_aggregator) {
    199     mutex_lock l(mu_);
    200     stats_aggregator_ = std::move(stats_aggregator);
    201   }
    202 
    203   std::shared_ptr<StatsAggregator> stats_aggregator() {
    204     tf_shared_lock l(mu_);
    205     return stats_aggregator_;
    206   }
    207 
    208   string DebugString() override { return "Iterator resource"; }
    209 
    210   const DataTypeVector& output_dtypes() const { return output_dtypes_; }
    211 
    212   const std::vector<PartialTensorShape>& output_shapes() const {
    213     return output_shapes_;
    214   }
    215 
    216  private:
    217   // The following (device_mgr_, flib_def_, pflr_) are only used when the
    218   // IteratorResource is shared between sessions and in that case we create
    219   // a new FLR. Otherwise these are set to null.
    220   std::unique_ptr<DeviceMgr> device_mgr_;
    221   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
    222   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
    223   FunctionLibraryRuntime* lib_ = nullptr;  // not owned.
    224   std::shared_ptr<IteratorBase> iterator_;
    225   mutex mu_;
    226   std::shared_ptr<StatsAggregator> stats_aggregator_ GUARDED_BY(mu_);
    227   std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
    228   const DataTypeVector output_dtypes_;
    229   const std::vector<PartialTensorShape> output_shapes_;
    230 };
    231 
    232 // Helper class for reading data from a VariantTensorData object.
    233 class VariantTensorDataReader : public IteratorStateReader {
    234  public:
    235   explicit VariantTensorDataReader(const VariantTensorData* data)
    236       : data_(data) {
    237     PreProcess();
    238   }
    239 
    240   // Returns OK iff the initialization was successful, i.e.,
    241   // pre-processing did not have errors.
    242   Status status() const { return status_; }
    243 
    244   Status ReadScalar(StringPiece key, int64* val) override {
    245     return ReadScalarInternal(key, val);
    246   }
    247 
    248   Status ReadScalar(StringPiece key, string* val) override {
    249     return ReadScalarInternal(key, val);
    250   }
    251 
    252   Status ReadTensor(StringPiece key, Tensor* val) override {
    253     return ReadTensorInternal(key, val);
    254   }
    255 
    256   bool Contains(StringPiece key) override {
    257     return map_.find(key.ToString()) != map_.end();
    258   }
    259 
    260  private:
    261   void PreProcess() {
    262     string metadata;
    263     data_->get_metadata(&metadata);
    264     IteratorStateMetadata proto;
    265     if (!proto.ParseFromString(metadata)) {
    266       status_ = errors::Internal("Error parsing IteratorStateMetadata.");
    267       return;
    268     }
    269     size_t num_entries = proto.keys_size();
    270     CHECK_EQ(num_entries, data_->tensors_size());
    271     for (size_t i = 0; i < num_entries; i++) {
    272       map_[proto.keys(i)] = i;
    273     }
    274   }
    275 
    276   template <typename T>
    277   Status ReadScalarInternal(StringPiece key, T* val) {
    278     if (map_.find(key.ToString()) == map_.end()) {
    279       return errors::NotFound(key);
    280     }
    281     *val = data_->tensors(map_[key.ToString()]).scalar<T>()();
    282     return Status::OK();
    283   }
    284 
    285   Status ReadTensorInternal(StringPiece key, Tensor* val) {
    286     if (map_.find(key.ToString()) == map_.end()) {
    287       return errors::NotFound(key);
    288     }
    289     *val = data_->tensors(map_[key.ToString()]);
    290     return Status::OK();
    291   }
    292 
    293   std::map<string, size_t> map_;
    294   const VariantTensorData* data_;  // Not owned.
    295   Status status_;
    296 };
    297 
    298 // Helper class for writing data to a VariantTensorData object.
    299 class VariantTensorDataWriter : public IteratorStateWriter {
    300  public:
    301   // Does not take ownership of data.
    302   explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
    303 
    304   Status WriteScalar(StringPiece key, const int64 val) override {
    305     return WriteScalarInternal(key, val);
    306   }
    307 
    308   Status WriteScalar(StringPiece key, const string& val) override {
    309     return WriteScalarInternal(key, val);
    310   }
    311 
    312   Status WriteTensor(StringPiece key, const Tensor& val) override {
    313     return WriteTensorInternal(key, val);
    314   }
    315 
    316   // Writes the metadata to `data_`.
    317   Status Flush() {
    318     string metadata;
    319     if (!metadata_proto_.SerializeToString(&metadata)) {
    320       return errors::Internal("Unable to serialize IteratorStateMetadata.");
    321     }
    322     data_->set_metadata(metadata);
    323     return Status::OK();
    324   }
    325 
    326  private:
    327   template <typename T>
    328   Status WriteScalarInternal(StringPiece key, const T& val) {
    329     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
    330     val_t.scalar<T>()() = val;
    331     return WriteTensorInternal(key, val_t);
    332   }
    333 
    334   Status WriteTensorInternal(StringPiece key, const Tensor& val) {
    335     // Write key to the metadata proto. This gets written to `data_`
    336     // when `Flush()` is called. We do this lazily to avoid multiple
    337     // serialization calls.
    338     metadata_proto_.add_keys(key.ToString());
    339 
    340     // Update tensors.
    341     *(data_->add_tensors()) = val;
    342     return Status::OK();
    343   }
    344 
    345   VariantTensorData* data_;
    346   // TODO(srbs): Set the version string.
    347   IteratorStateMetadata metadata_proto_;
    348 };
    349 
    350 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
    351 // The get() method returns an IteratorStateReader which can be used
    352 // to restore iterator state.
    353 //
    354 // Usage example:
    355 //
    356 // Encoding:
    357 //
    358 //   Tensor t(DT_VARIANT, TensorShape({}));
    359 //   t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
    360 //
    361 // Encode() sets the type_name of the VariantTensorData object to
    362 // IteratorStateVariant::TypeName().
    363 //
    364 // Decoding:
    365 //
    366 //   Variant v = <VariantTensorDataProto object>;
    367 //   DecodeUnaryVariant(&v);
    368 //   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
    369 //   iterator_resource->Restore(ctx, wrapper->get())
    370 //
    371 // The type_name of the VariantTensorData object to be decoded must
    372 // match IteratorStateVariant::TypeName().
    373 class IteratorStateVariant {
    374  public:
    375   IteratorStateVariant() : data_(nullptr) {}
    376   IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
    377     if (other.data_) {
    378       Decode(*other.data_);
    379     }
    380   }
    381   // Initializes this object with the current state of the iterator so
    382   // that it can be written on the next call to Encode().
    383   Status InitializeFromIterator(OpKernelContext* ctx,
    384                                 IteratorResource* iterator_resource) {
    385     data_.reset(new VariantTensorData());
    386     data_->set_type_name(TypeName());
    387     VariantTensorDataWriter writer(data_.get());
    388     TF_RETURN_IF_ERROR(iterator_resource->Save(ctx, &writer));
    389     TF_RETURN_IF_ERROR(writer.Flush());
    390     return Status::OK();
    391   }
    392   string TypeName() const { return kIteratorVariantTypeName; }
    393   void Encode(VariantTensorData* data) const { *data = *data_; }
    394   bool Decode(const VariantTensorData& data) {
    395     if (data.type_name() != TypeName()) {
    396       return false;
    397     }
    398     std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
    399     *tensor_data = data;
    400     std::unique_ptr<VariantTensorDataReader> reader(
    401         new VariantTensorDataReader(tensor_data.get()));
    402     status_ = reader->status();
    403     if (!status_.ok()) {
    404       return false;
    405     }
    406     data_ = std::move(tensor_data);
    407     reader_ = std::move(reader);
    408     return true;
    409   }
    410   IteratorStateReader* get() { return reader_.get(); }
    411   Status status() const { return status_; }
    412   string DebugString() const {
    413     if (data_) {
    414       return strings::StrCat("IteratorStateVariant<",
    415                              "data: ", data_->DebugString(),
    416                              " status: ", status_.ToString(), ">");
    417     } else {
    418       return strings::StrCat("IteratorStateVariant<empty>");
    419     }
    420   }
    421 
    422  private:
    423   std::unique_ptr<IteratorStateReader> reader_;
    424   Status status_;
    425   std::unique_ptr<VariantTensorData> data_;
    426 };
    427 
    428 // Register the reader class in the global variant decode_fn registry
    429 // so that a Variant containing a serialized representation of iterator state
    430 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
    431 // to manually decode the returned Variant using MaybeDecodeAndCopy in
    432 // DeserializeIteratorOp which is not recommended.
    433 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
    434                                        kIteratorVariantTypeName);
    435 
    436 class IteratorHandleOp : public OpKernel {
    437  public:
    438   explicit IteratorHandleOp(OpKernelConstruction* ctx)
    439       : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
    440     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
    441     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
    442     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
    443   }
    444 
    445   // The resource is deleted from the resource manager only when it is private
    446   // to kernel. Ideally the resource should be deleted when it is no longer held
    447   // by anyone, but it would break backward compatibility.
    448   ~IteratorHandleOp() override {
    449     if (resource_ != nullptr) {
    450       resource_->Unref();
    451       if (cinfo_.resource_is_private_to_kernel()) {
    452         if (!cinfo_.resource_manager()
    453                  ->template Delete<IteratorResource>(cinfo_.container(),
    454                                                      cinfo_.name())
    455                  .ok()) {
    456           // Do nothing; the resource can have been deleted by session resets.
    457         }
    458       }
    459     }
    460   }
    461 
    462   void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
    463     {
    464       mutex_lock l(mu_);
    465       if (resource_ == nullptr) {
    466         FunctionLibraryRuntime* lib;
    467         std::unique_ptr<DeviceMgr> device_mgr(nullptr);
    468         std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
    469         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
    470         // If the iterator is shared then we construct a new FLR, and pass that
    471         // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
    472         // functions from the iterator. We may add this functionality if there
    473         // is sufficient demand, but it will require a significant refactoring.
    474         if (!name_.empty()) {
    475           lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
    476         } else {
    477           OP_REQUIRES_OK(context, context->function_library()->Clone(
    478                                       &flib_def, &pflr, &lib));
    479         }
    480 
    481         ResourceMgr* mgr = context->resource_manager();
    482         OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
    483 
    484         IteratorResource* resource;
    485         OP_REQUIRES_OK(
    486             context,
    487             mgr->LookupOrCreate<IteratorResource>(
    488                 cinfo_.container(), cinfo_.name(), &resource,
    489                 [lib, &device_mgr, &flib_def, &pflr,
    490                  this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    491                   *ret = new IteratorResource(
    492                       output_dtypes_, output_shapes_, graph_def_version_,
    493                       std::move(device_mgr), std::move(flib_def),
    494                       std::move(pflr), lib);
    495                   return Status::OK();
    496                 }));
    497 
    498         Status s = VerifyResource(resource);
    499         if (TF_PREDICT_FALSE(!s.ok())) {
    500           resource->Unref();
    501           context->SetStatus(s);
    502           return;
    503         }
    504 
    505         resource_ = resource;
    506       }
    507     }
    508     OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
    509                                 context, 0, cinfo_.container(), cinfo_.name(),
    510                                 MakeTypeIndex<IteratorResource>()));
    511   }
    512 
    513  private:
    514   // During the first Compute(), resource is either created or looked up using
    515   // shared_name. In the latter case, the resource found should be verified if
    516   // it is compatible with this op's configuration. The verification may fail in
    517   // cases such as two graphs asking queues of the same shared name to have
    518   // inconsistent capacities.
    519   Status VerifyResource(IteratorResource* resource) {
    520     TF_RETURN_IF_ERROR(
    521         VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
    522     TF_RETURN_IF_ERROR(
    523         VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
    524     return Status::OK();
    525   }
    526 
    527   template <typename To, typename From>  // use like this: down_cast<T*>(foo);
    528   static inline To down_cast(From* f) {  // so we only accept pointers
    529     static_assert(
    530         (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
    531         "target type not derived from source type");
    532 
    533     // We skip the assert and hence the dynamic_cast if RTTI is disabled.
    534 #if !defined(__GNUC__) || defined(__GXX_RTTI)
    535     // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
    536     assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
    537 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
    538     return static_cast<To>(f);
    539   }
    540 
    541   FunctionLibraryRuntime* CreatePrivateFLR(
    542       OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
    543       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
    544       std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
    545     // Wrap the existing device in order to see any captured resources
    546     // in its resource manager. The existing device will outlive the
    547     // IteratorResource, because we are storing the IteratorResource
    548     // in that device's resource manager.
    549     Device* wrapped_device = RenamedDevice::NewRenamedDevice(
    550         ctx->device()->name(), down_cast<Device*>(ctx->device()),
    551         false /* owns_underlying */, false /* isolate_session_state */);
    552     device_mgr->reset(new DeviceMgr({wrapped_device}));
    553     flib_def->reset(new FunctionLibraryDefinition(
    554         *ctx->function_library()->GetFunctionLibraryDefinition()));
    555     pflr->reset(new ProcessFunctionLibraryRuntime(
    556         device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
    557         {} /* TODO(mrry): OptimizerOptions? */,
    558         nullptr /* TODO(mrry): ClusterFLR */));
    559 
    560     return (*pflr)->GetFLR(ctx->device()->name());
    561   }
    562 
    563   mutex mu_;
    564   ContainerInfo cinfo_;  // Written once under mu_ then constant afterwards.
    565   IteratorResource* resource_ GUARDED_BY(mu_) = nullptr;
    566   DataTypeVector output_dtypes_;
    567   std::vector<PartialTensorShape> output_shapes_;
    568   const int graph_def_version_;
    569   string name_;
    570 };
    571 
    572 class MakeIteratorOp : public OpKernel {
    573  public:
    574   explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    575 
    576   void Compute(OpKernelContext* ctx) override {
    577     DatasetBase* dataset;
    578     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
    579     IteratorResource* iterator_resource;
    580     OP_REQUIRES_OK(
    581         ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
    582     OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(
    583                             dataset->MakeIterator("Iterator")));
    584     iterator_resource->Unref();
    585   }
    586 };
    587 
    588 class ToSingleElementOp : public AsyncOpKernel {
    589  public:
    590   explicit ToSingleElementOp(OpKernelConstruction* ctx)
    591       : AsyncOpKernel(ctx),
    592         thread_pool_(new thread::ThreadPool(
    593             ctx->env(), ThreadOptions(),
    594             strings::StrCat("to_single_element_op_thread_",
    595                             SanitizeThreadSuffix(name())),
    596             1 /* num_threads */, false /* low_latency_hint */)) {}
    597 
    598   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    599     // The call to `iterator->GetNext()` may block and depend on an
    600     // inter-op thread pool thread, so we issue the call from the
    601     // owned thread pool.
    602     thread_pool_->Schedule([ctx, done]() {
    603       DatasetBase* dataset;
    604       OP_REQUIRES_OK_ASYNC(
    605           ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
    606       auto iterator = dataset->MakeIterator("SingleElementIterator");
    607 
    608       IteratorContext::Params params;
    609       params.env = ctx->env();
    610       params.runner = *(ctx->runner());
    611       params.lib = ctx->function_library();
    612       DeviceBase* device = ctx->function_library()->device();
    613       params.allocator_getter = [device](AllocatorAttributes attrs) {
    614         return device->GetAllocator(attrs);
    615       };
    616 
    617       IteratorContext iter_ctx(std::move(params));
    618 
    619       std::vector<Tensor> components;
    620       components.reserve(dataset->output_dtypes().size());
    621       bool end_of_sequence;
    622 
    623       OP_REQUIRES_OK_ASYNC(
    624           ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
    625           done);
    626       OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
    627                         errors::InvalidArgument("Dataset was empty."), done);
    628 
    629       for (int i = 0; i < components.size(); ++i) {
    630         // TODO(mrry): Check that the shapes match the shape attrs.
    631         ctx->set_output(i, components[i]);
    632       }
    633 
    634       components.clear();
    635       OP_REQUIRES_OK_ASYNC(
    636           ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
    637           done);
    638       OP_REQUIRES_ASYNC(
    639           ctx, end_of_sequence,
    640           errors::InvalidArgument("Dataset had more than one element."), done);
    641 
    642       done();
    643     });
    644   }
    645 
    646  private:
    647   std::unique_ptr<thread::ThreadPool> thread_pool_;
    648 };
    649 
    650 class OneShotIteratorOp : public AsyncOpKernel {
    651  public:
    652   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
    653       : AsyncOpKernel(ctx),
    654         thread_pool_(new thread::ThreadPool(
    655             ctx->env(), ThreadOptions(),
    656             strings::StrCat("one_shot_iterator_initialization_thread_",
    657                             SanitizeThreadSuffix(name())),
    658             1 /* num_threads */, false /* low_latency_hint */)),
    659         graph_def_version_(ctx->graph_def_version())
    660 
    661   {
    662     string shared_name;
    663     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
    664     OP_REQUIRES(ctx, shared_name.empty(),
    665                 errors::InvalidArgument("OneShotIteratorOp does not currently "
    666                                         "support the 'shared_name' attr."));
    667     OP_REQUIRES_OK(ctx,
    668                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
    669     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
    670     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
    671   }
    672 
    673   ~OneShotIteratorOp() override {
    674     if (iterator_resource_ != nullptr) {
    675       iterator_resource_->Unref();
    676       if (!cinfo_.resource_manager()
    677                ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
    678                .ok()) {
    679         // Do nothing; the resource can have been deleted by session resets.
    680       }
    681     }
    682   }
    683 
    684   // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
    685   // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
    686   // does not provide access to the `OpKernelContext*` and we need
    687   // this to invoke the factory function, it's not possible to
    688   // implement this kernel by implementing `CreateResource()`.
    689   // Furthermore, due to the fact that this kernel might block when
    690   // running the initialization function, we must implement this
    691   // kernel as an async kernel.
    692   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    693     {
    694       mutex_lock l(mu_);
    695       if (iterator_resource_ == nullptr && initialization_status_.ok()) {
    696         // The initialization thread will call `done`.
    697         if (!initialization_started_) {
    698           // TODO(mrry): Convert the initialization code to use
    699           // callbacks instead of wasting a thread.
    700           thread_pool_->Schedule([this, ctx, done]() { Init(ctx, done); });
    701           initialization_started_ = true;
    702         } else {
    703           done_callbacks_.emplace_back(ctx, std::move(done));
    704         }
    705         return;
    706       }
    707     }
    708     ProduceOutput(ctx, std::move(done));
    709   }
    710 
    711  private:
    712   void Init(OpKernelContext* ctx, DoneCallback done) {
    713     IteratorResource* iterator = nullptr;
    714     ContainerInfo cinfo;
    715     Status s = TryInit(ctx, &iterator, &cinfo);
    716 
    717     std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
    718     {
    719       mutex_lock l(mu_);
    720       if (s.ok()) {
    721         iterator_resource_ = iterator;
    722         cinfo_ = cinfo;
    723       }
    724       initialization_status_ = s;
    725       std::swap(done_callbacks_, callbacks_to_run);
    726     }
    727 
    728     for (auto&& ctx_done : callbacks_to_run) {
    729       ProduceOutput(ctx_done.first, std::move(ctx_done.second));
    730     }
    731     ProduceOutput(ctx, std::move(done));
    732   }
    733 
    734   Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
    735                  ContainerInfo* cinfo) {
    736     TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
    737 
    738     FunctionLibraryRuntime* lib;
    739     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
    740     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
    741     TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib));
    742 
    743     // Create an IteratorResource that will hold the iterator for this op.
    744     TF_RETURN_IF_ERROR(
    745         ctx->resource_manager()->LookupOrCreate<IteratorResource>(
    746             cinfo->container(), cinfo->name(), iterator,
    747             [lib, this, &flib_def, &pflr](IteratorResource** ret)
    748                 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    749                   *ret = new IteratorResource(
    750                       output_dtypes_, output_shapes_, graph_def_version_,
    751                       nullptr, std::move(flib_def), std::move(pflr), lib);
    752                   return Status::OK();
    753                 }));
    754 
    755     core::ScopedUnref unref_iterator(*iterator);
    756 
    757     TF_RETURN_IF_ERROR(
    758         VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
    759     TF_RETURN_IF_ERROR(
    760         VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
    761 
    762     // Call the dataset_factory_func_ to create a new dataset,
    763     // over which this op will iterate.
    764     FunctionLibraryRuntime::Handle f_handle;
    765     TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
    766         dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
    767         &f_handle));
    768     FunctionLibraryRuntime::Options opts;
    769     opts.cancellation_manager = ctx->cancellation_manager();
    770     // Choose a step ID that is guaranteed not to clash with any
    771     // Session-generated step ID. DirectSession only generates
    772     // non-negative step IDs (contiguous, starting from 0), and
    773     // MasterSession generates 56-bit random step IDs whose MSB is
    774     // always 0, so a negative random step ID should suffice.
    775     opts.step_id = -std::abs(static_cast<int64>(random::New64()));
    776     ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
    777       ctx->resource_manager()->Cleanup(name).IgnoreError();
    778     });
    779     opts.step_container = &step_container;
    780     opts.runner = ctx->runner();
    781     Notification n;
    782     Status factory_status;
    783     std::vector<Tensor> return_values;
    784     ctx->function_library()->Run(opts, f_handle, {}, &return_values,
    785                                  [&n, &factory_status](Status s) {
    786                                    factory_status.Update(s);
    787                                    n.Notify();
    788                                  });
    789     n.WaitForNotification();
    790     TF_RETURN_IF_ERROR(factory_status);
    791     if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
    792         !TensorShapeUtils::IsScalar(return_values[0].shape())) {
    793       return errors::InvalidArgument(
    794           "The `dataset_factory` function must return "
    795           "a single scalar of dtype DT_VARIANT.");
    796     }
    797 
    798     // Create an iterator for the dataset that was created in the
    799     // factory function.
    800     DatasetBase* dataset;
    801     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
    802     TF_RETURN_IF_ERROR(
    803         (*iterator)->set_iterator(dataset->MakeIterator("Iterator")));
    804 
    805     (*iterator)->Ref();
    806     return Status::OK();
    807   }
    808 
    809   void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
    810     Tensor* handle;
    811     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
    812                          done);
    813     Status s;
    814     {
    815       mutex_lock l(mu_);
    816       s = initialization_status_;
    817       if (s.ok()) {
    818         handle->scalar<ResourceHandle>()() =
    819             MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
    820                                                  cinfo_.name());
    821       }
    822     }
    823     OP_REQUIRES_OK_ASYNC(ctx, s, done);
    824     done();
    825   }
    826 
    827   NameAttrList dataset_factory_func_;
    828   DataTypeVector output_dtypes_;
    829   std::vector<PartialTensorShape> output_shapes_;
    830 
    831   std::unique_ptr<thread::ThreadPool> thread_pool_;
    832 
    833   mutex mu_;
    834   ContainerInfo cinfo_ GUARDED_BY(mu_);
    835   IteratorResource* iterator_resource_ GUARDED_BY(mu_) = nullptr;
    836 
    837   bool initialization_started_ GUARDED_BY(mu_) = false;
    838   Status initialization_status_ GUARDED_BY(mu_);
    839   std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
    840       GUARDED_BY(mu_);
    841   const int graph_def_version_;
    842 };
    843 
    844 class IteratorGetNextOp : public AsyncOpKernel {
    845  public:
    846   explicit IteratorGetNextOp(OpKernelConstruction* ctx)
    847       : AsyncOpKernel(ctx),
    848         thread_pool_(new thread::ThreadPool(
    849             ctx->env(), ThreadOptions(),
    850             strings::StrCat("iterator_get_next_thread_",
    851                             SanitizeThreadSuffix(name())),
    852             1 /* num_threads */, false /* low_latency_hint */)) {}
    853 
    854   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    855     IteratorResource* iterator;
    856     OP_REQUIRES_OK_ASYNC(
    857         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
    858     // The call to `iterator->GetNext()` may block and depend on an
    859     // inter-op thread pool thread, so we issue the call from the
    860     // owned thread pool.
    861     thread_pool_->Schedule(std::bind(
    862         [this, ctx, iterator](DoneCallback done) {
    863           core::ScopedUnref unref_iterator(iterator);
    864 
    865           std::vector<Tensor> components;
    866           bool end_of_sequence = false;
    867 
    868           IteratorContext::Params params;
    869           params.env = ctx->env();
    870           params.stats_aggregator_getter = [iterator]() {
    871             return iterator->stats_aggregator();
    872           };
    873           params.runner = *(ctx->runner());
    874           params.function_library = iterator->function_library();
    875           DeviceBase* device = ctx->function_library()->device();
    876           params.allocator_getter = [device](AllocatorAttributes attrs) {
    877             return device->GetAllocator(attrs);
    878           };
    879           IteratorContext iter_ctx(std::move(params));
    880 
    881           OP_REQUIRES_OK_ASYNC(
    882               ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
    883               done);
    884           OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
    885                             errors::OutOfRange("End of sequence"), done);
    886 
    887           for (int i = 0; i < components.size(); ++i) {
    888             // TODO(mrry): Check that the shapes match the shape attrs.
    889             ctx->set_output(i, components[i]);
    890           }
    891 
    892           done();
    893         },
    894         std::move(done)));
    895   }
    896 
    897  private:
    898   std::unique_ptr<thread::ThreadPool> thread_pool_;
    899 };
    900 
    901 class IteratorGetNextSyncOp : public OpKernel {
    902  public:
    903   explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    904 
    905   void Compute(OpKernelContext* ctx) override {
    906     IteratorResource* iterator;
    907     OP_REQUIRES_OK(ctx,
    908                    LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
    909     core::ScopedUnref unref_iterator(iterator);
    910 
    911     std::vector<Tensor> components;
    912     bool end_of_sequence = false;
    913 
    914     IteratorContext::Params params;
    915     params.env = ctx->env();
    916     params.stats_aggregator_getter = [iterator]() {
    917       return iterator->stats_aggregator();
    918     };
    919     params.runner = *(ctx->runner());
    920     params.function_library = iterator->function_library();
    921     DeviceBase* device = ctx->function_library()->device();
    922     params.allocator_getter = [device](AllocatorAttributes attrs) {
    923       return device->GetAllocator(attrs);
    924     };
    925     IteratorContext iter_ctx(std::move(params));
    926 
    927     OP_REQUIRES_OK(ctx,
    928                    iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
    929     OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
    930 
    931     for (int i = 0; i < components.size(); ++i) {
    932       // TODO(mrry): Check that the shapes match the shape attrs.
    933       ctx->set_output(i, components[i]);
    934     }
    935   }
    936 };
    937 
    938 class IteratorToStringHandleOp : public OpKernel {
    939  public:
    940   explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
    941       : OpKernel(ctx) {}
    942 
    943   void Compute(OpKernelContext* ctx) override {
    944     const Tensor& resource_handle_t = ctx->input(0);
    945     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
    946                 errors::InvalidArgument("resource_handle must be a scalar"));
    947 
    948     // Validate that the handle corresponds to a real resource, and
    949     // that it is an IteratorResource.
    950     IteratorResource* iterator_resource;
    951     OP_REQUIRES_OK(
    952         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
    953     iterator_resource->Unref();
    954 
    955     Tensor* string_handle_t;
    956     OP_REQUIRES_OK(ctx,
    957                    ctx->allocate_output(0, TensorShape({}), &string_handle_t));
    958     string_handle_t->scalar<string>()() =
    959         resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
    960   }
    961 };
    962 
    963 class IteratorFromStringHandleOp : public OpKernel {
    964  public:
    965   explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx)
    966       : OpKernel(ctx) {
    967     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
    968     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
    969     OP_REQUIRES(
    970         ctx,
    971         output_dtypes_.empty() || output_shapes_.empty() ||
    972             output_dtypes_.size() == output_shapes_.size(),
    973         errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
    974                                 "are set, they must have the same length."));
    975   }
    976 
    977   void Compute(OpKernelContext* ctx) override {
    978     const Tensor& string_handle_t = ctx->input(0);
    979     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
    980                 errors::InvalidArgument("string_handle must be a scalar"));
    981 
    982     ResourceHandle resource_handle;
    983     OP_REQUIRES(
    984         ctx,
    985         resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
    986         errors::InvalidArgument(
    987             "Could not parse string_handle as a valid ResourceHandle"));
    988 
    989     OP_REQUIRES(
    990         ctx, resource_handle.device() == ctx->device()->attributes().name(),
    991         errors::InvalidArgument("Attempted create an iterator on device \"",
    992                                 ctx->device()->attributes().name(),
    993                                 "\" from handle defined on device \"",
    994                                 resource_handle.device(), "\""));
    995 
    996     // Validate that the handle corresponds to a real resource, and
    997     // that it is an IteratorResource.
    998     IteratorResource* iterator_resource;
    999     OP_REQUIRES_OK(ctx,
   1000                    LookupResource(ctx, resource_handle, &iterator_resource));
   1001     core::ScopedUnref unref_iterator(iterator_resource);
   1002     if (!output_dtypes_.empty()) {
   1003       OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
   1004                                            iterator_resource->output_dtypes()));
   1005     }
   1006     if (!output_shapes_.empty()) {
   1007       OP_REQUIRES_OK(
   1008           ctx, VerifyShapesCompatible(output_shapes_,
   1009                                       iterator_resource->output_shapes()));
   1010     }
   1011 
   1012     Tensor* resource_handle_t;
   1013     OP_REQUIRES_OK(
   1014         ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
   1015     resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
   1016   }
   1017 
   1018  private:
   1019   DataTypeVector output_dtypes_;
   1020   std::vector<PartialTensorShape> output_shapes_;
   1021 };
   1022 
   1023 class SerializeIteratorOp : public OpKernel {
   1024  public:
   1025   explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
   1026 
   1027   void Compute(OpKernelContext* ctx) override {
   1028     const Tensor& resource_handle_t = ctx->input(0);
   1029     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
   1030                 errors::InvalidArgument("resource_handle must be a scalar"));
   1031 
   1032     // Validate that the handle corresponds to a real resource, and
   1033     // that it is an IteratorResource.
   1034     IteratorResource* iterator_resource;
   1035     OP_REQUIRES_OK(
   1036         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
   1037     core::ScopedUnref unref_iterator(iterator_resource);
   1038     Tensor* variant_t;
   1039     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
   1040     IteratorStateVariant v;
   1041     OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource));
   1042     variant_t->scalar<Variant>()() = v;
   1043   }
   1044 };
   1045 
   1046 class DeserializeIteratorOp : public OpKernel {
   1047  public:
   1048   explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
   1049 
   1050   void Compute(OpKernelContext* ctx) override {
   1051     // Validate that the handle corresponds to a real resource, and
   1052     // that it is an IteratorResource.
   1053     IteratorResource* iterator_resource;
   1054     OP_REQUIRES_OK(
   1055         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
   1056 
   1057     Variant variant = ctx->input(1).scalar<Variant>()();
   1058     auto* wrapper = variant.get<IteratorStateVariant>();
   1059     OP_REQUIRES(ctx, wrapper != nullptr,
   1060                 errors::InvalidArgument(
   1061                     "DeserializeIteratorOp: Unable to parse variant tensor."));
   1062     OP_REQUIRES_OK(ctx, wrapper->status());
   1063     OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
   1064   }
   1065 };
   1066 
   1067 class IteratorSetStatsAggregatorOp : public OpKernel {
   1068  public:
   1069   explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx)
   1070       : OpKernel(ctx) {}
   1071 
   1072   void Compute(OpKernelContext* ctx) override {
   1073     IteratorResource* iterator_resource;
   1074     OP_REQUIRES_OK(
   1075         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
   1076     core::ScopedUnref unref_iterator(iterator_resource);
   1077 
   1078     StatsAggregatorResource* stats_aggregator_resource;
   1079     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
   1080                                        &stats_aggregator_resource));
   1081     core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
   1082     // TODO(mrry): Consider allowing multiple StatsAggregator ops to
   1083     // subscribe to updates, and/or unsubscribing.
   1084     OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(),
   1085                 errors::FailedPrecondition(
   1086                     "Iterator already associated with a StatsAggregator"));
   1087     iterator_resource->set_stats_aggregator(
   1088         stats_aggregator_resource->stats_aggregator());
   1089   }
   1090 };
   1091 
   1092 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
   1093 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
   1094                         MakeIteratorOp);
   1095 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
   1096                         ToSingleElementOp);
   1097 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
   1098                         OneShotIteratorOp);
   1099 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
   1100                         IteratorGetNextOp);
   1101 REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
   1102                         IteratorGetNextSyncOp);
   1103 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
   1104                         IteratorToStringHandleOp);
   1105 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
   1106                         IteratorFromStringHandleOp);
   1107 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
   1108                         SerializeIteratorOp);
   1109 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
   1110                         DeserializeIteratorOp);
   1111 REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU),
   1112                         IteratorSetStatsAggregatorOp);
   1113 
   1114 }  // namespace
   1115 
   1116 }  // namespace tensorflow
   1117