Home | History | Annotate | Download | only in experimental
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <utility>
     16 
     17 #include "tensorflow/core/framework/dataset.h"
     18 #include "tensorflow/core/framework/partial_tensor_shape.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/kernels/data/experimental/sql/driver_manager.h"
     21 #include "tensorflow/core/kernels/data/experimental/sql/query_connection.h"
     22 #include "tensorflow/core/lib/io/inputbuffer.h"
     23 #include "tensorflow/core/lib/io/record_reader.h"
     24 #include "tensorflow/core/lib/strings/stringprintf.h"
     25 
     26 namespace tensorflow {
     27 namespace data {
     28 namespace {
     29 
     30 // See documentation in ../../ops/dataset_ops.cc for a high-level
     31 // description of the following ops.
     32 
     33 class SqlDatasetOp : public DatasetOpKernel {
     34  public:
     35   explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
     36     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
     38     for (const DataType& dt : output_types_) {
     39       OP_REQUIRES(ctx,
     40                   dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
     41                       dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
     42                       dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
     43                   errors::InvalidArgument(
     44                       "Each element of `output_types_` must be one of: "
     45                       "DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
     46                       "DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
     47     }
     48     for (const PartialTensorShape& pts : output_shapes_) {
     49       OP_REQUIRES(ctx, pts.dims() == 0,
     50                   errors::InvalidArgument(
     51                       "Each element of `output_shapes_` must be a scalar."));
     52     }
     53   }
     54   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
     55     string driver_name;
     56     OP_REQUIRES_OK(
     57         ctx, ParseScalarArgument<string>(ctx, "driver_name", &driver_name));
     58 
     59     string data_source_name;
     60     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "data_source_name",
     61                                                     &data_source_name));
     62 
     63     string query;
     64     OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "query", &query));
     65 
     66     // TODO(b/64276826) Change this check when we add support for other
     67     // databases.
     68     OP_REQUIRES(ctx, driver_name == "sqlite",
     69                 errors::InvalidArgument(tensorflow::strings::Printf(
     70                     "The database type, %s, is not supported by SqlDataset. "
     71                     "The set of supported databases is: {'sqlite'}.",
     72                     driver_name.c_str())));
     73 
     74     *output = new Dataset(ctx, driver_name, data_source_name, query,
     75                           output_types_, output_shapes_);
     76   }
     77 
     78  private:
     79   class Dataset : public DatasetBase {
     80    public:
     81     Dataset(OpKernelContext* ctx, const string& driver_name,
     82             const string& data_source_name, const string& query,
     83             const DataTypeVector& output_types,
     84             const std::vector<PartialTensorShape>& output_shapes)
     85         : DatasetBase(DatasetContext(ctx)),
     86           driver_name_(driver_name),
     87           data_source_name_(data_source_name),
     88           query_(query),
     89           output_types_(output_types),
     90           output_shapes_(output_shapes) {}
     91 
     92     std::unique_ptr<IteratorBase> MakeIteratorInternal(
     93         const string& prefix) const override {
     94       return absl::make_unique<Iterator>(
     95           Iterator::Params{this, strings::StrCat(prefix, "::Sql")});
     96     }
     97 
     98     const DataTypeVector& output_dtypes() const override {
     99       return output_types_;
    100     }
    101 
    102     const std::vector<PartialTensorShape>& output_shapes() const override {
    103       return output_shapes_;
    104     }
    105 
    106     string DebugString() const override { return "SqlDatasetOp::Dataset"; }
    107 
    108    protected:
    109     Status AsGraphDefInternal(SerializationContext* ctx,
    110                               DatasetGraphDefBuilder* b,
    111                               Node** output) const override {
    112       Node* driver_name_node;
    113       TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node));
    114       Node* data_source_name_node;
    115       TF_RETURN_IF_ERROR(
    116           b->AddScalar(data_source_name_, &data_source_name_node));
    117       Node* query_node;
    118       TF_RETURN_IF_ERROR(b->AddScalar(query_, &query_node));
    119       TF_RETURN_IF_ERROR(b->AddDataset(
    120           this, {driver_name_node, data_source_name_node, query_node}, output));
    121       return Status::OK();
    122     }
    123 
    124    private:
    125     class Iterator : public DatasetIterator<Dataset> {
    126      public:
    127       explicit Iterator(const Params& params)
    128           : DatasetIterator<Dataset>(params) {}
    129       ~Iterator() override {
    130         if (query_connection_initialized_) {
    131           Status s = query_connection_->Close();
    132           if (!s.ok()) {
    133             LOG(WARNING) << "Failed to close query connection: " << s;
    134           }
    135         }
    136       }
    137 
    138       Status GetNextInternal(IteratorContext* ctx,
    139                              std::vector<Tensor>* out_tensors,
    140                              bool* end_of_sequence) override {
    141         mutex_lock l(mu_);
    142         if (!query_connection_initialized_) {
    143           TF_RETURN_IF_ERROR(InitializeQueryConnection());
    144         }
    145         next_calls_++;
    146         return query_connection_->GetNext(ctx, out_tensors, end_of_sequence);
    147       }
    148 
    149      protected:
    150       std::shared_ptr<model::Node> CreateNode(
    151           IteratorContext* ctx, model::Node::Args args) const override {
    152         return model::MakeSourceNode(std::move(args));
    153       }
    154 
    155       Status SaveInternal(IteratorStateWriter* writer) override {
    156         mutex_lock l(mu_);
    157         if (query_connection_initialized_) {
    158           TF_RETURN_IF_ERROR(
    159               writer->WriteScalar(full_name("next_calls"), next_calls_));
    160         }
    161         return Status::OK();
    162       }
    163 
    164       Status RestoreInternal(IteratorContext* ctx,
    165                              IteratorStateReader* reader) override {
    166         mutex_lock l(mu_);
    167         if (reader->Contains(full_name("next_calls"))) {
    168           TF_RETURN_IF_ERROR(InitializeQueryConnection());
    169           TF_RETURN_IF_ERROR(
    170               reader->ReadScalar(full_name("next_calls"), &next_calls_));
    171           int64 rem_next_calls = next_calls_;
    172           std::vector<Tensor> out_tensors;
    173           bool end_of_sequence = false;
    174           while (rem_next_calls--) {
    175             TF_RETURN_IF_ERROR(query_connection_->GetNext(ctx, &out_tensors,
    176                                                           &end_of_sequence));
    177             out_tensors.clear();
    178           }
    179         } else {
    180           query_connection_initialized_ = false;
    181         }
    182         return Status::OK();
    183       }
    184 
    185      private:
    186       Status InitializeQueryConnection() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    187         query_connection_initialized_ = true;
    188         query_connection_ =
    189             sql::DriverManager::CreateQueryConnection(dataset()->driver_name_);
    190         Status s = query_connection_->Open(dataset()->data_source_name_,
    191                                            dataset()->query_,
    192                                            dataset()->output_types_);
    193         next_calls_ = 0;
    194         if (!s.ok()) {
    195           LOG(WARNING) << "Failed to connect to database: " << s;
    196           return s;
    197         }
    198         return Status::OK();
    199       }
    200 
    201       mutex mu_;
    202       // TODO(shivaniagrawal): explore ways to seek into a SQLite databases.
    203       int64 next_calls_ GUARDED_BY(mu_) = 0;
    204       std::unique_ptr<sql::QueryConnection> query_connection_ GUARDED_BY(mu_);
    205       bool query_connection_initialized_ GUARDED_BY(mu_) = false;
    206     };
    207     const string driver_name_;
    208     const string data_source_name_;
    209     const string query_;
    210     const DataTypeVector output_types_;
    211     const std::vector<PartialTensorShape> output_shapes_;
    212   };
    213   DataTypeVector output_types_;
    214   std::vector<PartialTensorShape> output_shapes_;
    215 };
    216 
    217 REGISTER_KERNEL_BUILDER(Name("ExperimentalSqlDataset").Device(DEVICE_CPU),
    218                         SqlDatasetOp);
    219 
    220 }  // namespace
    221 }  // namespace data
    222 }  // namespace tensorflow
    223