Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/partial_tensor_shape.h"
     17 #include "tensorflow/core/framework/tensor.h"
     18 #include "tensorflow/core/kernels/data/dataset.h"
     19 #include "tensorflow/core/kernels/data/stats_aggregator.h"
     20 #include "tensorflow/core/lib/random/random.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 
     25 // This op defines a `Dataset` that passes through its input elements and
     26 // records the latency of producing each element in the context's
     27 // `StatsAggregator`.
     28 //
     29 // TODO(mrry): It is likely that many *StatsDatasetOp kernels will have the
     30 // same or similar structure. We should abstract the common boilerplate into
     31 // a base case and/or investigate how to make general-purpose *StatsDatasetOp
     32 // kernels that use TensorFlow functions to represent their logic. For example,
     33 // if the performance were adequate, we might replace this kernel with an
     34 // implementation that executes functions before and after the `GetNext()` call
     35 // on the input, each executing an op that gets the current time and performing
     36 // the subtraction.
     37 class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
     38  public:
     39   explicit LatencyStatsDatasetOp(OpKernelConstruction* ctx)
     40       : UnaryDatasetOpKernel(ctx) {}
     41 
     42   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     43                    DatasetBase** output) override {
     44     string tag;
     45     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
     46     *output = new Dataset(ctx, input, std::move(tag));
     47   }
     48 
     49  private:
     50   class Dataset : public GraphDatasetBase {
     51    public:
     52     explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
     53         : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
     54       input_->Ref();
     55     }
     56 
     57     ~Dataset() override { input_->Unref(); }
     58 
     59     std::unique_ptr<IteratorBase> MakeIterator(
     60         const string& prefix) const override {
     61       return std::unique_ptr<IteratorBase>(
     62           new Iterator({this, strings::StrCat(prefix, "::LatencyStats")}));
     63     }
     64 
     65     const DataTypeVector& output_dtypes() const override {
     66       return input_->output_dtypes();
     67     }
     68     const std::vector<PartialTensorShape>& output_shapes() const override {
     69       return input_->output_shapes();
     70     }
     71 
     72     string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; }
     73 
     74    protected:
     75     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
     76                               Node** output) const override {
     77       Node* input_node;
     78       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
     79       Node* tag_node;
     80       TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
     81       TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
     82       return Status::OK();
     83     }
     84 
     85    private:
     86     class Iterator : public DatasetIterator<Dataset> {
     87      public:
     88       explicit Iterator(const Params& params)
     89           : DatasetIterator<Dataset>(params),
     90             input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
     91 
     92       Status GetNextInternal(IteratorContext* ctx,
     93                              std::vector<Tensor>* out_tensors,
     94                              bool* end_of_sequence) override {
     95         tf_shared_lock l(mu_);
     96         uint64 start = ctx->env()->NowMicros();
     97         Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
     98         uint64 end = ctx->env()->NowMicros();
     99         auto stats_aggregator = ctx->stats_aggregator();
    100         if (stats_aggregator && !*end_of_sequence) {
    101           ctx->stats_aggregator()->AddToHistogram(
    102               dataset()->tag_, {static_cast<double>(end - start)});
    103         }
    104         return s;
    105       }
    106 
    107      protected:
    108       Status SaveInternal(IteratorStateWriter* writer) override {
    109         mutex_lock l(mu_);
    110         TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
    111         return Status::OK();
    112       }
    113 
    114       Status RestoreInternal(IteratorContext* ctx,
    115                              IteratorStateReader* reader) override {
    116         mutex_lock l(mu_);
    117         TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
    118         return Status::OK();
    119       }
    120 
    121      private:
    122       mutex mu_;
    123       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    124     };
    125 
    126     const DatasetBase* const input_;
    127     const string tag_;
    128   };
    129 };
    130 
    131 class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
    132  public:
    133   explicit BytesProducedStatsDatasetOp(OpKernelConstruction* ctx)
    134       : UnaryDatasetOpKernel(ctx) {}
    135 
    136   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
    137                    DatasetBase** output) override {
    138     string tag;
    139     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
    140     *output = new Dataset(ctx, input, std::move(tag));
    141   }
    142 
    143  private:
    144   class Dataset : public GraphDatasetBase {
    145    public:
    146     explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
    147         : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
    148       input_->Ref();
    149     }
    150 
    151     ~Dataset() override { input_->Unref(); }
    152 
    153     std::unique_ptr<IteratorBase> MakeIterator(
    154         const string& prefix) const override {
    155       return std::unique_ptr<IteratorBase>(new Iterator(
    156           {this, strings::StrCat(prefix, "::BytesProducedStats")}));
    157     }
    158 
    159     const DataTypeVector& output_dtypes() const override {
    160       return input_->output_dtypes();
    161     }
    162     const std::vector<PartialTensorShape>& output_shapes() const override {
    163       return input_->output_shapes();
    164     }
    165 
    166     string DebugString() override {
    167       return "BytesProducedStatsDatasetOp::Dataset";
    168     }
    169 
    170    protected:
    171     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
    172                               Node** output) const override {
    173       Node* input_node;
    174       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
    175       Node* tag_node;
    176       TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
    177       TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
    178       return Status::OK();
    179     }
    180 
    181    private:
    182     class Iterator : public DatasetIterator<Dataset> {
    183      public:
    184       explicit Iterator(const Params& params)
    185           : DatasetIterator<Dataset>(params),
    186             input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
    187 
    188       Status GetNextInternal(IteratorContext* ctx,
    189                              std::vector<Tensor>* out_tensors,
    190                              bool* end_of_sequence) override {
    191         tf_shared_lock l(mu_);
    192         Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
    193         auto stats_aggregator = ctx->stats_aggregator();
    194         if (stats_aggregator && s.ok() && !*end_of_sequence) {
    195           size_t total_bytes = 0;
    196           for (const Tensor& t : *out_tensors) {
    197             total_bytes += t.TotalBytes();
    198           }
    199           ctx->stats_aggregator()->AddToHistogram(
    200               dataset()->tag_, {static_cast<double>(total_bytes)});
    201         }
    202         return s;
    203       }
    204 
    205      protected:
    206       Status SaveInternal(IteratorStateWriter* writer) override {
    207         mutex_lock l(mu_);
    208         TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
    209         return Status::OK();
    210       }
    211 
    212       Status RestoreInternal(IteratorContext* ctx,
    213                              IteratorStateReader* reader) override {
    214         mutex_lock l(mu_);
    215         TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
    216         return Status::OK();
    217       }
    218 
    219      private:
    220       mutex mu_;
    221       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    222     };
    223 
    224     const DatasetBase* const input_;
    225     const string tag_;
    226   };
    227 };
    228 
    229 REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
    230                         LatencyStatsDatasetOp);
    231 REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
    232                         BytesProducedStatsDatasetOp);
    233 
    234 }  // namespace
    235 }  // namespace tensorflow
    236