Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h"
     16 
     17 #include "tensorflow/core/example/feature.pb.h"
     18 #include "tensorflow/core/lib/strings/numbers.h"
     19 
     20 namespace tensorflow {
     21 namespace {
     22 
     23 constexpr size_t kBufferSize = 1024 * 1024;  // In bytes.
     24 const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
     25 
     26 bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
     27   if (partition.end_index() != -1 &&
     28       partition.end_index() < partition.start_index()) {
     29     return true;
     30   }
     31   return false;
     32 }
     33 
     34 Status ParseJson(StringPiece json, Json::Value* result) {
     35   Json::Reader reader;
     36   if (!reader.parse(string(json), *result)) {
     37     return errors::Internal("Couldn't parse JSON response from BigQuery.");
     38   }
     39   return Status::OK();
     40 }
     41 
     42 Status ParseColumnType(const string& type,
     43                        BigQueryTableAccessor::ColumnType* enum_type) {
     44   if (type == "RECORD") {
     45     *enum_type = BigQueryTableAccessor::ColumnType::kRecord;
     46   } else if (type == "STRING") {
     47     *enum_type = BigQueryTableAccessor::ColumnType::kString;
     48   } else if (type == "BYTES") {
     49     *enum_type = BigQueryTableAccessor::ColumnType::kBytes;
     50   } else if (type == "INTEGER") {
     51     *enum_type = BigQueryTableAccessor::ColumnType::kInteger;
     52   } else if (type == "FLOAT") {
     53     *enum_type = BigQueryTableAccessor::ColumnType::kFloat;
     54   } else if (type == "BOOLEAN") {
     55     *enum_type = BigQueryTableAccessor::ColumnType::kBoolean;
     56   } else if (type == "TIMESTAMP") {
     57     *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp;
     58   } else if (type == "DATE") {
     59     *enum_type = BigQueryTableAccessor::ColumnType::kDate;
     60   } else if (type == "TIME") {
     61     *enum_type = BigQueryTableAccessor::ColumnType::kTime;
     62   } else if (type == "DATETIME") {
     63     *enum_type = BigQueryTableAccessor::ColumnType::kDatetime;
     64   } else {
     65     return errors::Internal(
     66         strings::StrCat("Could not parse column type ", type));
     67   }
     68   return Status::OK();
     69 }
     70 
     71 }  // namespace
     72 
     73 Status BigQueryTableAccessor::New(
     74     const string& project_id, const string& dataset_id, const string& table_id,
     75     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
     76     const std::vector<string>& columns, const BigQueryTablePartition& partition,
     77     std::unique_ptr<BigQueryTableAccessor>* accessor) {
     78   return New(project_id, dataset_id, table_id, timestamp_millis,
     79              row_buffer_size, end_point, columns, partition, nullptr, nullptr,
     80              accessor);
     81 }
     82 
     83 Status BigQueryTableAccessor::New(
     84     const string& project_id, const string& dataset_id, const string& table_id,
     85     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
     86     const std::vector<string>& columns, const BigQueryTablePartition& partition,
     87     std::unique_ptr<AuthProvider> auth_provider,
     88     std::shared_ptr<HttpRequest::Factory> http_request_factory,
     89     std::unique_ptr<BigQueryTableAccessor>* accessor) {
     90   if (timestamp_millis <= 0) {
     91     return errors::InvalidArgument(
     92         "Cannot use zero or negative timestamp to query a table.");
     93   }
     94   const string& big_query_end_point =
     95       end_point.empty() ? kBigQueryEndPoint : end_point;
     96   if (auth_provider == nullptr && http_request_factory == nullptr) {
     97     http_request_factory = std::make_shared<CurlHttpRequest::Factory>();
     98     auto compute_engine_metadata_client =
     99         std::make_shared<ComputeEngineMetadataClient>(http_request_factory);
    100     auth_provider = std::unique_ptr<AuthProvider>(
    101         new GoogleAuthProvider(compute_engine_metadata_client));
    102   }
    103 
    104   accessor->reset(new BigQueryTableAccessor(
    105       project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
    106       big_query_end_point, columns, partition, std::move(auth_provider),
    107       std::move(http_request_factory)));
    108 
    109   return (*accessor)->ReadSchema();
    110 }
    111 
    112 BigQueryTableAccessor::BigQueryTableAccessor(
    113     const string& project_id, const string& dataset_id, const string& table_id,
    114     int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
    115     const std::vector<string>& columns, const BigQueryTablePartition& partition,
    116     std::unique_ptr<AuthProvider> auth_provider,
    117     std::shared_ptr<HttpRequest::Factory> http_request_factory)
    118     : project_id_(project_id),
    119       dataset_id_(dataset_id),
    120       table_id_(table_id),
    121       timestamp_millis_(timestamp_millis),
    122       columns_(columns.begin(), columns.end()),
    123       bigquery_end_point_(end_point),
    124       partition_(partition),
    125       auth_provider_(std::move(auth_provider)),
    126       http_request_factory_(std::move(http_request_factory)) {
    127   row_buffer_.resize(row_buffer_size);
    128   Reset();
    129 }
    130 
    131 Status BigQueryTableAccessor::SetPartition(
    132     const BigQueryTablePartition& partition) {
    133   if (partition.start_index() < 0) {
    134     return errors::InvalidArgument("Start index cannot be negative.");
    135   }
    136   partition_ = partition;
    137   Reset();
    138   return Status::OK();
    139 }
    140 
    141 void BigQueryTableAccessor::Reset() {
    142   first_buffered_row_index_ = partition_.start_index();
    143   next_row_in_buffer_ = -1;
    144   next_page_token_ = "";
    145 }
    146 
    147 Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
    148   if (Done()) {
    149     return errors::OutOfRange("Reached end of table ", FullTableName());
    150   }
    151 
    152   // If the next row is already fetched and cached, return the row from the
    153   // buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
    154   if (next_row_in_buffer_ != -1 &&
    155       next_row_in_buffer_ < ComputeMaxResultsArg()) {
    156     *row_id = first_buffered_row_index_ + next_row_in_buffer_;
    157     *example = row_buffer_[next_row_in_buffer_];
    158     next_row_in_buffer_++;
    159   } else {
    160     string auth_token;
    161     TF_RETURN_IF_ERROR(
    162         AuthProvider::GetToken(auth_provider_.get(), &auth_token));
    163 
    164     std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
    165     std::vector<char> output_buffer;
    166     output_buffer.reserve(kBufferSize);
    167 
    168     // The first time that we access BigQuery there is no page token. After that
    169     // we use the page token (which returns rows faster).
    170     if (!next_page_token_.empty()) {
    171       request->SetUri(strings::StrCat(
    172           BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
    173           "&pageToken=", request->EscapeString(next_page_token_)));
    174       first_buffered_row_index_ += row_buffer_.size();
    175     } else {
    176       request->SetUri(strings::StrCat(
    177           BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
    178           "&startIndex=", first_buffered_row_index_));
    179     }
    180     request->AddAuthBearerHeader(auth_token);
    181     request->SetResultBuffer(&output_buffer);
    182     TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ",
    183                                     FullTableName());
    184 
    185     // Parse the returned row.
    186     StringPiece response_piece =
    187         StringPiece(&output_buffer[0], output_buffer.size());
    188     Json::Value root;
    189     TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
    190     for (unsigned int i = 0; i < root["rows"].size(); ++i) {
    191       row_buffer_[i].Clear();
    192       TF_RETURN_IF_ERROR(
    193           ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i]));
    194     }
    195 
    196     next_page_token_ = root["pageToken"].asString();
    197     *row_id = first_buffered_row_index_;
    198     *example = row_buffer_[0];
    199     next_row_in_buffer_ = 1;
    200   }
    201   return Status::OK();
    202 }
    203 
    204 int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
    205   if (partition_.end_index() == -1) {
    206     return row_buffer_.size();
    207   }
    208   if (IsPartitionEmpty(partition_)) {
    209     return 0;
    210   }
    211   return std::min(static_cast<int64>(row_buffer_.size()),
    212                   static_cast<int64>(partition_.end_index() -
    213                                      partition_.start_index() + 1));
    214 }
    215 
    216 Status BigQueryTableAccessor::ParseColumnValues(
    217     const Json::Value& value, const SchemaNode& root_schema_node,
    218     Example* example) {
    219   if (value.empty()) {
    220     return Status::OK();
    221   }
    222   if (value["f"].isNull()) {
    223     return Status::OK();
    224   }
    225   int value_index = 0;
    226   for (const auto& schema_node : root_schema_node.schema_nodes) {
    227     if (value["f"][value_index].isNull()) {
    228       value_index++;
    229       continue;
    230     }
    231 
    232     if (schema_node.type == ColumnType::kRecord) {
    233       TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"],
    234                                            schema_node, example));
    235     } else {
    236       // Append the column value only if user has requested the column.
    237       if (columns_.empty() ||
    238           columns_.find(schema_node.name) != columns_.end()) {
    239         TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name,
    240                                                 value["f"][value_index]["v"],
    241                                                 schema_node.type, example));
    242       }
    243     }
    244     value_index++;
    245   }
    246   return Status::OK();
    247 }
    248 
    249 Status BigQueryTableAccessor::ReadSchema() {
    250   string auth_token;
    251   TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
    252 
    253   // Send a request to read the schema.
    254   std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
    255   std::vector<char> output_buffer;
    256   output_buffer.reserve(kBufferSize);
    257   request->SetUri(BigQueryUriPrefix());
    258   request->AddAuthBearerHeader(auth_token);
    259   request->SetResultBuffer(&output_buffer);
    260   TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ",
    261                                   FullTableName());
    262 
    263   // Parse the schema.
    264   StringPiece response_piece =
    265       StringPiece(&output_buffer[0], output_buffer.size());
    266 
    267   Json::Value root;
    268   TF_RETURN_IF_ERROR(ParseJson(response_piece, &root));
    269   const auto& columns = root["schema"]["fields"];
    270   string column_name_prefix = "";
    271   schema_root_ = {"", ColumnType::kNone};
    272   TF_RETURN_IF_ERROR(
    273       ExtractColumnType(columns, column_name_prefix, &schema_root_));
    274   if (root["numRows"].isNull()) {
    275     return errors::Internal("Number of rows cannot be extracted for table ",
    276                             FullTableName());
    277   }
    278   strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_);
    279   return Status::OK();
    280 }
    281 
    282 Status BigQueryTableAccessor::ExtractColumnType(
    283     const Json::Value& columns, const string& column_name_prefix,
    284     SchemaNode* root) {
    285   for (auto columns_it = columns.begin(); columns_it != columns.end();
    286        ++columns_it) {
    287     if ((*columns_it)["mode"].asString() == "REPEATED") {
    288       return errors::Unimplemented(strings::StrCat(
    289           "Tables with repeated columns are not supported: ", FullTableName()));
    290     }
    291     ColumnType type;
    292     const string current_column_name = strings::StrCat(
    293         column_name_prefix, (*columns_it)["name"].asString().c_str());
    294     TF_RETURN_IF_ERROR(
    295         ParseColumnType((*columns_it)["type"].asString().c_str(), &type));
    296     root->schema_nodes.emplace_back(current_column_name, type);
    297     if (type == ColumnType::kRecord) {
    298       const auto new_prefix = strings::StrCat(current_column_name, ".");
    299       TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix,
    300                                            &root->schema_nodes.back()));
    301     }
    302   }
    303   return Status::OK();
    304 }
    305 
    306 Status BigQueryTableAccessor::AppendValueToExample(
    307     const string& column_name, const Json::Value& column_value,
    308     const BigQueryTableAccessor::ColumnType type, Example* example) {
    309   if (column_value.isNull()) {
    310     return Status::OK();
    311   }
    312   auto& feature =
    313       (*example->mutable_features()->mutable_feature())[column_name];
    314 
    315   switch (type) {
    316     case BigQueryTableAccessor::ColumnType::kNone:
    317     case BigQueryTableAccessor::ColumnType::kRecord:
    318       return errors::Unimplemented("Cannot append type to an example.");
    319     case BigQueryTableAccessor::ColumnType::kTimestamp:
    320     case BigQueryTableAccessor::ColumnType::kDate:
    321     case BigQueryTableAccessor::ColumnType::kTime:
    322     case BigQueryTableAccessor::ColumnType::kDatetime:
    323     case BigQueryTableAccessor::ColumnType::kString:
    324     case BigQueryTableAccessor::ColumnType::kBytes:
    325       feature.mutable_bytes_list()->add_value(column_value.asString());
    326       break;
    327     case BigQueryTableAccessor::ColumnType::kBoolean:
    328       feature.mutable_int64_list()->add_value(
    329           column_value.asString() == "false" ? 0 : 1);
    330       break;
    331     case BigQueryTableAccessor::ColumnType::kInteger:
    332       int64 column_value_int64;
    333       if (!strings::safe_strto64(column_value.asString().c_str(),
    334                                  &column_value_int64)) {
    335         return errors::Internal("Cannot convert value to integer ",
    336                                 column_value.asString().c_str());
    337       }
    338       feature.mutable_int64_list()->add_value(column_value_int64);
    339       break;
    340     case BigQueryTableAccessor::ColumnType::kFloat:
    341       // BigQuery float is actually a double.
    342       double column_value_double;
    343       if (!strings::safe_strtod(column_value.asString().c_str(),
    344                                 &column_value_double)) {
    345         return errors::Internal("Cannot convert value to double: ",
    346                                 column_value.asString().c_str());
    347       }
    348       feature.mutable_float_list()->add_value(
    349           static_cast<float>(column_value_double));
    350       break;
    351   }
    352   return Status::OK();
    353 }
    354 
    355 string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
    356   CurlHttpRequest request;
    357   return strings::StrCat(bigquery_end_point_, "/projects/",
    358                          request.EscapeString(project_id_), "/datasets/",
    359                          request.EscapeString(dataset_id_), "/tables/",
    360                          request.EscapeString(table_id_), "/");
    361 }
    362 
    363 bool BigQueryTableAccessor::Done() {
    364   return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
    365          IsPartitionEmpty(partition_) ||
    366          (partition_.end_index() != -1 &&
    367           partition_.end_index() <
    368               first_buffered_row_index_ + next_row_in_buffer_);
    369 }
    370 
    371 }  // namespace tensorflow
    372