Home | History | Annotate | Download | only in data
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
     17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
     18 
     19 #include "tensorflow/core/common_runtime/function.h"
     20 #include "tensorflow/core/framework/dataset.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor_shape.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/kernels/ops_util.h"
     25 
     26 namespace tensorflow {
     27 namespace data {
     28 
     29 class IteratorResource;
     30 
     31 class IteratorHandleOp : public OpKernel {
     32  public:
     33   explicit IteratorHandleOp(OpKernelConstruction* ctx);
     34 
     35   // The resource is deleted from the resource manager only when it is private
     36   // to kernel. Ideally the resource should be deleted when it is no longer held
     37   // by anyone, but it would break backward compatibility.
     38   ~IteratorHandleOp() override;
     39 
     40   void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_);
     41 
     42  private:
     43   // During the first Compute(), resource is either created or looked up using
     44   // shared_name. In the latter case, the resource found should be verified if
     45   // it is compatible with this op's configuration. The verification may fail in
     46   // cases such as two graphs asking queues of the same shared name to have
     47   // inconsistent capacities.
     48   Status VerifyResource(IteratorResource* resource);
     49 
     50   template <typename To, typename From>  // use like this: down_cast<T*>(foo);
     51   static inline To down_cast(From* f) {  // so we only accept pointers
     52     static_assert(
     53         (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
     54         "target type not derived from source type");
     55 
     56     // We skip the assert and hence the dynamic_cast if RTTI is disabled.
     57 #if !defined(__GNUC__) || defined(__GXX_RTTI)
     58     // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
     59     assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
     60 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
     61     return static_cast<To>(f);
     62   }
     63 
     64   FunctionLibraryRuntime* CreatePrivateFLR(
     65       OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
     66       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
     67       std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr);
     68 
     69   mutex mu_;
     70   ContainerInfo cinfo_;  // Written once under mu_ then constant afterwards.
     71   IteratorResource* resource_ GUARDED_BY(mu_) = nullptr;
     72   DataTypeVector output_dtypes_;
     73   std::vector<PartialTensorShape> output_shapes_;
     74   const int graph_def_version_;
     75   string name_;
     76 };
     77 
     78 // Like IteratorHandleOp, but creates handles which are never shared, and does
     79 // not hold a reference to these handles. The latter is important for eager
     80 // execution, since OpKernel instances generally live as long as the program
     81 // running them.
     82 class AnonymousIteratorHandleOp : public OpKernel {
     83  public:
     84   explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
     85 
     86   void Compute(OpKernelContext* context) override;
     87 
     88  private:
     89   // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
     90   // instances.
     91   static mutex static_resource_lookup_mutex_;
     92   // current_id_ is just a hint for creating unique names. If it turns out
     93   // there's a collision (e.g. because another AnonymousIteratorHandleOp
     94   // instance is generating handles) we'll just skip that id.
     95   static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
     96   DataTypeVector output_dtypes_;
     97   std::vector<PartialTensorShape> output_shapes_;
     98   const int graph_def_version_;
     99 };
    100 
    101 class MakeIteratorOp : public OpKernel {
    102  public:
    103   explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    104 
    105   void Compute(OpKernelContext* ctx) override;
    106 };
    107 
    108 class IteratorGetNextOp : public AsyncOpKernel {
    109  public:
    110   explicit IteratorGetNextOp(OpKernelConstruction* ctx)
    111       : AsyncOpKernel(ctx),
    112         background_worker_(ctx->env(), "tf_data_iterator_get_next") {}
    113 
    114   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
    115 
    116  private:
    117   BackgroundWorker background_worker_;
    118 };
    119 
    120 class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
    121  public:
    122   explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
    123       : AsyncOpKernel(ctx),
    124         background_worker_(ctx->env(),
    125                            "tf_data_iterator_get_next_as_optional") {
    126     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
    127     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
    128   }
    129 
    130   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
    131 
    132  private:
    133   BackgroundWorker background_worker_;
    134   DataTypeVector output_types_;
    135   std::vector<PartialTensorShape> output_shapes_;
    136 };
    137 
    138 class IteratorGetNextSyncOp : public OpKernel {
    139  public:
    140   explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    141 
    142   void Compute(OpKernelContext* ctx) override;
    143 };
    144 
    145 class IteratorToStringHandleOp : public OpKernel {
    146  public:
    147   explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
    148       : OpKernel(ctx) {}
    149 
    150   void Compute(OpKernelContext* ctx) override;
    151 };
    152 
    153 class IteratorFromStringHandleOp : public OpKernel {
    154  public:
    155   explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx);
    156 
    157   void Compute(OpKernelContext* ctx) override;
    158 
    159  private:
    160   DataTypeVector output_dtypes_;
    161   std::vector<PartialTensorShape> output_shapes_;
    162 };
    163 
    164 }  // namespace data
    165 }  // namespace tensorflow
    166 
    167 #endif  // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
    168