Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
     17 #include "tensorflow/core/framework/op_kernel.h"
     18 
     19 namespace tensorflow {
     20 namespace data {
     21 namespace {
     22 
     23 class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
     24  public:
     25   using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
     26 
     27   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     28                    DatasetBase** output) override {
     29     BigtableTableResource* table;
     30     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
     31     core::ScopedUnref scoped_unref(table);
     32 
     33     std::vector<string> column_families;
     34     std::vector<string> columns;
     35     OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
     36                                                     &column_families));
     37     OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
     38     OP_REQUIRES(
     39         ctx, column_families.size() == columns.size(),
     40         errors::InvalidArgument("len(columns) != len(column_families)"));
     41 
     42     const uint64 num_outputs = columns.size() + 1;
     43     std::vector<PartialTensorShape> output_shapes;
     44     output_shapes.reserve(num_outputs);
     45     DataTypeVector output_types;
     46     output_types.reserve(num_outputs);
     47     for (uint64 i = 0; i < num_outputs; ++i) {
     48       output_shapes.push_back({});
     49       output_types.push_back(DT_STRING);
     50     }
     51 
     52     *output =
     53         new Dataset(ctx, input, table, std::move(column_families),
     54                     std::move(columns), output_types, std::move(output_shapes));
     55   }
     56 
     57  private:
     58   class Dataset : public DatasetBase {
     59    public:
     60     explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
     61                      BigtableTableResource* table,
     62                      std::vector<string> column_families,
     63                      std::vector<string> columns,
     64                      const DataTypeVector& output_types,
     65                      std::vector<PartialTensorShape> output_shapes)
     66         : DatasetBase(DatasetContext(ctx)),
     67           input_(input),
     68           table_(table),
     69           column_families_(std::move(column_families)),
     70           columns_(std::move(columns)),
     71           output_types_(output_types),
     72           output_shapes_(std::move(output_shapes)),
     73           filter_(MakeFilter(column_families_, columns_)) {
     74       table_->Ref();
     75       input_->Ref();
     76     }
     77 
     78     ~Dataset() override {
     79       table_->Unref();
     80       input_->Unref();
     81     }
     82 
     83     std::unique_ptr<IteratorBase> MakeIteratorInternal(
     84         const string& prefix) const override {
     85       return std::unique_ptr<IteratorBase>(
     86           new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")}));
     87     }
     88 
     89     const DataTypeVector& output_dtypes() const override {
     90       return output_types_;
     91     }
     92 
     93     const std::vector<PartialTensorShape>& output_shapes() const override {
     94       return output_shapes_;
     95     }
     96 
     97     string DebugString() const override {
     98       return "BigtableLookupDatasetOp::Dataset";
     99     }
    100 
    101    protected:
    102     Status AsGraphDefInternal(SerializationContext* ctx,
    103                               DatasetGraphDefBuilder* b,
    104                               Node** output) const override {
    105       return errors::Unimplemented("%s does not support serialization",
    106                                    DebugString());
    107     }
    108 
    109    private:
    110     static ::google::cloud::bigtable::Filter MakeFilter(
    111         const std::vector<string>& column_families,
    112         const std::vector<string>& columns) {
    113       string column_family_regex = RegexFromStringSet(column_families);
    114       string column_regex = RegexFromStringSet(columns);
    115 
    116       return ::google::cloud::bigtable::Filter::Chain(
    117           ::google::cloud::bigtable::Filter::Latest(1),
    118           ::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex),
    119           ::google::cloud::bigtable::Filter::ColumnRegex(column_regex));
    120     }
    121 
    122     class Iterator : public DatasetIterator<Dataset> {
    123      public:
    124       explicit Iterator(const Params& params)
    125           : DatasetIterator<Dataset>(params) {}
    126 
    127       Status Initialize(IteratorContext* ctx) override {
    128         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
    129       }
    130 
    131       Status GetNextInternal(IteratorContext* ctx,
    132                              std::vector<Tensor>* out_tensors,
    133                              bool* end_of_sequence) override {
    134         mutex_lock l(mu_);  // Sequence requests.
    135         std::vector<Tensor> input_tensors;
    136         TF_RETURN_IF_ERROR(
    137             input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
    138         if (*end_of_sequence) {
    139           return Status::OK();
    140         }
    141         if (input_tensors.size() != 1) {
    142           return errors::InvalidArgument(
    143               "Upstream iterator (", dataset()->input_->DebugString(),
    144               ") did not produce a single `tf.string` `tf.Tensor`. It "
    145               "produced ",
    146               input_tensors.size(), " tensors.");
    147         }
    148         if (input_tensors[0].NumElements() == 0) {
    149           return errors::InvalidArgument("Upstream iterator (",
    150                                          dataset()->input_->DebugString(),
    151                                          ") return an empty set of keys.");
    152         }
    153         if (input_tensors[0].NumElements() == 1) {
    154           // Single key lookup.
    155           ::google::cloud::Status status;
    156           auto pair = dataset()->table_->table().ReadRow(
    157               input_tensors[0].scalar<string>()(), dataset()->filter_, status);
    158           if (!status.ok()) {
    159             return GcpStatusToTfStatus(status);
    160           }
    161           if (!pair.first) {
    162             return errors::DataLoss("Row key '",
    163                                     input_tensors[0].scalar<string>()(),
    164                                     "' not found.");
    165           }
    166           TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors));
    167         } else {
    168           // Batched get.
    169           return errors::Unimplemented(
    170               "BigtableLookupDataset doesn't yet support batched retrieval.");
    171         }
    172         return Status::OK();
    173       }
    174 
    175      private:
    176       Status ParseRow(IteratorContext* ctx,
    177                       const ::google::cloud::bigtable::Row& row,
    178                       std::vector<Tensor>* out_tensors) {
    179         out_tensors->reserve(dataset()->columns_.size() + 1);
    180         Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
    181         row_key_tensor.scalar<string>()() = string(row.row_key());
    182         out_tensors->emplace_back(std::move(row_key_tensor));
    183 
    184         if (row.cells().size() > 2 * dataset()->columns_.size()) {
    185           LOG(WARNING) << "An excessive number of columns ("
    186                        << row.cells().size()
    187                        << ") were retrieved when reading row: "
    188                        << row.row_key();
    189         }
    190 
    191         for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
    192           Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
    193           bool found_column = false;
    194           for (auto cell_itr = row.cells().begin();
    195                !found_column && cell_itr != row.cells().end(); ++cell_itr) {
    196             if (cell_itr->family_name() == dataset()->column_families_[i] &&
    197                 string(cell_itr->column_qualifier()) ==
    198                     dataset()->columns_[i]) {
    199               col_tensor.scalar<string>()() = string(cell_itr->value());
    200               found_column = true;
    201             }
    202           }
    203           if (!found_column) {
    204             return errors::DataLoss("Column ", dataset()->column_families_[i],
    205                                     ":", dataset()->columns_[i],
    206                                     " not found in row: ", row.row_key());
    207           }
    208           out_tensors->emplace_back(std::move(col_tensor));
    209         }
    210         return Status::OK();
    211       }
    212 
    213       mutex mu_;
    214       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    215     };
    216 
    217     const DatasetBase* const input_;
    218     BigtableTableResource* table_;
    219     const std::vector<string> column_families_;
    220     const std::vector<string> columns_;
    221     const DataTypeVector output_types_;
    222     const std::vector<PartialTensorShape> output_shapes_;
    223     const ::google::cloud::bigtable::Filter filter_;
    224   };
    225 };
    226 
    227 REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
    228                         BigtableLookupDatasetOp);
    229 
    230 }  // namespace
    231 }  // namespace data
    232 }  // namespace tensorflow
    233