1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h" 16 17 #include "tensorflow/core/example/feature.pb.h" 18 #include "tensorflow/core/lib/strings/numbers.h" 19 20 namespace tensorflow { 21 namespace { 22 23 constexpr size_t kBufferSize = 1024 * 1024; // In bytes. 24 const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; 25 26 bool IsPartitionEmpty(const BigQueryTablePartition& partition) { 27 if (partition.end_index() != -1 && 28 partition.end_index() < partition.start_index()) { 29 return true; 30 } 31 return false; 32 } 33 34 Status ParseJson(StringPiece json, Json::Value* result) { 35 Json::Reader reader; 36 if (!reader.parse(string(json), *result)) { 37 return errors::Internal("Couldn't parse JSON response from BigQuery."); 38 } 39 return Status::OK(); 40 } 41 42 Status ParseColumnType(const string& type, 43 BigQueryTableAccessor::ColumnType* enum_type) { 44 if (type == "RECORD") { 45 *enum_type = BigQueryTableAccessor::ColumnType::kRecord; 46 } else if (type == "STRING") { 47 *enum_type = BigQueryTableAccessor::ColumnType::kString; 48 } else if (type == "BYTES") { 49 *enum_type = BigQueryTableAccessor::ColumnType::kBytes; 50 } else if (type == "INTEGER") { 51 *enum_type = BigQueryTableAccessor::ColumnType::kInteger; 52 } else if (type == "FLOAT") { 53 *enum_type = BigQueryTableAccessor::ColumnType::kFloat; 54 } else if (type == "BOOLEAN") { 55 *enum_type = BigQueryTableAccessor::ColumnType::kBoolean; 56 } else if (type == "TIMESTAMP") { 57 *enum_type = BigQueryTableAccessor::ColumnType::kTimestamp; 58 } else if (type == "DATE") { 59 *enum_type = BigQueryTableAccessor::ColumnType::kDate; 60 } else if (type == "TIME") { 61 *enum_type = BigQueryTableAccessor::ColumnType::kTime; 62 } else if (type == "DATETIME") { 63 *enum_type = BigQueryTableAccessor::ColumnType::kDatetime; 64 } else { 65 return errors::Internal( 66 strings::StrCat("Could not parse column type ", type)); 67 } 68 return Status::OK(); 69 } 70 71 } // namespace 72 73 Status BigQueryTableAccessor::New( 74 const string& project_id, const string& dataset_id, const string& table_id, 75 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 76 const std::vector<string>& columns, const BigQueryTablePartition& partition, 77 std::unique_ptr<BigQueryTableAccessor>* accessor) { 78 return New(project_id, dataset_id, table_id, timestamp_millis, 79 row_buffer_size, end_point, columns, partition, nullptr, nullptr, 80 accessor); 81 } 82 83 Status BigQueryTableAccessor::New( 84 const string& project_id, const string& dataset_id, const string& table_id, 85 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 86 const std::vector<string>& columns, const BigQueryTablePartition& partition, 87 std::unique_ptr<AuthProvider> auth_provider, 88 std::shared_ptr<HttpRequest::Factory> http_request_factory, 89 std::unique_ptr<BigQueryTableAccessor>* accessor) { 90 if (timestamp_millis <= 0) { 91 return errors::InvalidArgument( 92 "Cannot use zero or negative timestamp to query a table."); 93 } 94 const string& big_query_end_point = 95 end_point.empty() ? kBigQueryEndPoint : end_point; 96 if (auth_provider == nullptr && http_request_factory == nullptr) { 97 http_request_factory = std::make_shared<CurlHttpRequest::Factory>(); 98 auto compute_engine_metadata_client = 99 std::make_shared<ComputeEngineMetadataClient>(http_request_factory); 100 auth_provider = std::unique_ptr<AuthProvider>( 101 new GoogleAuthProvider(compute_engine_metadata_client)); 102 } 103 104 accessor->reset(new BigQueryTableAccessor( 105 project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, 106 big_query_end_point, columns, partition, std::move(auth_provider), 107 std::move(http_request_factory))); 108 109 return (*accessor)->ReadSchema(); 110 } 111 112 BigQueryTableAccessor::BigQueryTableAccessor( 113 const string& project_id, const string& dataset_id, const string& table_id, 114 int64 timestamp_millis, int64 row_buffer_size, const string& end_point, 115 const std::vector<string>& columns, const BigQueryTablePartition& partition, 116 std::unique_ptr<AuthProvider> auth_provider, 117 std::shared_ptr<HttpRequest::Factory> http_request_factory) 118 : project_id_(project_id), 119 dataset_id_(dataset_id), 120 table_id_(table_id), 121 timestamp_millis_(timestamp_millis), 122 columns_(columns.begin(), columns.end()), 123 bigquery_end_point_(end_point), 124 partition_(partition), 125 auth_provider_(std::move(auth_provider)), 126 http_request_factory_(std::move(http_request_factory)) { 127 row_buffer_.resize(row_buffer_size); 128 Reset(); 129 } 130 131 Status BigQueryTableAccessor::SetPartition( 132 const BigQueryTablePartition& partition) { 133 if (partition.start_index() < 0) { 134 return errors::InvalidArgument("Start index cannot be negative."); 135 } 136 partition_ = partition; 137 Reset(); 138 return Status::OK(); 139 } 140 141 void BigQueryTableAccessor::Reset() { 142 first_buffered_row_index_ = partition_.start_index(); 143 next_row_in_buffer_ = -1; 144 next_page_token_ = ""; 145 } 146 147 Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { 148 if (Done()) { 149 return errors::OutOfRange("Reached end of table ", FullTableName()); 150 } 151 152 // If the next row is already fetched and cached, return the row from the 153 // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. 154 if (next_row_in_buffer_ != -1 && 155 next_row_in_buffer_ < ComputeMaxResultsArg()) { 156 *row_id = first_buffered_row_index_ + next_row_in_buffer_; 157 *example = row_buffer_[next_row_in_buffer_]; 158 next_row_in_buffer_++; 159 } else { 160 string auth_token; 161 TF_RETURN_IF_ERROR( 162 AuthProvider::GetToken(auth_provider_.get(), &auth_token)); 163 164 std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); 165 std::vector<char> output_buffer; 166 output_buffer.reserve(kBufferSize); 167 168 // The first time that we access BigQuery there is no page token. After that 169 // we use the page token (which returns rows faster). 170 if (!next_page_token_.empty()) { 171 request->SetUri(strings::StrCat( 172 BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), 173 "&pageToken=", request->EscapeString(next_page_token_))); 174 first_buffered_row_index_ += row_buffer_.size(); 175 } else { 176 request->SetUri(strings::StrCat( 177 BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), 178 "&startIndex=", first_buffered_row_index_)); 179 } 180 request->AddAuthBearerHeader(auth_token); 181 request->SetResultBuffer(&output_buffer); 182 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", 183 FullTableName()); 184 185 // Parse the returned row. 186 StringPiece response_piece = 187 StringPiece(&output_buffer[0], output_buffer.size()); 188 Json::Value root; 189 TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); 190 for (unsigned int i = 0; i < root["rows"].size(); ++i) { 191 row_buffer_[i].Clear(); 192 TF_RETURN_IF_ERROR( 193 ParseColumnValues(root["rows"][i], schema_root_, &row_buffer_[i])); 194 } 195 196 next_page_token_ = root["pageToken"].asString(); 197 *row_id = first_buffered_row_index_; 198 *example = row_buffer_[0]; 199 next_row_in_buffer_ = 1; 200 } 201 return Status::OK(); 202 } 203 204 int64 BigQueryTableAccessor::ComputeMaxResultsArg() { 205 if (partition_.end_index() == -1) { 206 return row_buffer_.size(); 207 } 208 if (IsPartitionEmpty(partition_)) { 209 return 0; 210 } 211 return std::min(static_cast<int64>(row_buffer_.size()), 212 static_cast<int64>(partition_.end_index() - 213 partition_.start_index() + 1)); 214 } 215 216 Status BigQueryTableAccessor::ParseColumnValues( 217 const Json::Value& value, const SchemaNode& root_schema_node, 218 Example* example) { 219 if (value.empty()) { 220 return Status::OK(); 221 } 222 if (value["f"].isNull()) { 223 return Status::OK(); 224 } 225 int value_index = 0; 226 for (const auto& schema_node : root_schema_node.schema_nodes) { 227 if (value["f"][value_index].isNull()) { 228 value_index++; 229 continue; 230 } 231 232 if (schema_node.type == ColumnType::kRecord) { 233 TF_RETURN_IF_ERROR(ParseColumnValues(value["f"][value_index]["v"], 234 schema_node, example)); 235 } else { 236 // Append the column value only if user has requested the column. 237 if (columns_.empty() || 238 columns_.find(schema_node.name) != columns_.end()) { 239 TF_RETURN_IF_ERROR(AppendValueToExample(schema_node.name, 240 value["f"][value_index]["v"], 241 schema_node.type, example)); 242 } 243 } 244 value_index++; 245 } 246 return Status::OK(); 247 } 248 249 Status BigQueryTableAccessor::ReadSchema() { 250 string auth_token; 251 TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); 252 253 // Send a request to read the schema. 254 std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); 255 std::vector<char> output_buffer; 256 output_buffer.reserve(kBufferSize); 257 request->SetUri(BigQueryUriPrefix()); 258 request->AddAuthBearerHeader(auth_token); 259 request->SetResultBuffer(&output_buffer); 260 TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", 261 FullTableName()); 262 263 // Parse the schema. 264 StringPiece response_piece = 265 StringPiece(&output_buffer[0], output_buffer.size()); 266 267 Json::Value root; 268 TF_RETURN_IF_ERROR(ParseJson(response_piece, &root)); 269 const auto& columns = root["schema"]["fields"]; 270 string column_name_prefix = ""; 271 schema_root_ = {"", ColumnType::kNone}; 272 TF_RETURN_IF_ERROR( 273 ExtractColumnType(columns, column_name_prefix, &schema_root_)); 274 if (root["numRows"].isNull()) { 275 return errors::Internal("Number of rows cannot be extracted for table ", 276 FullTableName()); 277 } 278 strings::safe_strto64(root["numRows"].asString().c_str(), &total_num_rows_); 279 return Status::OK(); 280 } 281 282 Status BigQueryTableAccessor::ExtractColumnType( 283 const Json::Value& columns, const string& column_name_prefix, 284 SchemaNode* root) { 285 for (auto columns_it = columns.begin(); columns_it != columns.end(); 286 ++columns_it) { 287 if ((*columns_it)["mode"].asString() == "REPEATED") { 288 return errors::Unimplemented(strings::StrCat( 289 "Tables with repeated columns are not supported: ", FullTableName())); 290 } 291 ColumnType type; 292 const string current_column_name = strings::StrCat( 293 column_name_prefix, (*columns_it)["name"].asString().c_str()); 294 TF_RETURN_IF_ERROR( 295 ParseColumnType((*columns_it)["type"].asString().c_str(), &type)); 296 root->schema_nodes.emplace_back(current_column_name, type); 297 if (type == ColumnType::kRecord) { 298 const auto new_prefix = strings::StrCat(current_column_name, "."); 299 TF_RETURN_IF_ERROR(ExtractColumnType((*columns_it)["fields"], new_prefix, 300 &root->schema_nodes.back())); 301 } 302 } 303 return Status::OK(); 304 } 305 306 Status BigQueryTableAccessor::AppendValueToExample( 307 const string& column_name, const Json::Value& column_value, 308 const BigQueryTableAccessor::ColumnType type, Example* example) { 309 if (column_value.isNull()) { 310 return Status::OK(); 311 } 312 auto& feature = 313 (*example->mutable_features()->mutable_feature())[column_name]; 314 315 switch (type) { 316 case BigQueryTableAccessor::ColumnType::kNone: 317 case BigQueryTableAccessor::ColumnType::kRecord: 318 return errors::Unimplemented("Cannot append type to an example."); 319 case BigQueryTableAccessor::ColumnType::kTimestamp: 320 case BigQueryTableAccessor::ColumnType::kDate: 321 case BigQueryTableAccessor::ColumnType::kTime: 322 case BigQueryTableAccessor::ColumnType::kDatetime: 323 case BigQueryTableAccessor::ColumnType::kString: 324 case BigQueryTableAccessor::ColumnType::kBytes: 325 feature.mutable_bytes_list()->add_value(column_value.asString()); 326 break; 327 case BigQueryTableAccessor::ColumnType::kBoolean: 328 feature.mutable_int64_list()->add_value( 329 column_value.asString() == "false" ? 0 : 1); 330 break; 331 case BigQueryTableAccessor::ColumnType::kInteger: 332 int64 column_value_int64; 333 if (!strings::safe_strto64(column_value.asString().c_str(), 334 &column_value_int64)) { 335 return errors::Internal("Cannot convert value to integer ", 336 column_value.asString().c_str()); 337 } 338 feature.mutable_int64_list()->add_value(column_value_int64); 339 break; 340 case BigQueryTableAccessor::ColumnType::kFloat: 341 // BigQuery float is actually a double. 342 double column_value_double; 343 if (!strings::safe_strtod(column_value.asString().c_str(), 344 &column_value_double)) { 345 return errors::Internal("Cannot convert value to double: ", 346 column_value.asString().c_str()); 347 } 348 feature.mutable_float_list()->add_value( 349 static_cast<float>(column_value_double)); 350 break; 351 } 352 return Status::OK(); 353 } 354 355 string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { 356 CurlHttpRequest request; 357 return strings::StrCat(bigquery_end_point_, "/projects/", 358 request.EscapeString(project_id_), "/datasets/", 359 request.EscapeString(dataset_id_), "/tables/", 360 request.EscapeString(table_id_), "/"); 361 } 362 363 bool BigQueryTableAccessor::Done() { 364 return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || 365 IsPartitionEmpty(partition_) || 366 (partition_.end_index() != -1 && 367 partition_.end_index() < 368 first_buffered_row_index_ + next_row_in_buffer_); 369 } 370 371 } // namespace tensorflow 372