Home | History | Annotate | Download | only in experimental
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <map>
     16 
     17 #include "tensorflow/core/common_runtime/function.h"
     18 #include "tensorflow/core/framework/dataset.h"
     19 #include "tensorflow/core/framework/partial_tensor_shape.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/kernels/data/captured_function.h"
     22 #include "tensorflow/core/kernels/data/window_dataset.h"
     23 #include "tensorflow/core/lib/random/random.h"
     24 
     25 namespace tensorflow {
     26 namespace data {
     27 namespace {
     28 
     29 // See documentation in ../../ops/dataset_ops.cc for a high-level
     30 // description of the following op.
     31 class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
     32  public:
     33   explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
     34       : UnaryDatasetOpKernel(ctx) {
     35     OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
     36     OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
     38     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     39     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
     40   }
     41 
     42   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     43                    DatasetBase** output) override {
     44     std::unique_ptr<CapturedFunction> captured_key_func;
     45     OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
     46                                                  "key_func_other_arguments",
     47                                                  &captured_key_func));
     48     std::unique_ptr<CapturedFunction> captured_reduce_func;
     49     OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
     50                                                  "reduce_func_other_arguments",
     51                                                  &captured_reduce_func));
     52     std::unique_ptr<CapturedFunction> captured_window_size_func;
     53     OP_REQUIRES_OK(ctx,
     54                    CapturedFunction::Create(window_size_func_, ctx,
     55                                             "window_size_func_other_arguments",
     56                                             &captured_window_size_func));
     57 
     58     *output = new Dataset(
     59         ctx, input, key_func_, reduce_func_, window_size_func_,
     60         std::move(captured_key_func), std::move(captured_reduce_func),
     61         std::move(captured_window_size_func), output_types_, output_shapes_);
     62   }
     63 
     64  private:
     65   class Dataset : public DatasetBase {
     66    public:
     67     Dataset(OpKernelContext* ctx, const DatasetBase* input,
     68             const NameAttrList& key_func, const NameAttrList& reduce_func,
     69             const NameAttrList& window_size_func,
     70             std::unique_ptr<CapturedFunction> captured_key_func,
     71             std::unique_ptr<CapturedFunction> captured_reduce_func,
     72             std::unique_ptr<CapturedFunction> captured_window_size_func,
     73             const DataTypeVector& output_types,
     74             const std::vector<PartialTensorShape>& output_shapes)
     75         : DatasetBase(DatasetContext(ctx)),
     76           input_(input),
     77           key_func_(key_func),
     78           reduce_func_(reduce_func),
     79           window_size_func_(window_size_func),
     80           captured_key_func_(std::move(captured_key_func)),
     81           captured_reduce_func_(std::move(captured_reduce_func)),
     82           captured_window_size_func_(std::move(captured_window_size_func)),
     83           output_types_(output_types),
     84           output_shapes_(output_shapes) {
     85       input_->Ref();
     86     }
     87 
     88     ~Dataset() override { input_->Unref(); }
     89 
     90     std::unique_ptr<IteratorBase> MakeIteratorInternal(
     91         const string& prefix) const override {
     92       return absl::make_unique<Iterator>(
     93           Iterator::Params{this, strings::StrCat(prefix, "::GroupByWindow")});
     94     }
     95 
     96     const DataTypeVector& output_dtypes() const override {
     97       return output_types_;
     98     }
     99     const std::vector<PartialTensorShape>& output_shapes() const override {
    100       return output_shapes_;
    101     }
    102 
    103     string DebugString() const override {
    104       return "GroupByWindowDatasetOp::Dataset";
    105     }
    106 
    107    protected:
    108     Status AsGraphDefInternal(SerializationContext* ctx,
    109                               DatasetGraphDefBuilder* b,
    110                               Node** output) const override {
    111       TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name()));
    112       TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name()));
    113       TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name()));
    114       Node* input_graph_node = nullptr;
    115       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
    116 
    117       std::vector<Node*> key_func_other_arguments_node;
    118       DataTypeVector key_func_other_arguments_types;
    119       TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
    120           ctx, b, captured_key_func_, &key_func_other_arguments_node,
    121           &key_func_other_arguments_types));
    122 
    123       std::vector<Node*> reduce_func_other_arguments_node;
    124       DataTypeVector reduce_func_other_arguments_types;
    125       TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
    126           ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node,
    127           &reduce_func_other_arguments_types));
    128 
    129       std::vector<Node*> window_size_func_other_arguments_node;
    130       DataTypeVector window_size_func_other_arguments_types;
    131       TF_RETURN_IF_ERROR(
    132           OtherArgumentsNodeAndType(ctx, b, captured_window_size_func_,
    133                                     &window_size_func_other_arguments_node,
    134                                     &window_size_func_other_arguments_types));
    135 
    136       AttrValue key_func;
    137       b->BuildAttrValue(key_func_, &key_func);
    138       AttrValue reduce_func;
    139       b->BuildAttrValue(reduce_func_, &reduce_func);
    140       AttrValue window_size_func;
    141       b->BuildAttrValue(window_size_func_, &window_size_func);
    142 
    143       AttrValue key_func_other_arguments_types_attr;
    144       b->BuildAttrValue(key_func_other_arguments_types,
    145                         &key_func_other_arguments_types_attr);
    146       AttrValue reduce_func_other_arguments_types_attr;
    147       b->BuildAttrValue(reduce_func_other_arguments_types,
    148                         &reduce_func_other_arguments_types_attr);
    149       AttrValue window_size_func_other_arguments_types_attr;
    150       b->BuildAttrValue(window_size_func_other_arguments_types,
    151                         &window_size_func_other_arguments_types_attr);
    152 
    153       TF_RETURN_IF_ERROR(b->AddDataset(
    154           this, {{0, input_graph_node}},
    155           {{1, key_func_other_arguments_node},
    156            {2, reduce_func_other_arguments_node},
    157            {3, window_size_func_other_arguments_node}},
    158           {{"key_func", key_func},
    159            {"reduce_func", reduce_func},
    160            {"window_size_func", window_size_func},
    161            {"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
    162            {"Treduce_func_other_arguments",
    163             reduce_func_other_arguments_types_attr},
    164            {"Twindow_size_func_other_arguments",
    165             window_size_func_other_arguments_types_attr}},
    166           output));
    167       return Status::OK();
    168     }
    169 
    170    private:
    171     class Iterator : public DatasetIterator<Dataset> {
    172      public:
    173       explicit Iterator(const Params& params)
    174           : DatasetIterator<Dataset>(params) {}
    175 
    176       Status Initialize(IteratorContext* ctx) override {
    177         TF_RETURN_IF_ERROR(
    178             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
    179         TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(
    180             ctx, &instantiated_key_func_));
    181         TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(
    182             ctx, &instantiated_reduce_func_));
    183         TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Instantiate(
    184             ctx, &instantiated_window_size_func_));
    185         return Status::OK();
    186       }
    187 
    188       Status GetNextInternal(IteratorContext* ctx,
    189                              std::vector<Tensor>* out_tensors,
    190                              bool* end_of_sequence) override {
    191         mutex_lock l(mu_);
    192         do {
    193           if (current_group_iterator_) {
    194             // We are currently processing a group, so try to get the
    195             // next element.
    196             bool end_of_group;
    197             TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
    198                 ctx, out_tensors, &end_of_group));
    199             if (!end_of_group) {
    200               // Produce the subelement as output.
    201               *end_of_sequence = false;
    202               return Status::OK();
    203             }
    204             // We have reached the end of the current group, so maybe move on
    205             // to the next group.
    206             current_group_iterator_.reset();
    207             groups_.erase(current_key_);
    208           }
    209 
    210           // Iterate through the input dataset until we get a full
    211           // group, or reach the end.
    212           while (!end_of_input_) {
    213             std::vector<Tensor> next_input_element;
    214             TF_RETURN_IF_ERROR(
    215                 input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
    216 
    217             if (!end_of_input_) {
    218               // Run the key function on the input element to identify its
    219               // group.
    220               std::vector<Tensor> key_func_output;
    221               TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
    222                   ctx, next_input_element, &key_func_output));
    223 
    224               if (key_func_output.size() != 1 ||
    225                   key_func_output[0].dtype() != DT_INT64 ||
    226                   key_func_output[0].NumElements() != 1) {
    227                 // TODO(b/78665031): Support non-int64 keys.
    228                 return errors::InvalidArgument(
    229                     "`key_func` must return a scalar int64.");
    230               }
    231               const int64 key = key_func_output[0].scalar<int64>()();
    232 
    233               if (window_sizes_.find(key) == window_sizes_.end()) {
    234                 // Run the window size function on the key to identify its
    235                 // window size.
    236                 std::vector<Tensor> window_size_func_output;
    237                 TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
    238                     ctx, std::move(key_func_output), &window_size_func_output));
    239 
    240                 if (window_size_func_output.size() != 1 ||
    241                     window_size_func_output[0].dtype() != DT_INT64 ||
    242                     window_size_func_output[0].NumElements() != 1) {
    243                   // TODO(mrry): Support non-int64 window sizes.
    244                   return errors::InvalidArgument(
    245                       "`window_size_func` must return a scalar int64.");
    246                 }
    247                 const int64 window_size =
    248                     window_size_func_output[0].scalar<int64>()();
    249                 if (window_size <= 0) {
    250                   return errors::InvalidArgument(
    251                       "Window size must be greater than zero, but got ",
    252                       window_size, ".");
    253                 }
    254                 window_sizes_[key] = window_size;
    255               }
    256 
    257               const int64 window_size = window_sizes_[key];
    258 
    259               std::vector<std::vector<Tensor>>& group = groups_[key];
    260               group.push_back(std::move(next_input_element));
    261 
    262               if (group.size() == window_size) {
    263                 current_key_ = key;
    264                 TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, key));
    265                 break;
    266               }
    267             }
    268           }
    269 
    270           if (end_of_input_) {
    271             if (!groups_.empty()) {
    272               // We have consumed all of the input, so flush an
    273               // arbitrarily chosen group.
    274               current_key_ = groups_.begin()->first;
    275               TF_RETURN_IF_ERROR(
    276                   StartFlushingGroup(ctx, groups_.begin()->first));
    277             }
    278           }
    279         } while (current_group_iterator_ || !end_of_input_);
    280 
    281         *end_of_sequence = true;
    282         return Status::OK();
    283       }
    284 
    285      protected:
    286       std::shared_ptr<model::Node> CreateNode(
    287           IteratorContext* ctx, model::Node::Args args) const override {
    288         return model::MakeUnknownRatioNode(std::move(args));
    289       }
    290 
    291       Status SaveInternal(IteratorStateWriter* writer) override {
    292         mutex_lock l(mu_);
    293         TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
    294 
    295         if (end_of_input_) {
    296           TF_RETURN_IF_ERROR(
    297               writer->WriteScalar(full_name("end_of_input"), ""));
    298         }
    299 
    300         // Saving groups_
    301         if (!groups_.empty()) {
    302           TF_RETURN_IF_ERROR(
    303               writer->WriteScalar(full_name("groups_size"), groups_.size()));
    304           int idx = 0;
    305           for (auto it = groups_.begin(); it != groups_.end(); it++) {
    306             int64 key = it->first;
    307             TF_RETURN_IF_ERROR(writer->WriteScalar(
    308                 full_name(strings::StrCat("groups_[", idx, "]->key")), key));
    309             TF_RETURN_IF_ERROR(SaveGroup(
    310                 writer, full_name(strings::StrCat("groups_[", idx, "]")),
    311                 it->second));
    312             idx++;
    313           }
    314         }
    315 
    316         // Saving window_sizes_
    317         if (!window_sizes_.empty()) {
    318           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("window_sizes_size"),
    319                                                  window_sizes_.size()));
    320           int idx = 0;
    321           for (auto it = window_sizes_.begin(); it != window_sizes_.end();
    322                it++) {
    323             TF_RETURN_IF_ERROR(writer->WriteScalar(
    324                 full_name(strings::StrCat("window_sizes_[", idx, "]->key")),
    325                 it->first));
    326             TF_RETURN_IF_ERROR(writer->WriteScalar(
    327                 full_name(strings::StrCat("window_sizes_[", idx, "]->value")),
    328                 it->second));
    329             idx++;
    330           }
    331         }
    332 
    333         if (current_group_iterator_) {
    334           TF_RETURN_IF_ERROR(SaveInput(writer, current_group_iterator_));
    335 
    336           // Saving current_key_
    337           TF_RETURN_IF_ERROR(
    338               writer->WriteScalar(full_name("current_key"), current_key_));
    339         } else {
    340           TF_RETURN_IF_ERROR(writer->WriteScalar(
    341               full_name("current_iterator_not_initialized"), ""));
    342         }
    343 
    344         return Status::OK();
    345       }
    346 
    347       Status RestoreInternal(IteratorContext* ctx,
    348                              IteratorStateReader* reader) override {
    349         mutex_lock l(mu_);
    350         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
    351 
    352         if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
    353 
    354         // Restoring groups
    355         if (reader->Contains(full_name("groups_size"))) {
    356           int64 size;
    357           TF_RETURN_IF_ERROR(
    358               reader->ReadScalar(full_name("groups_size"), &size));
    359           for (int idx = 0; idx < size; idx++) {
    360             int64 key;
    361             TF_RETURN_IF_ERROR(reader->ReadScalar(
    362                 full_name(strings::StrCat("groups_[", idx, "]->key")), &key));
    363             std::vector<std::vector<Tensor>> group;
    364             TF_RETURN_IF_ERROR(RestoreGroup(
    365                 reader, full_name(strings::StrCat("groups_[", idx, "]")),
    366                 &group));
    367             groups_[key] = group;
    368           }
    369         }
    370 
    371         // Restoring Windows
    372         if (reader->Contains(full_name("window_sizes_size"))) {
    373           int64 size;
    374           TF_RETURN_IF_ERROR(
    375               reader->ReadScalar(full_name("window_sizes_size"), &size));
    376           for (int idx = 0; idx < size; idx++) {
    377             int64 key;
    378             TF_RETURN_IF_ERROR(reader->ReadScalar(
    379                 full_name(strings::StrCat("window_sizes_[", idx, "]->key")),
    380                 &key));
    381             TF_RETURN_IF_ERROR(reader->ReadScalar(
    382                 full_name(strings::StrCat("window_sizes_[", idx, "]->value")),
    383                 &window_sizes_[key]));
    384           }
    385         }
    386 
    387         if (reader->Contains(full_name("current_iterator_not_initialized"))) {
    388           current_group_iterator_.reset();
    389         } else {
    390           // Restore current_key_
    391           TF_RETURN_IF_ERROR(
    392               reader->ReadScalar(full_name("current_key"), &current_key_));
    393 
    394           // Initialize current_group_iterator_
    395           TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, current_key_));
    396           // Restore current_group_iterator_ state
    397           TF_RETURN_IF_ERROR(
    398               RestoreInput(ctx, reader, current_group_iterator_));
    399         }
    400         return Status::OK();
    401       }
    402 
    403      private:
    404       Status SaveGroup(IteratorStateWriter* writer, const string& name,
    405                        const std::vector<std::vector<Tensor>>& group)
    406           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    407         TF_RETURN_IF_ERROR(
    408             writer->WriteScalar(strings::StrCat(name, "_size"), group.size()));
    409         for (int i = 0; i < group.size(); i++) {
    410           TF_RETURN_IF_ERROR(writer->WriteScalar(
    411               strings::StrCat(name, "[", i, "]_size"), group[i].size()));
    412           for (int j = 0; j < group[i].size(); j++) {
    413             TF_RETURN_IF_ERROR(writer->WriteTensor(
    414                 strings::StrCat(name, "[", i, "][", j, "]"), group[i][j]));
    415           }
    416         }
    417         return Status::OK();
    418       }
    419 
    420       Status RestoreGroup(IteratorStateReader* reader, const string& name,
    421                           std::vector<std::vector<Tensor>>* group)
    422           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    423         int64 group_size;
    424         TF_RETURN_IF_ERROR(
    425             reader->ReadScalar(strings::StrCat(name, "_size"), &group_size));
    426         group->resize(group_size);
    427         for (int i = 0; i < group_size; i++) {
    428           int64 vector_size;
    429           TF_RETURN_IF_ERROR(reader->ReadScalar(
    430               strings::StrCat(name, "[", i, "]_size"), &vector_size));
    431           group->at(i).resize(vector_size);
    432           for (int j = 0; j < vector_size; j++) {
    433             TF_RETURN_IF_ERROR(reader->ReadTensor(
    434                 strings::StrCat(name, "[", i, "][", j, "]"), &group->at(i)[j]));
    435           }
    436         }
    437         return Status::OK();
    438       }
    439 
    440       Status StartFlushingGroup(IteratorContext* ctx, int64 key)
    441           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    442         DatasetBase* group_dataset;
    443         TF_RETURN_IF_ERROR(NewWindowDataset(
    444             groups_[key], dataset()->input_->output_dtypes(),
    445             dataset()->input_->output_shapes(), &group_dataset));
    446 
    447         Tensor key_arg(DT_INT64, TensorShape({}));
    448         key_arg.scalar<int64>()() = key;
    449 
    450         Tensor group_dataset_arg(DT_VARIANT, TensorShape({}));
    451         TF_RETURN_IF_ERROR(
    452             StoreDatasetInVariantTensor(group_dataset, &group_dataset_arg));
    453 
    454         std::vector<Tensor> args(
    455             {std::move(key_arg), std::move(group_dataset_arg)});
    456         std::vector<Tensor> return_values;
    457         TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args),
    458                                                           &return_values));
    459 
    460         if (!(return_values.size() == 1 &&
    461               return_values[0].dtype() == DT_VARIANT &&
    462               TensorShapeUtils::IsScalar(return_values[0].shape()))) {
    463           return errors::InvalidArgument(
    464               "`reduce_func` must return a single scalar of dtype "
    465               "DT_VARIANT.");
    466         }
    467 
    468         // Retrieve the dataset that was created in `f`.
    469         // `returned_dataset` is borrowed from the `return_values[0]`.
    470         DatasetBase* returned_dataset;
    471         TF_RETURN_IF_ERROR(
    472             GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
    473 
    474         // Create an iterator for the dataset that was returned by `f`.
    475         return returned_dataset->MakeIterator(ctx, prefix(),
    476                                               &current_group_iterator_);
    477       }
    478 
    479       mutex mu_;
    480       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    481       // TODO(mrry): Optimize for dense key space if appropriate.
    482       bool end_of_input_ GUARDED_BY(mu_) = false;
    483       int64 current_key_ GUARDED_BY(mu_);
    484       std::map<int64, std::vector<std::vector<Tensor>>> groups_ GUARDED_BY(mu_);
    485       std::unique_ptr<IteratorBase> current_group_iterator_ GUARDED_BY(mu_);
    486       std::map<int64, int64> window_sizes_ GUARDED_BY(mu_);
    487       std::unique_ptr<InstantiatedCapturedFunction> instantiated_key_func_;
    488       std::unique_ptr<InstantiatedCapturedFunction> instantiated_reduce_func_;
    489       std::unique_ptr<InstantiatedCapturedFunction>
    490           instantiated_window_size_func_;
    491     };
    492 
    493     Status OtherArgumentsNodeAndType(
    494         SerializationContext* ctx, DatasetGraphDefBuilder* b,
    495         const std::unique_ptr<CapturedFunction>& captured_func,
    496         std::vector<Node*>* other_arguments_node,
    497         DataTypeVector* other_arguments_types) const {
    498       other_arguments_node->reserve(captured_func->captured_inputs().size());
    499       other_arguments_types->reserve(captured_func->captured_inputs().size());
    500       for (const Tensor& t : captured_func->captured_inputs()) {
    501         Node* node;
    502         DatasetBase* input;
    503         Status s = GetDatasetFromVariantTensor(t, &input);
    504         if (s.ok()) {
    505           TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
    506         } else {
    507           TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
    508         }
    509         other_arguments_node->emplace_back(node);
    510         other_arguments_types->emplace_back(t.dtype());
    511       }
    512       return Status::OK();
    513     }
    514 
    515     const DatasetBase* const input_;
    516     const NameAttrList key_func_;
    517     const NameAttrList reduce_func_;
    518     const NameAttrList window_size_func_;
    519     const std::unique_ptr<CapturedFunction> captured_key_func_;
    520     const std::unique_ptr<CapturedFunction> captured_reduce_func_;
    521     const std::unique_ptr<CapturedFunction> captured_window_size_func_;
    522     const DataTypeVector output_types_;
    523     const std::vector<PartialTensorShape> output_shapes_;
    524   };
    525 
    526   DataTypeVector output_types_;
    527   std::vector<PartialTensorShape> output_shapes_;
    528   NameAttrList key_func_;
    529   NameAttrList reduce_func_;
    530   NameAttrList window_size_func_;
    531 };
    532 
    533 REGISTER_KERNEL_BUILDER(
    534     Name("ExperimentalGroupByWindowDataset").Device(DEVICE_CPU),
    535     GroupByWindowDatasetOp);
    536 
    537 }  // namespace
    538 }  // namespace data
    539 }  // namespace tensorflow
    540