1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" 16 17 #include "tensorflow/core/example/feature.pb.h" 18 #include "tensorflow/core/lib/strings/numbers.h" 19 20 namespace tensorflow { 21 namespace { 22 23 constexpr size_t kBufferSize = 1024 * 1024; // In bytes. 24 const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; 25 26 bool IsPartitionEmpty(const BigQueryTablePartition& partition) { 27 if (partition.end_index() != -1 && 28 partition.end_index() < partition.start_index()) { 29 return true; 30 } 31 return false; 32 } 33 34 Status ParseJson(StringPiece json, Json::Value* result) { 35 Json::Reader reader; 36 if (!reader.parse(json.ToString(), *result)) { 37 return errors::Internal("Couldn't parse JSON response from BigQuery."); 38 } 39 return Status::OK(); 40 } 41 42 Status ParseColumnType(const string& type, 43 BigQueryTableAccessor::ColumnType* enum_type) { 44 if (type == "RECORD") { 45 *enum_type = BigQueryTableAccessor::ColumnType::kRecord; 46 } else if (type == "STRING") { 47 *enum_type = BigQueryTableAccessor::ColumnType::kString; 48 } else if (type == "BYTES") { 49 *enum_type = BigQueryTableAccessor::ColumnType::kBytes; 50 } else if (type == "INTEGER") { 51 *enum_type = BigQueryTableAccessor::ColumnType::kInteger; 52 } else if (type == "FLOAT") { 53 *enum_type = BigQueryTableAccessor::ColumnType::kFloat; 54 } else if (type == "BOOLEAN") { 55 *enum_type = BigQueryTableAccessor::ColumnType::kBoolean; 56 } else if (type == "TIMESTAMP") { 57 *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp; 58 } else if (type == "DATE") { 59 *enum_type = BigQueryTableAccessor::ColumnType::kDate; 60 } else if (type == "TIME") { 61 *enum_type = BigQueryTableAccessor::ColumnType::kTime; 62 } else if (type == "DATETIME") { 63 *enum_type = BigQueryTableAccessor::ColumnType::kDatetime; 64 } else { 65 return errors::Internal( 66 strings::StrCat("Could not parse column type ", type)); 67 } 68 return Status::OK(); 69 } 70 71 } // namespace 72 73 Status BigQueryTableAccessor::New( 74 const string& project_id, const string& dataset_id, const string& table_id, 75 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 76 const std::vector<string>& columns, const BigQueryTablePartition& partition, 77 std::unique_ptr<BigQueryTableAccessor>* accessor) { 78 return New(project_id, dataset_id, table_id, timestamp_millis, 79 row_buffer_size, end_point, columns, partition, nullptr, nullptr, 80 accessor); 81 } 82 83 Status BigQueryTableAccessor::New( 84 const string& project_id, const string& dataset_id, const string& table_id, 85 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 86 const std::vector<string>& columns, const BigQueryTablePartition& partition, 87 std::unique_ptr<AuthProvider> auth_provider, 88 std::unique_ptr<HttpRequest::Factory> http_request_factory, 89 std::unique_ptr<BigQueryTableAccessor>* accessor) { 90 if (timestamp_millis <= 0) { 91 return errors::InvalidArgument( 92 "Cannot use zero or negative timestamp to query a table."); 93 } 94 const string& big_query_end_point = 95 end_point.empty() ? kBigQueryEndPoint : end_point; 96 if (auth_provider == nullptr && http_request_factory == nullptr) { 97 accessor->reset(new BigQueryTableAccessor( 98 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, 99 big_query_end_point, columns, partition)); 100 } else { 101 accessor->reset(new BigQueryTableAccessor( 102 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, 103 big_query_end_point, columns, partition, std::move(auth_provider), 104 std::move(http_request_factory))); 105 } 106 return (*accessor)->ReadSchema(); 107 } 108 109 BigQueryTableAccessor::BigQueryTableAccessor( 110 const string& project_id, const string& dataset_id, const string& table_id, 111 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 112 const std::vector<string>& columns, const BigQueryTablePartition& partition) 113 : BigQueryTableAccessor( 114 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, 115 end_point, columns, partition, 116 std::unique_ptr<AuthProvider>(new GoogleAuthProvider()), 117 std::unique_ptr<HttpRequest::Factory>( 118 new CurlHttpRequest::Factory())) { 119 row_buffer_.resize(row_buffer_size); 120 } 121 122 BigQueryTableAccessor::BigQueryTableAccessor( 123 const string& project_id, const string& dataset_id, const string& table_id, 124 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 125 const std::vector<string>& columns, const BigQueryTablePartition& partition, 126 std::unique_ptr<AuthProvider> auth_provider, 127 std::unique_ptr<HttpRequest::Factory> http_request_factory) 128 : project_id_(project_id), 129 dataset_id_(dataset_id), 130 table_id_(table_id), 131 timestamp_millis_(timestamp_millis), 132 columns_(columns.begin(), columns.end()), 133 bigquery_end_point_(end_point), 134 partition_(partition), 135 auth_provider_(std::move(auth_provider)), 136 http_request_factory_(std::move(http_request_factory)) { 137 row_buffer_.resize(row_buffer_size); 138 Reset(); 139 } 140 141 Status BigQueryTableAccessor::SetPartition( 142 const BigQueryTablePartition& partition) { 143 if (partition.start_index() < 0) { 144 return errors::InvalidArgument("Start index cannot be negative."); 145 } 146 partition_ = partition; 147 Reset(); 148 return Status::OK(); 149 } 150 151 void BigQueryTableAccessor::Reset() { 152 first_buffered_row_index_ = partition_.start_index(); 153 next_row_in_buffer_ = -1; 154 next_page_token_ = ""; 155 } 156 157 Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { 158 if (Done()) { 159 return errors::OutOfRange("Reached end of table ", FullTableName()); 160 } 161 162 // If the next row is already fetched and cached, return the row from the 163 // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. 164 if (next_row_in_buffer_ != -1 && 165 next_row_in_buffer_ < ComputeMaxResultsArg()) { 166 *row_id = first_buffered_row_index_ + next_row_in_buffer_; 167 *example = row_buffer_[next_row_in_buffer_]; 168 next_row_in_buffer_++; 169 } else { 170 string auth_token; 171 TF_RETURN_IF_ERROR( 172 AuthProvider::GetToken(auth_provider_.get(), &auth_token)); 173 174 std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); 175 std::vector<char> output_buffer; 176 output_buffer.reserve(kBufferSize); 177 178 // The first time that we access BigQuery there is no page token. After that 179 // we use the page token (which returns rows faster). 180 if (!next_page_token_.empty()) { 181 request->SetUri(strings::StrCat( 182 BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), 183 "&pageToken=", request->EscapeString(next_page_token_))); 184 first_buffered_row_index_ += row_buffer_.size(); 185 } else { 186 request->SetUri(strings::StrCat( 187 BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), 188 "&startIndex=", first_buffered_row_index_)); 189 } 190 request->AddAuthBearerHeader(auth_token); 191 request->SetResultBuffer(&output_buffer); 192 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", 193 FullTableName()); 194 195 // Parse the returned row. 196 StringPiece response_piece = 197 StringPiece(&output_buffer[0], output_buffer.size()); 198 Json::Value root; 199 TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); 200 for (unsigned int i = 0; i < root["rows"].size(); ++i) { 201 row_buffer_[i].Clear(); 202 TF_RETURN_IF_ERROR( 203 ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i])); 204 } 205 206 next_page_token_ = root["pageToken"].asString(); 207 *row_id = first_buffered_row_index_; 208 *example = row_buffer_[0]; 209 next_row_in_buffer_ = 1; 210 } 211 return Status::OK(); 212 } 213 214 int64 BigQueryTableAccessor::ComputeMaxResultsArg() { 215 if (partition_.end_index() == -1) { 216 return row_buffer_.size(); 217 } 218 if (IsPartitionEmpty(partition_)) { 219 return 0; 220 } 221 return std::min(static_cast<int64>(row_buffer_.size()), 222 static_cast<int64>(partition_.end_index() - 223 partition_.start_index() + 1)); 224 } 225 226 Status BigQueryTableAccessor::ParseColumnValues( 227 const Json::Value& value, const SchemaNode& root_schema_node, 228 Example* example) { 229 if (value.empty()) { 230 return Status::OK(); 231 } 232 if (value["f"].isNull()) { 233 return Status::OK(); 234 } 235 int value_index = 0; 236 for (const auto& schema_node : root_schema_node.schema_nodes) { 237 if (value["f"][value_index].isNull()) { 238 value_index++; 239 continue; 240 } 241 242 if (schema_node.type == ColumnType::kRecord) { 243 TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"], 244 schema_node, example)); 245 } else { 246 // Append the column value only if user has requested the column. 247 if (columns_.empty() || 248 columns_.find(schema_node.name) != columns_.end()) { 249 TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name, 250 value["f"][value_index]["v"], 251 schema_node.type, example)); 252 } 253 } 254 value_index++; 255 } 256 return Status::OK(); 257 } 258 259 Status BigQueryTableAccessor::ReadSchema() { 260 string auth_token; 261 TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); 262 263 // Send a request to read the schema. 264 std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); 265 std::vector<char> output_buffer; 266 output_buffer.reserve(kBufferSize); 267 request->SetUri(BigQueryUriPrefix()); 268 request->AddAuthBearerHeader(auth_token); 269 request->SetResultBuffer(&output_buffer); 270 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", 271 FullTableName()); 272 273 // Parse the schema. 274 StringPiece response_piece = 275 StringPiece(&output_buffer[0], output_buffer.size()); 276 277 Json::Value root; 278 TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); 279 const auto& columns = root["schema"]["fields"]; 280 string column_name_prefix = ""; 281 schema_root_ = {"", ColumnType::kNone}; 282 TF_RETURN_IF_ERROR( 283 ExtractColumnType(columns, column_name_prefix, &schema_root_)); 284 if (root["numRows"].isNull()) { 285 return errors::Internal("Number of rows cannot be extracted for table ", 286 FullTableName()); 287 } 288 strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_); 289 return Status::OK(); 290 } 291 292 Status BigQueryTableAccessor::ExtractColumnType( 293 const Json::Value& columns, const string& column_name_prefix, 294 SchemaNode* root) { 295 for (auto columns_it = columns.begin(); columns_it != columns.end(); 296 ++columns_it) { 297 if ((*columns_it)["mode"].asString() == "REPEATED") { 298 return errors::Unimplemented(strings::StrCat( 299 "Tables with repeated columns are not supported: ", FullTableName())); 300 } 301 ColumnType type; 302 const string current_column_name = strings::StrCat( 303 column_name_prefix, (*columns_it)["name"].asString().c_str()); 304 TF_RETURN_IF_ERROR( 305 ParseColumnType((*columns_it)["type"].asString().c_str(), &type)); 306 root->schema_nodes.emplace_back(current_column_name, type); 307 if (type == ColumnType::kRecord) { 308 const auto new_prefix = strings::StrCat(current_column_name, "."); 309 TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix, 310 &root->schema_nodes.back())); 311 } 312 } 313 return Status::OK(); 314 } 315 316 Status BigQueryTableAccessor::AppendValueToExample( 317 const string& column_name, const Json::Value& column_value, 318 const BigQueryTableAccessor::ColumnType type, Example* example) { 319 if (column_value.isNull()) { 320 return Status::OK(); 321 } 322 auto& feature = 323 (*example->mutable_features()->mutable_feature())[column_name]; 324 325 switch (type) { 326 case BigQueryTableAccessor::ColumnType::kNone: 327 case BigQueryTableAccessor::ColumnType::kRecord: 328 return errors::Unimplemented("Cannot append type to an example."); 329 case BigQueryTableAccessor::ColumnType::kTimestamp: 330 case BigQueryTableAccessor::ColumnType::kDate: 331 case BigQueryTableAccessor::ColumnType::kTime: 332 case BigQueryTableAccessor::ColumnType::kDatetime: 333 case BigQueryTableAccessor::ColumnType::kString: 334 case BigQueryTableAccessor::ColumnType::kBytes: 335 feature.mutable_bytes_list()->add_value(column_value.asString()); 336 break; 337 case BigQueryTableAccessor::ColumnType::kBoolean: 338 feature.mutable_int64_list()->add_value( 339 column_value.asString() == "false" ? 0 : 1); 340 break; 341 case BigQueryTableAccessor::ColumnType::kInteger: 342 int64 column_value_int64; 343 if (!strings::safe_strto64(column_value.asString().c_str(), 344 &column_value_int64)) { 345 return errors::Internal("Cannot convert value to integer ", 346 column_value.asString().c_str()); 347 } 348 feature.mutable_int64_list()->add_value(column_value_int64); 349 break; 350 case BigQueryTableAccessor::ColumnType::kFloat: 351 // BigQuery float is actually a double. 352 double column_value_double; 353 if (!strings::safe_strtod(column_value.asString().c_str(), 354 &column_value_double)) { 355 return errors::Internal("Cannot convert value to double: ", 356 column_value.asString().c_str()); 357 } 358 feature.mutable_float_list()->add_value( 359 static_cast<float>(column_value_double)); 360 break; 361 } 362 return Status::OK(); 363 } 364 365 string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { 366 CurlHttpRequest request; 367 return strings::StrCat(bigquery_end_point_, "/projects/", 368 request.EscapeString(project_id_), "/datasets/", 369 request.EscapeString(dataset_id_), "/tables/", 370 request.EscapeString(table_id_), "/"); 371 } 372 373 bool BigQueryTableAccessor::Done() { 374 return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || 375 IsPartitionEmpty(partition_) || 376 (partition_.end_index() != -1 && 377 partition_.end_index() < 378 first_buffered_row_index_ + next_row_in_buffer_); 379 } 380 381 } // namespace tensorflow 382