Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <map>
     17 #include <memory>
     18 #include <set>
     19 
     20 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
     21 #include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h"
     22 #include "tensorflow/core/framework/reader_base.h"
     23 #include "tensorflow/core/framework/reader_op_kernel.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/math/math_util.h"
     26 #include "tensorflow/core/lib/strings/numbers.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 constexpr int64 kDefaultRowBufferSize = 1000;  // Number of rows to buffer.
     32 
     33 // This is a helper function for reading table attributes from context.
     34 Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
     35                      string* dataset_id, string* table_id,
     36                      int64* timestamp_millis, std::vector<string>* columns,
     37                      string* test_end_point) {
     38   TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
     39   TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
     40   TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
     41   TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
     42   TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
     43   TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
     44   return Status::OK();
     45 }
     46 
     47 }  // namespace
     48 
     49 // Note that overridden methods with names ending in "Locked" are called by
     50 // ReaderBase while a mutex is held.
     51 // See comments for ReaderBase.
     52 class BigQueryReader : public ReaderBase {
     53  public:
     54   explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
     55                           const string& node_name)
     56       : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
     57         bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
     58 
     59   Status OnWorkStartedLocked() override {
     60     BigQueryTablePartition partition;
     61     if (!partition.ParseFromString(current_work())) {
     62       return errors::InvalidArgument(
     63           "Could not parse work as valid partition.");
     64     }
     65     TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
     66     return Status::OK();
     67   }
     68 
     69   Status ReadLocked(string* key, string* value, bool* produced,
     70                     bool* at_end) override {
     71     *at_end = false;
     72     *produced = false;
     73     if (bigquery_table_accessor_->Done()) {
     74       *at_end = true;
     75       return Status::OK();
     76     }
     77 
     78     Example example;
     79     int64 row_id;
     80     TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
     81 
     82     *key = std::to_string(row_id);
     83     *value = example.SerializeAsString();
     84     *produced = true;
     85     return Status::OK();
     86   }
     87 
     88  private:
     89   // Not owned.
     90   BigQueryTableAccessor* bigquery_table_accessor_;
     91 };
     92 
     93 class BigQueryReaderOp : public ReaderOpKernel {
     94  public:
     95   explicit BigQueryReaderOp(OpKernelConstruction* context)
     96       : ReaderOpKernel(context) {
     97     string table_id;
     98     string project_id;
     99     string dataset_id;
    100     int64 timestamp_millis;
    101     std::vector<string> columns;
    102     string test_end_point;
    103 
    104     OP_REQUIRES_OK(context,
    105                    GetTableAttrs(context, &project_id, &dataset_id, &table_id,
    106                                  &timestamp_millis, &columns, &test_end_point));
    107     OP_REQUIRES_OK(context,
    108                    BigQueryTableAccessor::New(
    109                        project_id, dataset_id, table_id, timestamp_millis,
    110                        kDefaultRowBufferSize, test_end_point, columns,
    111                        BigQueryTablePartition(), &bigquery_table_accessor_));
    112 
    113     SetReaderFactory([this]() {
    114       return new BigQueryReader(bigquery_table_accessor_.get(), name());
    115     });
    116   }
    117 
    118  private:
    119   std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
    120 };
    121 
    122 REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
    123                         BigQueryReaderOp);
    124 
    125 class GenerateBigQueryReaderPartitionsOp : public OpKernel {
    126  public:
    127   explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
    128       : OpKernel(context) {
    129     string project_id;
    130     string dataset_id;
    131     string table_id;
    132     int64 timestamp_millis;
    133     std::vector<string> columns;
    134     string test_end_point;
    135 
    136     OP_REQUIRES_OK(context,
    137                    GetTableAttrs(context, &project_id, &dataset_id, &table_id,
    138                                  &timestamp_millis, &columns, &test_end_point));
    139     OP_REQUIRES_OK(context,
    140                    BigQueryTableAccessor::New(
    141                        project_id, dataset_id, table_id, timestamp_millis,
    142                        kDefaultRowBufferSize, test_end_point, columns,
    143                        BigQueryTablePartition(), &bigquery_table_accessor_));
    144     OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
    145     OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
    146   }
    147 
    148   void Compute(OpKernelContext* context) override {
    149     const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
    150         total_num_rows_, num_partitions_);
    151     Tensor* output_tensor = nullptr;
    152     OP_REQUIRES_OK(context,
    153                    context->allocate_output(0, TensorShape({num_partitions_}),
    154                                             &output_tensor));
    155 
    156     auto output = output_tensor->template flat<string>();
    157     for (int64 i = 0; i < num_partitions_; ++i) {
    158       BigQueryTablePartition partition;
    159       partition.set_start_index(i * partition_size);
    160       partition.set_end_index(
    161           std::min(total_num_rows_, (i + 1) * partition_size) - 1);
    162       output(i) = partition.SerializeAsString();
    163     }
    164   }
    165 
    166  private:
    167   Status InitializeTotalNumberOfRows() {
    168     total_num_rows_ = bigquery_table_accessor_->total_num_rows();
    169     if (total_num_rows_ <= 0) {
    170       return errors::FailedPrecondition("Invalid total number of rows.");
    171     }
    172     return Status::OK();
    173   }
    174 
    175   Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
    176     TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
    177     if (num_partitions_ <= 0) {
    178       return errors::FailedPrecondition("Invalid number of partitions.");
    179     }
    180     return Status::OK();
    181   }
    182 
    183   int64 num_partitions_;
    184   int64 total_num_rows_;
    185   std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
    186 };
    187 
    188 REGISTER_KERNEL_BUILDER(
    189     Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
    190     GenerateBigQueryReaderPartitionsOp);
    191 
    192 }  // namespace tensorflow
    193