Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/partial_tensor_shape.h"
     16 #include "tensorflow/core/framework/tensor.h"
     17 #include "tensorflow/core/kernels/data/dataset.h"
     18 
     19 namespace tensorflow {
     20 
     21 namespace {
     22 
     23 // See documentation in ../ops/dataset_ops.cc for a high-level
     24 // description of the following op.
     25 
     26 class RangeDatasetOp : public DatasetOpKernel {
     27  public:
     28   explicit RangeDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
     29 
     30   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
     31     int64 start;
     32     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "start", &start));
     33 
     34     int64 stop;
     35     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stop", &stop));
     36 
     37     int64 step;
     38     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "step", &step));
     39     OP_REQUIRES(ctx, step != 0,
     40                 errors::InvalidArgument("step must be a non-zero integer."));
     41 
     42     *output = new Dataset(ctx, start, stop, step);
     43   }
     44 
     45  private:
     46   class Dataset : public GraphDatasetBase {
     47    public:
     48     Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
     49         : GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {}
     50 
     51     std::unique_ptr<IteratorBase> MakeIterator(
     52         const string& prefix) const override {
     53       return std::unique_ptr<IteratorBase>(
     54           new Iterator({this, strings::StrCat(prefix, "::Range")}));
     55     }
     56 
     57     const DataTypeVector& output_dtypes() const override {
     58       static DataTypeVector* dtypes = new DataTypeVector({DT_INT64});
     59       return *dtypes;
     60     }
     61 
     62     const std::vector<PartialTensorShape>& output_shapes() const override {
     63       static std::vector<PartialTensorShape>* shapes =
     64           new std::vector<PartialTensorShape>({{}});
     65       return *shapes;
     66     }
     67 
     68     string DebugString() override {
     69       return strings::StrCat("RangeDatasetOp(", start_, ", ", stop_, ", ",
     70                              step_, ")::Dataset");
     71     }
     72 
     73    protected:
     74     Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
     75                               Node** output) const override {
     76       Node* start = nullptr;
     77       Node* stop = nullptr;
     78       Node* step = nullptr;
     79       TF_RETURN_IF_ERROR(b->AddScalar(start_, &start));
     80       TF_RETURN_IF_ERROR(b->AddScalar(stop_, &stop));
     81       TF_RETURN_IF_ERROR(b->AddScalar(step_, &step));
     82       TF_RETURN_IF_ERROR(b->AddDataset(this, {start, stop, step}, output));
     83       return Status::OK();
     84     }
     85 
     86    private:
     87     class Iterator : public DatasetIterator<Dataset> {
     88      public:
     89       explicit Iterator(const Params& params)
     90           : DatasetIterator<Dataset>(params) {
     91         next_ = params.dataset->start_;
     92       }
     93 
     94       Status GetNextInternal(IteratorContext* ctx,
     95                              std::vector<Tensor>* out_tensors,
     96                              bool* end_of_sequence) override {
     97         mutex_lock l(mu_);
     98         if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
     99             (dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
    100           *end_of_sequence = true;
    101           return Status::OK();
    102         }
    103         Tensor value_tensor(ctx->allocator({}), DT_INT64, {});
    104         value_tensor.scalar<int64>()() = next_;
    105         out_tensors->emplace_back(std::move(value_tensor));
    106         *end_of_sequence = false;
    107         next_ += dataset()->step_;
    108 
    109         return Status::OK();
    110       }
    111 
    112      protected:
    113       Status SaveInternal(IteratorStateWriter* writer) override {
    114         mutex_lock l(mu_);
    115         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_));
    116         return Status::OK();
    117       }
    118 
    119       Status RestoreInternal(IteratorContext* ctx,
    120                              IteratorStateReader* reader) override {
    121         mutex_lock l(mu_);
    122         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_));
    123         return Status::OK();
    124       }
    125 
    126      private:
    127       mutex mu_;
    128       int64 next_ GUARDED_BY(mu_);
    129     };
    130 
    131     const int64 start_;
    132     const int64 stop_;
    133     const int64 step_;
    134   };
    135 };
    136 
    137 REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU),
    138                         RangeDatasetOp);
    139 
    140 }  // namespace
    141 
    142 }  // namespace tensorflow
    143