Home | History | Annotate | Download | only in framework
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
     16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
     17 
     18 #include <memory>
     19 
     20 #include "tensorflow/core/framework/attr_value.pb.h"
     21 #include "tensorflow/core/framework/attr_value_util.h"
     22 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
     23 #include "tensorflow/core/framework/function.h"
     24 #include "tensorflow/core/framework/graph.pb.h"
     25 #include "tensorflow/core/framework/node_def.pb.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/types.pb.h"
     29 #include "tensorflow/core/framework/variant_encode_decode.h"
     30 #include "tensorflow/core/framework/variant_tensor_data.h"
     31 #include "tensorflow/core/lib/strings/str_util.h"
     32 #include "tensorflow/core/lib/strings/strcat.h"
     33 #include "tensorflow/core/platform/tracing.h"
     34 
     35 // Polymorphic datasets should support all primitive TensorFlow
     36 // types. Use this macro to expand `m(T)` once for each primitive type
     37 // `T`, e.g. to build a `switch` statement.
     38 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
     39 
     40 namespace tensorflow {
     41 
     42 // Interface for reading values from a key-value store.
     43 // Used for restoring iterator state.
     44 class IteratorStateReader {
     45  public:
     46   virtual Status ReadScalar(StringPiece key, int64* val) = 0;
     47   virtual Status ReadScalar(StringPiece key, string* val) = 0;
     48   virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
     49   virtual bool Contains(StringPiece key) = 0;
     50 
     51   virtual ~IteratorStateReader() {}
     52 };
     53 
     54 // Interface for writing values to a key-value store.
     55 // Used for saving iterator state.
     56 class IteratorStateWriter {
     57  public:
     58   virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
     59   virtual Status WriteScalar(StringPiece key, const string& val) = 0;
     60   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
     61 
     62   virtual ~IteratorStateWriter() {}
     63 };
     64 
     65 // Forward declarations to avoid introducing a dependency on headers in
     66 // "tensorflow/core/graph/...".
     67 class GraphDefBuilder;
     68 class GraphDatasetBase;
     69 class Node;
     70 
     71 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
     72 class GraphDefBuilderWrapper {
     73  public:
     74   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
     75 
     76   // Adds a Const node with scalar value to the Graph.
     77   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
     78   // non-null if the method returns with an OK status.
     79   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
     80   template <typename T>
     81   Status AddScalar(const T& val, Node** output) {
     82     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
     83     val_t.scalar<T>()() = val;
     84     AddTensorInternal(val_t, output);
     85     if (*output == nullptr) {
     86       return errors::Internal("AddScalar: Failed to build Const op.");
     87     }
     88     return Status::OK();
     89   }
     90 
     91   // Adds a Const node with vector value to the Graph.
     92   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
     93   // non-null if the method returns with an OK status.
     94   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
     95   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
     96   template <typename T>
     97   Status AddVector(const std::vector<T>& val, Node** output) {
     98     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
     99                           TensorShape({static_cast<int64>(val.size())}));
    100     for (int i = 0; i < val.size(); i++) {
    101       val_t.flat<T>()(i) = val[i];
    102     }
    103     AddTensorInternal(val_t, output);
    104     if (*output == nullptr) {
    105       return errors::Internal("AddVector: Failed to build Const op.");
    106     }
    107     return Status::OK();
    108   }
    109 
    110   // Adds a Const node with Tensor value to the Graph.
    111   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
    112   // non-null if the method returns with an OK status.
    113   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
    114   Status AddTensor(const Tensor& val, Node** output) {
    115     AddTensorInternal(val, output);
    116     if (*output == nullptr) {
    117       return errors::Internal("AddTensor: Failed to build Const op.");
    118     }
    119     return Status::OK();
    120   }
    121 
    122   Status AddDataset(const GraphDatasetBase* dataset,
    123                     const std::vector<Node*>& inputs, Node** output) {
    124     return AddDataset(dataset, inputs, {}, output);
    125   }
    126 
    127   // Adds a node corresponding to the `DatasetType` to the Graph.
    128   // Return value of `DatasetType::op_name()` is used as the op type for the
    129   // node.
    130   // Values for the output_types and output_shapes node attributes are also
    131   // written if those attributes are defined in the OpDef.
    132   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
    133   // non-null if the method returns with an OK status.
    134   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
    135   Status AddDataset(const GraphDatasetBase* dataset,
    136                     const std::vector<Node*>& inputs,
    137                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
    138                     Node** output) {
    139     std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
    140     for (int i = 0; i < inputs.size(); i++) {
    141       enumerated_inputs[i] = std::make_pair(i, inputs[i]);
    142     }
    143     return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
    144   }
    145 
    146   Status AddDataset(
    147       const GraphDatasetBase* dataset,
    148       const std::vector<std::pair<size_t, Node*>>& inputs,
    149       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
    150       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
    151       Node** output);
    152 
    153   // Adds a user-defined function with name `function_name` to the graph and
    154   // recursively adds all functions it references. If a function with a matching
    155   // name has already been added, returns with OK status. If a user-defined with
    156   // name `function_name` is not found in the FunctionLibraryDefinition, returns
    157   // an InvalidArgumentError. If the function with name `function_name` or any
    158   // of its dependent functions are stateful, returns an InvalidArgument error.
    159   Status AddFunction(OpKernelContext* ctx, const string& function_name);
    160 
    161   template <typename T>
    162   void BuildAttrValue(const T& value, AttrValue* attr) {
    163     SetAttrValue(value, attr);
    164   }
    165 
    166  private:
    167   void AddTensorInternal(const Tensor& val, Node** output);
    168 
    169   Status EnsureFunctionIsStateless(OpKernelContext* ctx,
    170                                    const string& function_name) const {
    171     const FunctionLibraryDefinition* lib_def =
    172         ctx->function_library()->GetFunctionLibraryDefinition();
    173     const FunctionDef* function_def = lib_def->Find(function_name);
    174     if (!function_def) {
    175       return errors::InvalidArgument("Unable to find FunctionDef for ",
    176                                      function_name, " in registry.");
    177     }
    178     for (const NodeDef& node_def : function_def->node_def()) {
    179       const OpDef* op_def;
    180       TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def));
    181       // TODO(b/65524810): Hack to allow functions to capture Dataset op
    182       // nodes needed for FlatMap. Currently, source datasets nodes have been
    183       // marked stateful to avoid constant folding since we do not have a
    184       // good way of serializing them.
    185       if (IsOpWhitelisted(op_def)) {
    186         continue;
    187       }
    188       if (op_def->is_stateful()) {
    189         return errors::InvalidArgument(
    190             "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ",
    191             "in function ", function_name, " is stateful. ",
    192             "Saving stateful functions is not supported yet.");
    193       }
    194     }
    195     return Status::OK();
    196   }
    197 
    198   // Returns whether an op has been whitelisted for use inside map_fns.
    199   // Uses a heuristic to whitelist source dataset ops which have been
    200   // marked stateful due to b/65524810.
    201   // Also looks up the `op_def->name` in the global
    202   // `WhitelistedStatefulOpRegistry`.
    203   bool IsOpWhitelisted(const OpDef* op_def) const {
    204     return (StringPiece(op_def->name()).ends_with("Dataset") &&
    205             op_def->output_arg_size() == 1 &&
    206             op_def->output_arg(0).type() == DT_VARIANT) ||
    207            dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
    208                op_def->name());
    209   }
    210 
    211   bool HasAttr(const string& op_type_name, const string& attr_name) const;
    212 
    213   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
    214     for (auto attr : op_def->attr()) {
    215       if (attr.name() == attr_name) {
    216         return true;
    217       }
    218     }
    219     return false;
    220   }
    221 
    222   Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) {
    223     if (attr_value.has_func()) {
    224       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
    225     } else if (attr_value.has_list()) {
    226       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
    227         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
    228       }
    229     }
    230     return Status::OK();
    231   }
    232 
    233   GraphDefBuilder* b_;
    234 };
    235 
    236 class StatsAggregator;
    237 
    238 // A cut-down version of OpKernelContext for running computations in
    239 // iterators. Note that we cannot simply use OpKernelContext here
    240 // because we might run computation in an iterator whose lifetime is
    241 // not nested within the lifetime of a single OpKernelContext
    242 // (e.g. asynchronous prefetching).
    243 //
    244 // TODO(mrry): We will probably need to support more of
    245 // OpKernelContext here. For example, should allocation be handled by
    246 // the IteratorContext?
    247 // TODO(mrry): We're making some daring assumptions about the lifetime
    248 // of the runner passed in here. A runner will be deleted when the original
    249 // step ends, but all existing runners only close over session-lifetime (or
    250 // longer-lived) state, so we can make a copy of the function. There's nothing
    251 // in the definition of the API from which we took the runner to guarantee that
    252 // what we are doing is safe. We should formalize the properties here.
    253 class IteratorContext {
    254  public:
    255   struct Params {
    256     // Interface to operating system functionality.
    257     Env* env;
    258 
    259     // Function call support.
    260     std::function<void(std::function<void()>)> runner = nullptr;
    261 
    262     // A function that returns the current `StatsAggregator` instance to be
    263     // used when recording statistics about the iterator.
    264     //
    265     // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator`
    266     // is a property of the `IteratorResource` (which this class does not know
    267     // about), and (ii) it can change after the `IteratorContext` has been
    268     // created. Better suggestions are welcome!
    269     std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter =
    270         nullptr;
    271 
    272     // The FunctionLibraryRuntime object to be used to make function calls.
    273     FunctionLibraryRuntime* lib = nullptr;
    274     std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr;
    275 
    276     // The Allocator to be used to allocate the output of an iterator.
    277     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
    278   };
    279 
    280   explicit IteratorContext(Params params) : params_(std::move(params)) {}
    281 
    282   Env* env() const { return params_.env; }
    283 
    284   std::function<void(std::function<void()>)>* runner() {
    285     return &params_.runner;
    286   }
    287 
    288   std::shared_ptr<StatsAggregator> stats_aggregator() {
    289     if (params_.stats_aggregator_getter) {
    290       return params_.stats_aggregator_getter();
    291     } else {
    292       return nullptr;
    293     }
    294   }
    295 
    296   std::shared_ptr<const FunctionLibraryDefinition> function_library() {
    297     return params_.function_library;
    298   }
    299 
    300   FunctionLibraryRuntime* lib() { return params_.lib; }
    301 
    302   void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; }
    303 
    304   Allocator* allocator(AllocatorAttributes attrs) {
    305     return params_.allocator_getter(attrs);
    306   }
    307 
    308  private:
    309   Params params_;
    310 };
    311 
    312 // Represents the current position in a range of outputs, where the
    313 // range of outputs is typically represented by an `DatasetBase`,
    314 // defined below.
    315 class IteratorBase {
    316  public:
    317   virtual ~IteratorBase() {}
    318 
    319   // Gets the next output from the range that this iterator is traversing.
    320   //
    321   // If at least one output remains in this iterator's range, that
    322   // output will be stored in `*out_tensors` and `false` will be
    323   // stored in `*end_of_sequence`.
    324   //
    325   // If no more outputs remain in this iterator's range, `true` will
    326   // be stored in `*end_of_sequence`, and the content of
    327   // `*out_tensors` will be undefined.
    328   //
    329   // This method is thread-safe.
    330   //
    331   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
    332   // potentially remove this method.
    333   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
    334                          bool* end_of_sequence) = 0;
    335 
    336   // Returns a vector of DataType values, representing the respective
    337   // element types of each tuple component in the outputs of this
    338   // iterator.
    339   virtual const DataTypeVector& output_dtypes() const = 0;
    340 
    341   // Returns a vector of tensor shapes, representing the respective
    342   // (and possibly partially defined) shapes of each tuple component
    343   // in the outputs of this iterator.
    344   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
    345 
    346   // Saves the state of this iterator.
    347   virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
    348     return SaveInternal(writer);
    349   }
    350 
    351   // Restores the state of this iterator.
    352   virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
    353     return RestoreInternal(ctx, reader);
    354   }
    355 
    356  protected:
    357   // This is needed so that sub-classes of IteratorBase can call
    358   // `SaveInternal` on their parent iterators, e.g., in
    359   // `RepeatDataasetOp::Dataset`.
    360   Status SaveParent(IteratorStateWriter* writer,
    361                     const std::unique_ptr<IteratorBase>& parent) {
    362     return parent->SaveInternal(writer);
    363   }
    364 
    365   // This is needed so that sub-classes of IteratorBase can call
    366   // `RestoreInternal` on their parent iterators, e.g., in
    367   // `RepeatDataasetOp::Dataset`.
    368   Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader,
    369                        const std::unique_ptr<IteratorBase>& parent) {
    370     return parent->RestoreInternal(ctx, reader);
    371   }
    372 
    373   // Saves the state of this iterator recursively.
    374   virtual Status SaveInternal(IteratorStateWriter* writer) {
    375     return errors::Unimplemented("SaveInternal");
    376   }
    377 
    378   // Restores the state of this iterator recursively.
    379   virtual Status RestoreInternal(IteratorContext* ctx,
    380                                  IteratorStateReader* reader) {
    381     return errors::Unimplemented("RestoreInternal");
    382   }
    383 };
    384 
    385 // Represents a (potentially infinite) range of outputs, where each
    386 // output is a tuple of tensors.
    387 class DatasetBase : public core::RefCounted {
    388  public:
    389   // Returns a new iterator for iterating over the range of elements in
    390   // this dataset.
    391   //
    392   // This method may be called multiple times on the same instance,
    393   // and the resulting iterators will have distinct state. Each
    394   // iterator will traverse all elements in this dataset from the
    395   // start.
    396   //
    397   // Ownership of the created iterator will be transferred to the caller.
    398   //
    399   // The prefix identifies the sequence of iterators leading up to the newly
    400   // created iterator.
    401   virtual std::unique_ptr<IteratorBase> MakeIterator(
    402       const string& prefix) const = 0;
    403 
    404   // Returns a vector of DataType values, representing the respective
    405   // element types of each tuple component in the outputs of this
    406   // dataset.
    407   virtual const DataTypeVector& output_dtypes() const = 0;
    408 
    409   // Returns a vector of tensor shapes, representing the respective
    410   // (and possibly partially defined) shapes of each tuple component
    411   // in the outputs of this dataset.
    412   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
    413 
    414   // A human-readable debug string for this dataset.
    415   virtual string DebugString() = 0;
    416 
    417   // Serializes the dataset and writes it to the `writer`.
    418   virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const {
    419     return errors::Unimplemented("DatasetBase::Save");
    420   }
    421 
    422  protected:
    423   // TODO(srbs): Ideally all graph related logic should reside in
    424   // GraphDatasetBase. However, that would require Datasets defined in all ops
    425   // to derive from GraphDatasetBase. Once that is done we can move
    426   // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase.
    427   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
    428    public:
    429     DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
    430     Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset,
    431                             Node** output) {
    432       return dataset->AsGraphDefInternal(ctx, this, output);
    433     }
    434   };
    435 
    436   virtual Status AsGraphDefInternal(OpKernelContext* ctx,
    437                                     DatasetGraphDefBuilder* b,
    438                                     Node** node) const {
    439     return AsGraphDefInternal(b, node);
    440   }
    441 
    442   virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
    443                                     Node** node) const {
    444     return errors::Unimplemented("AsGraphDefInternal");
    445   }
    446 };
    447 
    448 // Base-class for datasets that are built by ops.
    449 class GraphDatasetBase : public DatasetBase {
    450  public:
    451   GraphDatasetBase(OpKernelContext* ctx)
    452       : op_name_(ctx->op_kernel().type_string()) {}
    453 
    454   const string op_name() const { return op_name_; }
    455 
    456   Status Save(OpKernelContext* ctx,
    457               IteratorStateWriter* writer) const override {
    458     string serialized_graph_def;
    459     string output_node;
    460     TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node));
    461     TF_RETURN_IF_ERROR(
    462         writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
    463     TF_RETURN_IF_ERROR(
    464         writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
    465     return Status::OK();
    466   }
    467 
    468   // Key for storing the Dataset graph in the serialized format.
    469   static const char kDatasetGraphKey[];
    470 
    471   // Key for storing the output node of the Dataset graph in the serialized
    472   // format.
    473   static const char kDatasetGraphOutputNodeKey[];
    474 
    475  private:
    476   Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
    477                    string* output_node) const;
    478 
    479   const string op_name_;
    480 };
    481 
    482 // Represents an iterator that is associated with a particular parent dataset.
    483 template <class DatasetType>
    484 class DatasetIterator : public IteratorBase {
    485  public:
    486   struct Params {
    487     // Owns one reference on the shared dataset resource.
    488     const DatasetType* dataset;
    489 
    490     // Identifies the sequence of iterators leading up to this iterator.
    491     const string prefix;
    492   };
    493 
    494   explicit DatasetIterator(const Params& params) : params_(params) {
    495     params_.dataset->Ref();
    496   }
    497 
    498   ~DatasetIterator() override { params_.dataset->Unref(); }
    499 
    500   // The dataset from which this iterator was created.
    501   const DatasetType* dataset() const { return params_.dataset; }
    502 
    503   // The sequence of iterators leading up to this iterator.
    504   const string prefix() const { return params_.prefix; }
    505 
    506   const DataTypeVector& output_dtypes() const override {
    507     return params_.dataset->output_dtypes();
    508   }
    509 
    510   const std::vector<PartialTensorShape>& output_shapes() const override {
    511     return params_.dataset->output_shapes();
    512   }
    513 
    514   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
    515                  bool* end_of_sequence) final {
    516     port::Tracing::TraceMe activity(params_.prefix);
    517     Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
    518     if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
    519       s = errors::Internal(
    520           "Iterator \"", params_.prefix,
    521           "\" returned OutOfRange without setting `*end_of_sequence`. This "
    522           "indicates that an error may have occurred. Original message: ",
    523           s.error_message());
    524       LOG(ERROR) << s;
    525     }
    526     return s;
    527   }
    528 
    529   Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final {
    530     TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer));
    531     return IteratorBase::Save(ctx, writer);
    532   }
    533 
    534  protected:
    535   // Internal implementation of GetNext that is wrapped in tracing logic.
    536   virtual Status GetNextInternal(IteratorContext* ctx,
    537                                  std::vector<Tensor>* out_tensors,
    538                                  bool* end_of_sequence) = 0;
    539 
    540   string full_name(const string& name) const {
    541     return strings::StrCat(prefix(), ":", name);
    542   }
    543 
    544  private:
    545   Params params_;
    546 };
    547 
    548 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
    549 // graph execution engine.
    550 class DatasetOpKernel : public OpKernel {
    551  public:
    552   DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    553   void Compute(OpKernelContext* ctx) final;
    554 
    555  protected:
    556   // Subclasses should implement this method. It will be called during Compute
    557   // execution.
    558   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
    559 
    560   template <typename T>
    561   Status ParseScalarArgument(OpKernelContext* ctx,
    562                              const StringPiece& argument_name, T* output) {
    563     const Tensor* argument_t;
    564     TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
    565     if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
    566       return errors::InvalidArgument(argument_name, " must be a scalar");
    567     }
    568     *output = argument_t->scalar<T>()();
    569     return Status::OK();
    570   }
    571 };
    572 
    573 // Encapsulates the work required to plug unary Datasets into the core
    574 // TensorFlow graph execution engine.
    575 class UnaryDatasetOpKernel : public DatasetOpKernel {
    576  public:
    577   UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
    578 
    579  protected:
    580   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
    581   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
    582                            DatasetBase** output) = 0;
    583 };
    584 
    585 // Encapsulates the work required to plug binary Datasets into the core
    586 // TensorFlow graph execution engine.
    587 class BinaryDatasetOpKernel : public DatasetOpKernel {
    588  public:
    589   BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
    590 
    591  protected:
    592   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
    593   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
    594                            DatasetBase* another_input,
    595                            DatasetBase** output) = 0;
    596 };
    597 
    598 // Validates and extracts a `DatasetBase` object from `tensor`.
    599 //
    600 // `tensor` must have been written by a call to SetVariantTensorToDataset().
    601 //
    602 // The retrieved pointer is a borrowed reference to the dataset, which is owned
    603 // by the tensor. The consumer must either acquire its own reference to the
    604 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
    605 // destroyed or mutated while the retrieved pointer is in use.
    606 Status GetDatasetFromVariantTensor(const Tensor& tensor,
    607                                    DatasetBase** out_dataset);
    608 
    609 // Stores a `DatasetBase` object in `tensor`.
    610 //
    611 // The ownership of `dataset` is transferred to `tensor`.
    612 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
    613 
    614 }  // namespace tensorflow
    615 
    616 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
    617