Home | History | Annotate | Download | only in experimental
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <map>
     16 
     17 #include "tensorflow/core/common_runtime/function.h"
     18 #include "tensorflow/core/framework/dataset.h"
     19 #include "tensorflow/core/framework/partial_tensor_shape.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/kernels/data/captured_function.h"
     22 #include "tensorflow/core/lib/random/random.h"
     23 
     24 namespace tensorflow {
     25 namespace data {
     26 namespace {
     27 
     28 // See documentation in ../../ops/dataset_ops.cc for a high-level
     29 // description of the following op.
     30 class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
     31  public:
     32   explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
     33       : UnaryDatasetOpKernel(ctx) {
     34     OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
     35     OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
     36     OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
     38     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     39     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
     40   }
     41 
     42   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     43                    DatasetBase** output) override {
     44     std::unique_ptr<CapturedFunction> captured_key_func;
     45     OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
     46                                                  "key_func_other_arguments",
     47                                                  &captured_key_func));
     48     std::unique_ptr<CapturedFunction> captured_init_func;
     49     OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx,
     50                                                  "init_func_other_arguments",
     51                                                  &captured_init_func));
     52     std::unique_ptr<CapturedFunction> captured_reduce_func;
     53     OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
     54                                                  "reduce_func_other_arguments",
     55                                                  &captured_reduce_func));
     56     std::unique_ptr<CapturedFunction> captured_finalize_func;
     57     OP_REQUIRES_OK(ctx,
     58                    CapturedFunction::Create(finalize_func_, ctx,
     59                                             "finalize_func_other_arguments",
     60                                             &captured_finalize_func));
     61 
     62     *output = new Dataset(
     63         ctx, input, std::move(captured_key_func), std::move(captured_init_func),
     64         std::move(captured_reduce_func), std::move(captured_finalize_func),
     65         output_types_, output_shapes_);
     66   }
     67 
     68  private:
     69   class Dataset : public DatasetBase {
     70    public:
     71     Dataset(OpKernelContext* ctx, const DatasetBase* input,
     72             std::unique_ptr<CapturedFunction> captured_key_func,
     73             std::unique_ptr<CapturedFunction> captured_init_func,
     74             std::unique_ptr<CapturedFunction> captured_reduce_func,
     75             std::unique_ptr<CapturedFunction> captured_finalize_func,
     76             const DataTypeVector& output_types,
     77             const std::vector<PartialTensorShape>& output_shapes)
     78         : DatasetBase(DatasetContext(ctx)),
     79           input_(input),
     80           captured_key_func_(std::move(captured_key_func)),
     81           captured_init_func_(std::move(captured_init_func)),
     82           captured_reduce_func_(std::move(captured_reduce_func)),
     83           captured_finalize_func_(std::move(captured_finalize_func)),
     84           output_types_(output_types),
     85           output_shapes_(output_shapes) {
     86       input_->Ref();
     87     }
     88 
     89     ~Dataset() override { input_->Unref(); }
     90 
     91     std::unique_ptr<IteratorBase> MakeIteratorInternal(
     92         const string& prefix) const override {
     93       return absl::make_unique<Iterator>(
     94           Iterator::Params{this, strings::StrCat(prefix, "::GroupByReducer")});
     95     }
     96 
     97     const DataTypeVector& output_dtypes() const override {
     98       return output_types_;
     99     }
    100     const std::vector<PartialTensorShape>& output_shapes() const override {
    101       return output_shapes_;
    102     }
    103 
    104     string DebugString() const override {
    105       return "GroupByReducerDatasetOp::Dataset";
    106     }
    107 
    108    protected:
    109     Status AsGraphDefInternal(SerializationContext* ctx,
    110                               DatasetGraphDefBuilder* b,
    111                               Node** output) const override {
    112       TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
    113       TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
    114       TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
    115       TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
    116       Node* input_graph_node = nullptr;
    117       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
    118 
    119       std::vector<Node*> key_func_other_arguments_node;
    120       DataTypeVector key_func_other_arguments_types;
    121       TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
    122           ctx, b, captured_key_func_, &key_func_other_arguments_node,
    123           &key_func_other_arguments_types));
    124 
    125       std::vector<Node*> init_func_other_arguments_node;
    126       DataTypeVector init_func_other_arguments_types;
    127       TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
    128           ctx, b, captured_init_func_, &init_func_other_arguments_node,
    129           &init_func_other_arguments_types));
    130 
    131       std::vector<Node*> reduce_func_other_arguments_node;
    132       DataTypeVector reduce_func_other_arguments_types;
    133       TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
    134           ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node,
    135           &reduce_func_other_arguments_types));
    136 
    137       std::vector<Node*> finalize_func_other_arguments_node;
    138       DataTypeVector finalize_func_other_arguments_types;
    139       TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
    140           ctx, b, captured_finalize_func_, &finalize_func_other_arguments_node,
    141           &finalize_func_other_arguments_types));
    142 
    143       AttrValue key_func;
    144       b->BuildAttrValue(this->key_func(), &key_func);
    145       AttrValue init_func;
    146       b->BuildAttrValue(this->init_func(), &init_func);
    147       AttrValue reduce_func;
    148       b->BuildAttrValue(this->reduce_func(), &reduce_func);
    149       AttrValue finalize_func;
    150       b->BuildAttrValue(this->finalize_func(), &finalize_func);
    151 
    152       AttrValue key_func_other_arguments_types_attr;
    153       b->BuildAttrValue(key_func_other_arguments_types,
    154                         &key_func_other_arguments_types_attr);
    155       AttrValue init_func_other_arguments_types_attr;
    156       b->BuildAttrValue(init_func_other_arguments_types,
    157                         &init_func_other_arguments_types_attr);
    158       AttrValue reduce_func_other_arguments_types_attr;
    159       b->BuildAttrValue(reduce_func_other_arguments_types,
    160                         &reduce_func_other_arguments_types_attr);
    161       AttrValue finalize_func_other_arguments_types_attr;
    162       b->BuildAttrValue(finalize_func_other_arguments_types,
    163                         &finalize_func_other_arguments_types_attr);
    164 
    165       TF_RETURN_IF_ERROR(b->AddDataset(
    166           this, {{0, input_graph_node}},
    167           {{1, key_func_other_arguments_node},
    168            {2, init_func_other_arguments_node},
    169            {3, reduce_func_other_arguments_node},
    170            {4, finalize_func_other_arguments_node}},
    171           {{"key_func", key_func},
    172            {"init_func", init_func},
    173            {"reduce_func", reduce_func},
    174            {"finalize_func", finalize_func},
    175            {"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
    176            {"Tinit_func_other_arguments", init_func_other_arguments_types_attr},
    177            {"Treduce_func_other_arguments",
    178             reduce_func_other_arguments_types_attr},
    179            {"Tfinalize_func_other_arguments",
    180             finalize_func_other_arguments_types_attr}},
    181           output));
    182       return Status::OK();
    183     }
    184 
    185    private:
    186     class Iterator : public DatasetIterator<Dataset> {
    187      public:
    188       explicit Iterator(const Params& params)
    189           : DatasetIterator<Dataset>(params) {}
    190 
    191       Status Initialize(IteratorContext* ctx) override {
    192         TF_RETURN_IF_ERROR(
    193             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
    194         TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(
    195             ctx, &instantiated_key_func_));
    196         TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(
    197             ctx, &instantiated_init_func_));
    198         TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(
    199             ctx, &instantiated_reduce_func_));
    200         TF_RETURN_IF_ERROR(dataset()->captured_finalize_func_->Instantiate(
    201             ctx, &instantiated_finalize_func_));
    202         return Status::OK();
    203       }
    204 
    205       Status GetNextInternal(IteratorContext* ctx,
    206                              std::vector<Tensor>* out_tensors,
    207                              bool* end_of_sequence) override {
    208         mutex_lock l(mu_);
    209 
    210         // Iterate through the input dataset, keying input elements to reducers.
    211         while (!end_of_input_) {
    212           std::vector<Tensor> next_input_element;
    213           TF_RETURN_IF_ERROR(
    214               input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
    215 
    216           if (!end_of_input_) {
    217             // Run the key function on the input element.
    218             std::vector<Tensor> key_func_output;
    219             TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
    220                 ctx, next_input_element, &key_func_output));
    221 
    222             if (key_func_output.size() != 1 ||
    223                 key_func_output[0].dtype() != DT_INT64 ||
    224                 key_func_output[0].NumElements() != 1) {
    225               // TODO(b/78665031): Support non-int64 keys.
    226               return errors::InvalidArgument(
    227                   "`key_func` must return a scalar int64.");
    228             }
    229             const int64 key = key_func_output[0].scalar<int64>()();
    230 
    231             if (states_.find(key) == states_.end()) {
    232               // Run the init function to create the initial state.
    233               std::vector<Tensor> init_func_output;
    234               TF_RETURN_IF_ERROR(instantiated_init_func_->Run(
    235                   ctx, std::move(key_func_output), &init_func_output));
    236               states_[key] = init_func_output;
    237             }
    238 
    239             // Run the reduce function to update the current state.
    240             std::vector<Tensor> args;
    241             args.reserve(states_[key].size() + next_input_element.size());
    242             std::copy(states_[key].begin(), states_[key].end(),
    243                       std::back_inserter(args));
    244             std::copy(next_input_element.begin(), next_input_element.end(),
    245                       std::back_inserter(args));
    246 
    247             std::vector<Tensor> reduce_func_output;
    248             TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
    249                 ctx, std::move(args), &reduce_func_output));
    250             states_[key] = reduce_func_output;
    251           } else {
    252             keys_.resize(states_.size());
    253             int idx = 0;
    254             for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
    255               keys_[idx] = it->first;
    256             }
    257           }
    258         }
    259 
    260         if (keys_index_ == keys_.size()) {
    261           *end_of_sequence = true;
    262           return Status::OK();
    263         }
    264         TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
    265             ctx, states_[keys_[keys_index_++]], out_tensors));
    266         *end_of_sequence = false;
    267         return Status::OK();
    268       }
    269 
    270      protected:
    271       std::shared_ptr<model::Node> CreateNode(
    272           IteratorContext* ctx, model::Node::Args args) const override {
    273         return model::MakeUnknownRatioNode(std::move(args));
    274       }
    275 
    276       Status SaveInternal(IteratorStateWriter* writer) override {
    277         mutex_lock l(mu_);
    278         TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
    279 
    280         if (end_of_input_) {
    281           TF_RETURN_IF_ERROR(
    282               writer->WriteScalar(full_name("end_of_input"), ""));
    283         }
    284 
    285         // Saving states_.
    286         if (!states_.empty()) {
    287           TF_RETURN_IF_ERROR(
    288               writer->WriteScalar(full_name("states_size"), states_.size()));
    289           int idx = 0;
    290           for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
    291             int64 key = it->first;
    292             TF_RETURN_IF_ERROR(writer->WriteScalar(
    293                 full_name(strings::StrCat("states[", idx, "]->key")), key));
    294             if (!it->second.empty()) {
    295               TF_RETURN_IF_ERROR(writer->WriteScalar(
    296                   full_name(strings::StrCat("states[", idx, "]->state_size")),
    297                   it->second.size()));
    298               for (int j = 0; j < it->second.size(); ++j) {
    299                 TF_RETURN_IF_ERROR(writer->WriteTensor(
    300                     full_name(
    301                         strings::StrCat("states[", idx, "]->state[", j, "]")),
    302                     it->second[j]));
    303               }
    304             }
    305           }
    306         }
    307 
    308         // Saving keys_index_ and keys_.
    309         if (end_of_input_) {
    310           TF_RETURN_IF_ERROR(
    311               writer->WriteScalar(full_name("keys_index"), keys_index_));
    312           if (!keys_.empty()) {
    313             TF_RETURN_IF_ERROR(
    314                 writer->WriteScalar(full_name("keys_size"), keys_.size()));
    315             for (int idx = 0; idx < keys_.size(); ++idx) {
    316               TF_RETURN_IF_ERROR(writer->WriteScalar(
    317                   full_name(strings::StrCat("keys[", idx, "]")), keys_[idx]));
    318             }
    319           }
    320         }
    321 
    322         return Status::OK();
    323       }
    324 
    325       Status RestoreInternal(IteratorContext* ctx,
    326                              IteratorStateReader* reader) override {
    327         mutex_lock l(mu_);
    328         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
    329 
    330         if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
    331 
    332         // Restoring states_.
    333         if (reader->Contains(full_name("states_size"))) {
    334           int64 size;
    335           TF_RETURN_IF_ERROR(
    336               reader->ReadScalar(full_name("states_size"), &size));
    337           for (int idx = 0; idx < size; ++idx) {
    338             int64 key;
    339             TF_RETURN_IF_ERROR(reader->ReadScalar(
    340                 full_name(strings::StrCat("states[", idx, "]->key")), &key));
    341             std::vector<Tensor> state;
    342             if (reader->Contains(full_name(
    343                     strings::StrCat("states[", idx, "]->state_size")))) {
    344               int64 state_size;
    345               TF_RETURN_IF_ERROR(reader->ReadScalar(
    346                   full_name(strings::StrCat("states[", idx, "]->state_size")),
    347                   &state_size));
    348               state.resize(state_size);
    349               for (int j = 0; j < state_size; ++j) {
    350                 TF_RETURN_IF_ERROR(reader->ReadTensor(
    351                     full_name(
    352                         strings::StrCat("states[", idx, "]->state[", j, "]")),
    353                     &state[j]));
    354               }
    355             }
    356             states_[key] = state;
    357           }
    358         }
    359 
    360         // Restoring keys_index_ and keys_.
    361         if (end_of_input_) {
    362           TF_RETURN_IF_ERROR(
    363               reader->ReadScalar(full_name("keys_index"), &keys_index_));
    364           if (reader->Contains(full_name("keys_size"))) {
    365             int64 size;
    366             TF_RETURN_IF_ERROR(
    367                 reader->ReadScalar(full_name("keys_size"), &size));
    368             keys_.resize(size);
    369             for (int idx = 0; idx < size; ++idx) {
    370               int64 key;
    371               TF_RETURN_IF_ERROR(reader->ReadScalar(
    372                   full_name(strings::StrCat("keys[", idx, "]")), &key));
    373               keys_[idx] = key;
    374             }
    375           }
    376         }
    377 
    378         return Status::OK();
    379       }
    380 
    381      private:
    382       mutex mu_;
    383       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    384       bool end_of_input_ GUARDED_BY(mu_) = false;
    385       std::map<int64, std::vector<Tensor>> states_ GUARDED_BY(mu_);
    386       std::vector<int64> keys_ GUARDED_BY(mu_);
    387       int64 keys_index_ GUARDED_BY(mu_) = 0;
    388       std::unique_ptr<InstantiatedCapturedFunction> instantiated_key_func_;
    389       std::unique_ptr<InstantiatedCapturedFunction> instantiated_init_func_;
    390       std::unique_ptr<InstantiatedCapturedFunction> instantiated_reduce_func_;
    391       std::unique_ptr<InstantiatedCapturedFunction> instantiated_finalize_func_;
    392     };
    393 
    394     const NameAttrList& key_func() const { return captured_key_func_->func(); }
    395 
    396     const NameAttrList& init_func() const {
    397       return captured_init_func_->func();
    398     }
    399 
    400     const NameAttrList& reduce_func() const {
    401       return captured_reduce_func_->func();
    402     }
    403 
    404     const NameAttrList& finalize_func() const {
    405       return captured_finalize_func_->func();
    406     }
    407 
    408     Status OtherArgumentsNodeAndType(
    409         SerializationContext* ctx, DatasetGraphDefBuilder* b,
    410         const std::unique_ptr<CapturedFunction>& captured_func,
    411         std::vector<Node*>* other_arguments_node,
    412         DataTypeVector* other_arguments_types) const {
    413       other_arguments_node->reserve(captured_func->captured_inputs().size());
    414       other_arguments_types->reserve(captured_func->captured_inputs().size());
    415       for (const Tensor& t : captured_func->captured_inputs()) {
    416         Node* node;
    417         DatasetBase* input;
    418         Status s = GetDatasetFromVariantTensor(t, &input);
    419         if (s.ok()) {
    420           TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
    421         } else {
    422           TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
    423         }
    424         other_arguments_node->emplace_back(node);
    425         other_arguments_types->emplace_back(t.dtype());
    426       }
    427       return Status::OK();
    428     }
    429 
    430     const DatasetBase* const input_;
    431     const std::unique_ptr<CapturedFunction> captured_key_func_;
    432     const std::unique_ptr<CapturedFunction> captured_init_func_;
    433     const std::unique_ptr<CapturedFunction> captured_reduce_func_;
    434     const std::unique_ptr<CapturedFunction> captured_finalize_func_;
    435     const DataTypeVector output_types_;
    436     const std::vector<PartialTensorShape> output_shapes_;
    437   };
    438 
    439   DataTypeVector output_types_;
    440   std::vector<PartialTensorShape> output_shapes_;
    441   NameAttrList key_func_;
    442   NameAttrList init_func_;
    443   NameAttrList reduce_func_;
    444   NameAttrList finalize_func_;
    445 };
    446 
    447 REGISTER_KERNEL_BUILDER(
    448     Name("ExperimentalGroupByReducerDataset").Device(DEVICE_CPU),
    449     GroupByReducerDatasetOp);
    450 
    451 }  // namespace
    452 }  // namespace data
    453 }  // namespace tensorflow
    454