1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/partial_tensor_shape.h" 16 #include "tensorflow/core/framework/tensor.h" 17 #include "tensorflow/core/kernels/data/dataset.h" 18 19 namespace tensorflow { 20 21 namespace { 22 23 // See documentation in ../ops/dataset_ops.cc for a high-level 24 // description of the following op. 25 26 class RangeDatasetOp : public DatasetOpKernel { 27 public: 28 explicit RangeDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 29 30 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 31 int64 start; 32 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "start", &start)); 33 34 int64 stop; 35 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stop", &stop)); 36 37 int64 step; 38 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "step", &step)); 39 OP_REQUIRES(ctx, step != 0, 40 errors::InvalidArgument("step must be a non-zero integer.")); 41 42 *output = new Dataset(ctx, start, stop, step); 43 } 44 45 private: 46 class Dataset : public GraphDatasetBase { 47 public: 48 Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step) 49 : GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {} 50 51 std::unique_ptr<IteratorBase> MakeIterator( 52 const string& prefix) const override { 53 return std::unique_ptr<IteratorBase>( 54 new Iterator({this, strings::StrCat(prefix, "::Range")})); 55 } 56 57 const DataTypeVector& output_dtypes() const override { 58 static DataTypeVector* dtypes = new DataTypeVector({DT_INT64}); 59 return *dtypes; 60 } 61 62 const std::vector<PartialTensorShape>& output_shapes() const override { 63 static std::vector<PartialTensorShape>* shapes = 64 new std::vector<PartialTensorShape>({{}}); 65 return *shapes; 66 } 67 68 string DebugString() override { 69 return strings::StrCat("RangeDatasetOp(", start_, ", ", stop_, ", ", 70 step_, ")::Dataset"); 71 } 72 73 protected: 74 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 75 Node** output) const override { 76 Node* start = nullptr; 77 Node* stop = nullptr; 78 Node* step = nullptr; 79 TF_RETURN_IF_ERROR(b->AddScalar(start_, &start)); 80 TF_RETURN_IF_ERROR(b->AddScalar(stop_, &stop)); 81 TF_RETURN_IF_ERROR(b->AddScalar(step_, &step)); 82 TF_RETURN_IF_ERROR(b->AddDataset(this, {start, stop, step}, output)); 83 return Status::OK(); 84 } 85 86 private: 87 class Iterator : public DatasetIterator<Dataset> { 88 public: 89 explicit Iterator(const Params& params) 90 : DatasetIterator<Dataset>(params) { 91 next_ = params.dataset->start_; 92 } 93 94 Status GetNextInternal(IteratorContext* ctx, 95 std::vector<Tensor>* out_tensors, 96 bool* end_of_sequence) override { 97 mutex_lock l(mu_); 98 if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) || 99 (dataset()->step_ < 0 && next_ <= dataset()->stop_)) { 100 *end_of_sequence = true; 101 return Status::OK(); 102 } 103 Tensor value_tensor(ctx->allocator({}), DT_INT64, {}); 104 value_tensor.scalar<int64>()() = next_; 105 out_tensors->emplace_back(std::move(value_tensor)); 106 *end_of_sequence = false; 107 next_ += dataset()->step_; 108 109 return Status::OK(); 110 } 111 112 protected: 113 Status SaveInternal(IteratorStateWriter* writer) override { 114 mutex_lock l(mu_); 115 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_)); 116 return Status::OK(); 117 } 118 119 Status RestoreInternal(IteratorContext* ctx, 120 IteratorStateReader* reader) override { 121 mutex_lock l(mu_); 122 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_)); 123 return Status::OK(); 124 } 125 126 private: 127 mutex mu_; 128 int64 next_ GUARDED_BY(mu_); 129 }; 130 131 const int64 start_; 132 const int64 stop_; 133 const int64 step_; 134 }; 135 }; 136 137 REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU), 138 RangeDatasetOp); 139 140 } // namespace 141 142 } // namespace tensorflow 143