Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
     16 
     17 #include "tensorflow/core/example/feature.pb.h"
     18 #include "tensorflow/core/lib/strings/numbers.h"
     19 
     20 namespace tensorflow {
     21 namespace {
     22 
     23 constexpr size_t kBufferSize = 1024 * 1024;  // In bytes.
     24 const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
     25 
     26 bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
     27   if (partition.end_index() != -1 &&
     28       partition.end_index() < partition.start_index()) {
     29     return true;
     30   }
     31   return false;
     32 }
     33 
     34 Status ParseJson(StringPiece json, Json::Value* result) {
     35   Json::Reader reader;
     36   if (!reader.parse(json.ToString(), *result)) {
     37     return errors::Internal("Couldn't parse JSON response from BigQuery.");
     38   }
     39   return Status::OK();
     40 }
     41 
     42 Status ParseColumnType(const string& type,
     43                        BigQueryTableAccessor::ColumnType* enum_type) {
     44   if (type == "RECORD") {
     45     *enum_type = BigQueryTableAccessor::ColumnType::kRecord;
     46   } else if (type == "STRING") {
     47     *enum_type = BigQueryTableAccessor::ColumnType::kString;
     48   } else if (type == "BYTES") {
     49     *enum_type = BigQueryTableAccessor::ColumnType::kBytes;
     50   } else if (type == "INTEGER") {
     51     *enum_type = BigQueryTableAccessor::ColumnType::kInteger;
     52   } else if (type == "FLOAT") {
     53     *enum_type = BigQueryTableAccessor::ColumnType::kFloat;
     54   } else if (type == "BOOLEAN") {
     55     *enum_type = BigQueryTableAccessor::ColumnType::kBoolean;
     56   } else if (type == "TIMESTAMP") {
     57     *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp;
     58   } else if (type == "DATE") {
     59     *enum_type = BigQueryTableAccessor::ColumnType::kDate;
     60   } else if (type == "TIME") {
     61     *enum_type = BigQueryTableAccessor::ColumnType::kTime;
     62   } else if (type == "DATETIME") {
     63     *enum_type = BigQueryTableAccessor::ColumnType::kDatetime;
     64   } else {
     65     return errors::Internal(
     66         strings::StrCat("Could not parse column type ", type));
     67   }
     68   return Status::OK();
     69 }
     70 
     71 }  // namespace
     72 
     73 Status BigQueryTableAccessor::New(
     74     const string& project_id, const string& dataset_id, const string& table_id,
     75     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
     76     const std::vector<string>& columns, const BigQueryTablePartition& partition,
     77     std::unique_ptr<BigQueryTableAccessor>* accessor) {
     78   return New(project_id, dataset_id, table_id, timestamp_millis,
     79              row_buffer_size, end_point, columns, partition, nullptr, nullptr,
     80              accessor);
     81 }
     82 
     83 Status BigQueryTableAccessor::New(
     84     const string& project_id, const string& dataset_id, const string& table_id,
     85     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
     86     const std::vector<string>& columns, const BigQueryTablePartition& partition,
     87     std::unique_ptr<AuthProvider> auth_provider,
     88     std::unique_ptr<HttpRequest::Factory> http_request_factory,
     89     std::unique_ptr<BigQueryTableAccessor>* accessor) {
     90   if (timestamp_millis <= 0) {
     91     return errors::InvalidArgument(
     92         "Cannot use zero or negative timestamp to query a table.");
     93   }
     94   const string& big_query_end_point =
     95       end_point.empty() ? kBigQueryEndPoint : end_point;
     96   if (auth_provider == nullptr && http_request_factory == nullptr) {
     97     accessor->reset(new BigQueryTableAccessor(
     98         project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
     99         big_query_end_point, columns, partition));
    100   } else {
    101     accessor->reset(new BigQueryTableAccessor(
    102         project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
    103         big_query_end_point, columns, partition, std::move(auth_provider),
    104         std::move(http_request_factory)));
    105   }
    106   return (*accessor)->ReadSchema();
    107 }
    108 
    109 BigQueryTableAccessor::BigQueryTableAccessor(
    110     const string& project_id, const string& dataset_id, const string& table_id,
    111     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
    112     const std::vector<string>& columns, const BigQueryTablePartition& partition)
    113     : BigQueryTableAccessor(
    114           project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
    115           end_point, columns, partition,
    116           std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
    117           std::unique_ptr<HttpRequest::Factory>(
    118               new CurlHttpRequest::Factory())) {
    119   row_buffer_.resize(row_buffer_size);
    120 }
    121 
    122 BigQueryTableAccessor::BigQueryTableAccessor(
    123     const string& project_id, const string& dataset_id, const string& table_id,
    124     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
    125     const std::vector<string>& columns, const BigQueryTablePartition& partition,
    126     std::unique_ptr<AuthProvider> auth_provider,
    127     std::unique_ptr<HttpRequest::Factory> http_request_factory)
    128     : project_id_(project_id),
    129       dataset_id_(dataset_id),
    130       table_id_(table_id),
    131       timestamp_millis_(timestamp_millis),
    132       columns_(columns.begin(), columns.end()),
    133       bigquery_end_point_(end_point),
    134       partition_(partition),
    135       auth_provider_(std::move(auth_provider)),
    136       http_request_factory_(std::move(http_request_factory)) {
    137   row_buffer_.resize(row_buffer_size);
    138   Reset();
    139 }
    140 
    141 Status BigQueryTableAccessor::SetPartition(
    142     const BigQueryTablePartition& partition) {
    143   if (partition.start_index() < 0) {
    144     return errors::InvalidArgument("Start index cannot be negative.");
    145   }
    146   partition_ = partition;
    147   Reset();
    148   return Status::OK();
    149 }
    150 
    151 void BigQueryTableAccessor::Reset() {
    152   first_buffered_row_index_ = partition_.start_index();
    153   next_row_in_buffer_ = -1;
    154   next_page_token_ = "";
    155 }
    156 
    157 Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
    158   if (Done()) {
    159     return errors::OutOfRange("Reached end of table ", FullTableName());
    160   }
    161 
    162   // If the next row is already fetched and cached, return the row from the
    163   // buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
    164   if (next_row_in_buffer_ != -1 &&
    165       next_row_in_buffer_ < ComputeMaxResultsArg()) {
    166     *row_id = first_buffered_row_index_ + next_row_in_buffer_;
    167     *example = row_buffer_[next_row_in_buffer_];
    168     next_row_in_buffer_++;
    169   } else {
    170     string auth_token;
    171     TF_RETURN_IF_ERROR(
    172         AuthProvider::GetToken(auth_provider_.get(), &auth_token));
    173 
    174     std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
    175     std::vector<char> output_buffer;
    176     output_buffer.reserve(kBufferSize);
    177 
    178     // The first time that we access BigQuery there is no page token. After that
    179     // we use the page token (which returns rows faster).
    180     if (!next_page_token_.empty()) {
    181       request->SetUri(strings::StrCat(
    182           BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
    183           "&pageToken=", request->EscapeString(next_page_token_)));
    184       first_buffered_row_index_ += row_buffer_.size();
    185     } else {
    186       request->SetUri(strings::StrCat(
    187           BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
    188           "&startIndex=", first_buffered_row_index_));
    189     }
    190     request->AddAuthBearerHeader(auth_token);
    191     request->SetResultBuffer(&output_buffer);
    192     TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ",
    193                                     FullTableName());
    194 
    195     // Parse the returned row.
    196     StringPiece response_piece =
    197         StringPiece(&output_buffer[0], output_buffer.size());
    198     Json::Value root;
    199     TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
    200     for (unsigned int i = 0; i < root["rows"].size(); ++i) {
    201       row_buffer_[i].Clear();
    202       TF_RETURN_IF_ERROR(
    203           ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i]));
    204     }
    205 
    206     next_page_token_ = root["pageToken"].asString();
    207     *row_id = first_buffered_row_index_;
    208     *example = row_buffer_[0];
    209     next_row_in_buffer_ = 1;
    210   }
    211   return Status::OK();
    212 }
    213 
    214 int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
    215   if (partition_.end_index() == -1) {
    216     return row_buffer_.size();
    217   }
    218   if (IsPartitionEmpty(partition_)) {
    219     return 0;
    220   }
    221   return std::min(static_cast<int64>(row_buffer_.size()),
    222                   static_cast<int64>(partition_.end_index() -
    223                                      partition_.start_index() + 1));
    224 }
    225 
    226 Status BigQueryTableAccessor::ParseColumnValues(
    227     const Json::Value& value, const SchemaNode& root_schema_node,
    228     Example* example) {
    229   if (value.empty()) {
    230     return Status::OK();
    231   }
    232   if (value["f"].isNull()) {
    233     return Status::OK();
    234   }
    235   int value_index = 0;
    236   for (const auto& schema_node : root_schema_node.schema_nodes) {
    237     if (value["f"][value_index].isNull()) {
    238       value_index++;
    239       continue;
    240     }
    241 
    242     if (schema_node.type == ColumnType::kRecord) {
    243       TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"],
    244                                            schema_node, example));
    245     } else {
    246       // Append the column value only if user has requested the column.
    247       if (columns_.empty() ||
    248           columns_.find(schema_node.name) != columns_.end()) {
    249         TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name,
    250                                                 value["f"][value_index]["v"],
    251                                                 schema_node.type, example));
    252       }
    253     }
    254     value_index++;
    255   }
    256   return Status::OK();
    257 }
    258 
    259 Status BigQueryTableAccessor::ReadSchema() {
    260   string auth_token;
    261   TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
    262 
    263   // Send a request to read the schema.
    264   std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
    265   std::vector<char> output_buffer;
    266   output_buffer.reserve(kBufferSize);
    267   request->SetUri(BigQueryUriPrefix());
    268   request->AddAuthBearerHeader(auth_token);
    269   request->SetResultBuffer(&output_buffer);
    270   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ",
    271                                   FullTableName());
    272 
    273   // Parse the schema.
    274   StringPiece response_piece =
    275       StringPiece(&output_buffer[0], output_buffer.size());
    276 
    277   Json::Value root;
    278   TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
    279   const auto& columns = root["schema"]["fields"];
    280   string column_name_prefix = "";
    281   schema_root_ = {"", ColumnType::kNone};
    282   TF_RETURN_IF_ERROR(
    283       ExtractColumnType(columns, column_name_prefix, &schema_root_));
    284   if (root["numRows"].isNull()) {
    285     return errors::Internal("Number of rows cannot be extracted for table ",
    286                             FullTableName());
    287   }
    288   strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_);
    289   return Status::OK();
    290 }
    291 
    292 Status BigQueryTableAccessor::ExtractColumnType(
    293     const Json::Value& columns, const string& column_name_prefix,
    294     SchemaNode* root) {
    295   for (auto columns_it = columns.begin(); columns_it != columns.end();
    296        ++columns_it) {
    297     if ((*columns_it)["mode"].asString() == "REPEATED") {
    298       return errors::Unimplemented(strings::StrCat(
    299           "Tables with repeated columns are not supported: ", FullTableName()));
    300     }
    301     ColumnType type;
    302     const string current_column_name = strings::StrCat(
    303         column_name_prefix, (*columns_it)["name"].asString().c_str());
    304     TF_RETURN_IF_ERROR(
    305         ParseColumnType((*columns_it)["type"].asString().c_str(), &type));
    306     root->schema_nodes.emplace_back(current_column_name, type);
    307     if (type == ColumnType::kRecord) {
    308       const auto new_prefix = strings::StrCat(current_column_name, ".");
    309       TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix,
    310                                            &root->schema_nodes.back()));
    311     }
    312   }
    313   return Status::OK();
    314 }
    315 
    316 Status BigQueryTableAccessor::AppendValueToExample(
    317     const string& column_name, const Json::Value& column_value,
    318     const BigQueryTableAccessor::ColumnType type, Example* example) {
    319   if (column_value.isNull()) {
    320     return Status::OK();
    321   }
    322   auto& feature =
    323       (*example->mutable_features()->mutable_feature())[column_name];
    324 
    325   switch (type) {
    326     case BigQueryTableAccessor::ColumnType::kNone:
    327     case BigQueryTableAccessor::ColumnType::kRecord:
    328       return errors::Unimplemented("Cannot append type to an example.");
    329     case BigQueryTableAccessor::ColumnType::kTimestamp:
    330     case BigQueryTableAccessor::ColumnType::kDate:
    331     case BigQueryTableAccessor::ColumnType::kTime:
    332     case BigQueryTableAccessor::ColumnType::kDatetime:
    333     case BigQueryTableAccessor::ColumnType::kString:
    334     case BigQueryTableAccessor::ColumnType::kBytes:
    335       feature.mutable_bytes_list()->add_value(column_value.asString());
    336       break;
    337     case BigQueryTableAccessor::ColumnType::kBoolean:
    338       feature.mutable_int64_list()->add_value(
    339           column_value.asString() == "false" ? 0 : 1);
    340       break;
    341     case BigQueryTableAccessor::ColumnType::kInteger:
    342       int64 column_value_int64;
    343       if (!strings::safe_strto64(column_value.asString().c_str(),
    344                                  &column_value_int64)) {
    345         return errors::Internal("Cannot convert value to integer ",
    346                                 column_value.asString().c_str());
    347       }
    348       feature.mutable_int64_list()->add_value(column_value_int64);
    349       break;
    350     case BigQueryTableAccessor::ColumnType::kFloat:
    351       // BigQuery float is actually a double.
    352       double column_value_double;
    353       if (!strings::safe_strtod(column_value.asString().c_str(),
    354                                 &column_value_double)) {
    355         return errors::Internal("Cannot convert value to double: ",
    356                                 column_value.asString().c_str());
    357       }
    358       feature.mutable_float_list()->add_value(
    359           static_cast<float>(column_value_double));
    360       break;
    361   }
    362   return Status::OK();
    363 }
    364 
    365 string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
    366   CurlHttpRequest request;
    367   return strings::StrCat(bigquery_end_point_, "/projects/",
    368                          request.EscapeString(project_id_), "/datasets/",
    369                          request.EscapeString(dataset_id_), "/tables/",
    370                          request.EscapeString(table_id_), "/");
    371 }
    372 
    373 bool BigQueryTableAccessor::Done() {
    374   return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
    375          IsPartitionEmpty(partition_) ||
    376          (partition_.end_index() != -1 &&
    377           partition_.end_index() <
    378               first_buffered_row_index_ + next_row_in_buffer_);
    379 }
    380 
    381 }  // namespace tensorflow
    382