1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/partial_tensor_shape.h" 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/kernels/data/dataset.h" 19 #include "tensorflow/core/kernels/data/stats_aggregator.h" 20 #include "tensorflow/core/lib/random/random.h" 21 22 namespace tensorflow { 23 namespace { 24 25 // This op defines a `Dataset` that passes through its input elements and 26 // records the latency of producing each element in the context's 27 // `StatsAggregator`. 28 // 29 // TODO(mrry): It is likely that many *StatsDatasetOp kernels will have the 30 // same or similar structure. We should abstract the common boilerplate into 31 // a base case and/or investigate how to make general-purpose *StatsDatasetOp 32 // kernels that use TensorFlow functions to represent their logic. For example, 33 // if the performance were adequate, we might replace this kernel with an 34 // implementation that executes functions before and after the `GetNext()` call 35 // on the input, each executing an op that gets the current time and performing 36 // the subtraction. 37 class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { 38 public: 39 explicit LatencyStatsDatasetOp(OpKernelConstruction* ctx) 40 : UnaryDatasetOpKernel(ctx) {} 41 42 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 43 DatasetBase** output) override { 44 string tag; 45 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); 46 *output = new Dataset(ctx, input, std::move(tag)); 47 } 48 49 private: 50 class Dataset : public GraphDatasetBase { 51 public: 52 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag) 53 : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) { 54 input_->Ref(); 55 } 56 57 ~Dataset() override { input_->Unref(); } 58 59 std::unique_ptr<IteratorBase> MakeIterator( 60 const string& prefix) const override { 61 return std::unique_ptr<IteratorBase>( 62 new Iterator({this, strings::StrCat(prefix, "::LatencyStats")})); 63 } 64 65 const DataTypeVector& output_dtypes() const override { 66 return input_->output_dtypes(); 67 } 68 const std::vector<PartialTensorShape>& output_shapes() const override { 69 return input_->output_shapes(); 70 } 71 72 string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; } 73 74 protected: 75 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 76 Node** output) const override { 77 Node* input_node; 78 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); 79 Node* tag_node; 80 TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); 81 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output)); 82 return Status::OK(); 83 } 84 85 private: 86 class Iterator : public DatasetIterator<Dataset> { 87 public: 88 explicit Iterator(const Params& params) 89 : DatasetIterator<Dataset>(params), 90 input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} 91 92 Status GetNextInternal(IteratorContext* ctx, 93 std::vector<Tensor>* out_tensors, 94 bool* end_of_sequence) override { 95 tf_shared_lock l(mu_); 96 uint64 start = ctx->env()->NowMicros(); 97 Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); 98 uint64 end = ctx->env()->NowMicros(); 99 auto stats_aggregator = ctx->stats_aggregator(); 100 if (stats_aggregator && !*end_of_sequence) { 101 ctx->stats_aggregator()->AddToHistogram( 102 dataset()->tag_, {static_cast<double>(end - start)}); 103 } 104 return s; 105 } 106 107 protected: 108 Status SaveInternal(IteratorStateWriter* writer) override { 109 mutex_lock l(mu_); 110 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 111 return Status::OK(); 112 } 113 114 Status RestoreInternal(IteratorContext* ctx, 115 IteratorStateReader* reader) override { 116 mutex_lock l(mu_); 117 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 118 return Status::OK(); 119 } 120 121 private: 122 mutex mu_; 123 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 124 }; 125 126 const DatasetBase* const input_; 127 const string tag_; 128 }; 129 }; 130 131 class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { 132 public: 133 explicit BytesProducedStatsDatasetOp(OpKernelConstruction* ctx) 134 : UnaryDatasetOpKernel(ctx) {} 135 136 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 137 DatasetBase** output) override { 138 string tag; 139 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); 140 *output = new Dataset(ctx, input, std::move(tag)); 141 } 142 143 private: 144 class Dataset : public GraphDatasetBase { 145 public: 146 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag) 147 : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) { 148 input_->Ref(); 149 } 150 151 ~Dataset() override { input_->Unref(); } 152 153 std::unique_ptr<IteratorBase> MakeIterator( 154 const string& prefix) const override { 155 return std::unique_ptr<IteratorBase>(new Iterator( 156 {this, strings::StrCat(prefix, "::BytesProducedStats")})); 157 } 158 159 const DataTypeVector& output_dtypes() const override { 160 return input_->output_dtypes(); 161 } 162 const std::vector<PartialTensorShape>& output_shapes() const override { 163 return input_->output_shapes(); 164 } 165 166 string DebugString() override { 167 return "BytesProducedStatsDatasetOp::Dataset"; 168 } 169 170 protected: 171 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 172 Node** output) const override { 173 Node* input_node; 174 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); 175 Node* tag_node; 176 TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); 177 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output)); 178 return Status::OK(); 179 } 180 181 private: 182 class Iterator : public DatasetIterator<Dataset> { 183 public: 184 explicit Iterator(const Params& params) 185 : DatasetIterator<Dataset>(params), 186 input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} 187 188 Status GetNextInternal(IteratorContext* ctx, 189 std::vector<Tensor>* out_tensors, 190 bool* end_of_sequence) override { 191 tf_shared_lock l(mu_); 192 Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); 193 auto stats_aggregator = ctx->stats_aggregator(); 194 if (stats_aggregator && s.ok() && !*end_of_sequence) { 195 size_t total_bytes = 0; 196 for (const Tensor& t : *out_tensors) { 197 total_bytes += t.TotalBytes(); 198 } 199 ctx->stats_aggregator()->AddToHistogram( 200 dataset()->tag_, {static_cast<double>(total_bytes)}); 201 } 202 return s; 203 } 204 205 protected: 206 Status SaveInternal(IteratorStateWriter* writer) override { 207 mutex_lock l(mu_); 208 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 209 return Status::OK(); 210 } 211 212 Status RestoreInternal(IteratorContext* ctx, 213 IteratorStateReader* reader) override { 214 mutex_lock l(mu_); 215 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 216 return Status::OK(); 217 } 218 219 private: 220 mutex mu_; 221 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 222 }; 223 224 const DatasetBase* const input_; 225 const string tag_; 226 }; 227 }; 228 229 REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU), 230 LatencyStatsDatasetOp); 231 REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU), 232 BytesProducedStatsDatasetOp); 233 234 } // namespace 235 } // namespace tensorflow 236