1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/util/example_proto_fast_parsing.h" 16 17 #include <vector> 18 19 #include "tensorflow/core/example/example.pb.h" 20 #include "tensorflow/core/example/feature.pb_text.h" 21 #include "tensorflow/core/framework/numeric_op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/lib/core/blocking_counter.h" 25 #include "tensorflow/core/lib/core/casts.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/threadpool.h" 28 #include "tensorflow/core/lib/gtl/inlined_vector.h" 29 #include "tensorflow/core/lib/monitoring/counter.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/protobuf.h" 32 #include "tensorflow/core/util/presized_cuckoo_map.h" 33 #include "tensorflow/core/util/sparse/sparse_tensor.h" 34 35 namespace tensorflow { 36 namespace example { 37 38 namespace { 39 40 template <typename T> 41 using SmallVector = gtl::InlinedVector<T, 4>; 42 43 template <typename A> 44 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) { 45 a->EnableAliasing(true); 46 } 47 48 template <typename A> 49 void EnableAliasing(A&& a) {} 50 51 uint8 PeekTag(protobuf::io::CodedInputStream* stream) { 52 DCHECK(stream != nullptr); 53 const void* ptr; 54 int size; 55 if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0; 56 return *static_cast<const uint8*>(ptr); 57 } 58 59 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; } 60 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; } 61 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; } 62 63 namespace parsed { 64 65 // ParseDataType has to be called first, then appropriate ParseZzzzList. 66 class Feature { 67 public: 68 Feature() {} 69 explicit Feature(StringPiece serialized) : serialized_(serialized) {} 70 71 Status ParseDataType(DataType* dtype) { 72 DCHECK(dtype != nullptr); 73 if (serialized_.empty()) { 74 *dtype = DT_INVALID; 75 return Status::OK(); 76 } 77 uint8 oneof_tag = static_cast<uint8>(*serialized_.data()); 78 serialized_.remove_prefix(1); 79 switch (oneof_tag) { 80 case kDelimitedTag(1): 81 *dtype = DT_STRING; 82 break; 83 case kDelimitedTag(2): 84 *dtype = DT_FLOAT; 85 break; 86 case kDelimitedTag(3): 87 *dtype = DT_INT64; 88 break; 89 default: 90 // Initialize variable to avoid compiler warning 91 *dtype = DT_INVALID; 92 return errors::InvalidArgument("Unsupported datatype."); 93 } 94 return Status::OK(); 95 } 96 97 bool GetNumElementsInBytesList(int* num_elements) { 98 protobuf::io::CodedInputStream stream( 99 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); 100 EnableAliasing(&stream); 101 uint32 length = 0; 102 if (!stream.ReadVarint32(&length)) return false; 103 auto limit = stream.PushLimit(length); 104 *num_elements = 0; 105 while (!stream.ExpectAtEnd()) { 106 if (!stream.ExpectTag(kDelimitedTag(1))) return false; 107 uint32 bytes_length = 0; 108 if (!stream.ReadVarint32(&bytes_length)) return false; 109 if (!stream.Skip(bytes_length)) return false; 110 ++*num_elements; 111 } 112 stream.PopLimit(limit); 113 return true; 114 } 115 116 template <typename Result> 117 bool ParseBytesList(Result* bytes_list) { 118 DCHECK(bytes_list != nullptr); 119 120 protobuf::io::CodedInputStream stream( 121 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); 122 123 EnableAliasing(&stream); 124 125 uint32 length; 126 if (!stream.ReadVarint32(&length)) return false; 127 auto limit = stream.PushLimit(length); 128 129 while (!stream.ExpectAtEnd()) { 130 if (!stream.ExpectTag(kDelimitedTag(1))) return false; 131 // parse string 132 uint32 bytes_length; 133 if (!stream.ReadVarint32(&bytes_length)) return false; 134 string bytes; 135 if (!stream.ReadString(&bytes, bytes_length)) return false; 136 bytes_list->push_back(std::move(bytes)); 137 } 138 stream.PopLimit(limit); 139 return true; 140 } 141 142 template <typename Result> 143 bool ParseFloatList(Result* float_list) { 144 DCHECK(float_list != nullptr); 145 protobuf::io::CodedInputStream stream( 146 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); 147 EnableAliasing(&stream); 148 uint32 length; 149 if (!stream.ReadVarint32(&length)) return false; 150 auto limit = stream.PushLimit(length); 151 152 if (!stream.ExpectAtEnd()) { 153 uint8 peek_tag = PeekTag(&stream); 154 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) { 155 return false; 156 } 157 158 if (peek_tag == kDelimitedTag(1)) { // packed 159 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag 160 uint32 packed_length; 161 if (!stream.ReadVarint32(&packed_length)) return false; 162 auto packed_limit = stream.PushLimit(packed_length); 163 164 while (!stream.ExpectAtEnd()) { 165 uint32 buffer32; 166 if (!stream.ReadLittleEndian32(&buffer32)) return false; 167 float_list->push_back(bit_cast<float>(buffer32)); 168 } 169 170 stream.PopLimit(packed_limit); 171 } else { // non-packed 172 while (!stream.ExpectAtEnd()) { 173 if (!stream.ExpectTag(kFixed32Tag(1))) return false; 174 uint32 buffer32; 175 if (!stream.ReadLittleEndian32(&buffer32)) return false; 176 float_list->push_back(bit_cast<float>(buffer32)); 177 } 178 } 179 } 180 181 stream.PopLimit(limit); 182 return true; 183 } 184 185 template <typename Result> 186 bool ParseInt64List(Result* int64_list) { 187 DCHECK(int64_list != nullptr); 188 protobuf::io::CodedInputStream stream( 189 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); 190 EnableAliasing(&stream); 191 uint32 length; 192 if (!stream.ReadVarint32(&length)) return false; 193 auto limit = stream.PushLimit(length); 194 195 if (!stream.ExpectAtEnd()) { 196 uint8 peek_tag = PeekTag(&stream); 197 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) { 198 return false; 199 } 200 if (peek_tag == kDelimitedTag(1)) { // packed 201 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag 202 uint32 packed_length; 203 if (!stream.ReadVarint32(&packed_length)) return false; 204 auto packed_limit = stream.PushLimit(packed_length); 205 206 while (!stream.ExpectAtEnd()) { 207 protobuf_uint64 n; // There is no API for int64 208 if (!stream.ReadVarint64(&n)) return false; 209 int64_list->push_back(static_cast<int64>(n)); 210 } 211 212 stream.PopLimit(packed_limit); 213 } else { // non-packed 214 while (!stream.ExpectAtEnd()) { 215 if (!stream.ExpectTag(kVarintTag(1))) return false; 216 protobuf_uint64 n; // There is no API for int64 217 if (!stream.ReadVarint64(&n)) return false; 218 int64_list->push_back(static_cast<int64>(n)); 219 } 220 } 221 } 222 stream.PopLimit(limit); 223 return true; 224 } 225 226 StringPiece GetSerialized() const { return serialized_; } 227 228 private: 229 // TODO(lew): Pair of uint8* would be more natural. 230 StringPiece serialized_; 231 }; 232 233 using FeatureMapEntry = std::pair<StringPiece, Feature>; 234 using Example = std::vector<FeatureMapEntry>; 235 236 } // namespace parsed 237 238 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) { 239 uint32 data; 240 protobuf_uint64 dummy; 241 switch (stream->ReadTag() & 0x7) { 242 case 0: // varint 243 if (!stream->ReadVarint32(&data)) return false; 244 return true; 245 case 1: // fixed64 246 if (!stream->ReadLittleEndian64(&dummy)) return false; 247 return true; 248 case 2: // length delimited 249 if (!stream->ReadVarint32(&data)) return false; 250 stream->Skip(data); 251 return true; 252 case 3: // group begin 253 return false; // groups not supported. 254 case 4: // group end 255 return false; // groups not supported. 256 case 5: // fixed32 257 if (!stream->ReadLittleEndian32(&data)) return false; 258 return true; 259 } 260 return false; // unrecognized tag type 261 } 262 263 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { 264 DCHECK(stream != nullptr); 265 DCHECK(result != nullptr); 266 uint32 length; 267 if (!stream->ReadVarint32(&length)) return false; 268 if (length == 0) { 269 *result = StringPiece(nullptr, 0); 270 return true; 271 } 272 const void* stream_alias; 273 int stream_size; 274 if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) { 275 return false; 276 } 277 if (static_cast<uint32>(stream_size) < length) return false; 278 *result = StringPiece(static_cast<const char*>(stream_alias), length); 279 stream->Skip(length); 280 return true; 281 } 282 283 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream, 284 parsed::FeatureMapEntry* feature_map_entry) { 285 DCHECK(stream != nullptr); 286 DCHECK(feature_map_entry != nullptr); 287 uint32 length; 288 if (!stream->ReadVarint32(&length)) return false; 289 auto limit = stream->PushLimit(length); 290 if (!stream->ExpectTag(kDelimitedTag(1))) return false; 291 if (!ParseString(stream, &feature_map_entry->first)) return false; 292 if (!stream->ExpectTag(kDelimitedTag(2))) return false; 293 StringPiece feature_string_piece; 294 if (!ParseString(stream, &feature_string_piece)) return false; 295 feature_map_entry->second = parsed::Feature(feature_string_piece); 296 if (!stream->ExpectAtEnd()) return false; 297 stream->PopLimit(limit); 298 return true; 299 } 300 301 bool ParseFeatures(protobuf::io::CodedInputStream* stream, 302 parsed::Example* example) { 303 DCHECK(stream != nullptr); 304 DCHECK(example != nullptr); 305 uint32 length; 306 if (!stream->ReadVarint32(&length)) return false; 307 auto limit = stream->PushLimit(length); 308 while (!stream->ExpectAtEnd()) { 309 parsed::FeatureMapEntry feature_map_entry; 310 if (!stream->ExpectTag(kDelimitedTag(1))) return false; 311 if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false; 312 example->push_back(std::move(feature_map_entry)); 313 } 314 stream->PopLimit(limit); 315 return true; 316 } 317 318 bool ParseExample(protobuf::io::CodedInputStream* stream, 319 parsed::Example* example) { 320 DCHECK(stream != nullptr); 321 DCHECK(example != nullptr); 322 // Loop over the input stream which may contain multiple serialized Example 323 // protos merged together as strings. This behavior is consistent with Proto's 324 // ParseFromString when string representations are concatenated. 325 while (!stream->ExpectAtEnd()) { 326 if (!stream->ExpectTag(kDelimitedTag(1))) { 327 if (!SkipExtraneousTag(stream)) return false; 328 continue; 329 } 330 if (!ParseFeatures(stream, example)) return false; 331 } 332 return true; 333 } 334 335 bool ParseExample(StringPiece serialized, parsed::Example* example) { 336 DCHECK(example != nullptr); 337 protobuf::io::CodedInputStream stream( 338 reinterpret_cast<const uint8*>(serialized.data()), serialized.size()); 339 EnableAliasing(&stream); 340 return ParseExample(&stream, example); 341 } 342 343 } // namespace 344 345 bool TestFastParse(const string& serialized, Example* example) { 346 DCHECK(example != nullptr); 347 parsed::Example parsed_example; 348 if (!ParseExample(serialized, &parsed_example)) return false; 349 auto& features = *example->mutable_features(); 350 size_t parsed_example_size = parsed_example.size(); 351 for (size_t i = 0; i < parsed_example_size; ++i) { 352 // This is a logic that standard protobuf parsing is implementing. 353 // I.e. last entry in the map overwrites all the previous ones. 354 parsed::FeatureMapEntry& name_and_feature = 355 parsed_example[parsed_example_size - i - 1]; 356 string name = name_and_feature.first.ToString(); 357 if ((*features.mutable_feature()).count(name) > 0) continue; 358 359 auto& value = (*features.mutable_feature())[name]; 360 DataType dtype; 361 if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false; 362 switch (dtype) { 363 case DT_INVALID: 364 break; 365 case DT_STRING: { 366 SmallVector<string> list; 367 if (!name_and_feature.second.ParseBytesList(&list)) return false; 368 auto* result_list = value.mutable_bytes_list(); 369 for (auto& bytes : list) { 370 auto* new_value = result_list->add_value(); 371 new_value->swap(bytes); 372 } 373 break; 374 } 375 case DT_FLOAT: { 376 SmallVector<float> list; 377 if (!name_and_feature.second.ParseFloatList(&list)) return false; 378 auto* result_list = value.mutable_float_list(); 379 for (float f : list) { 380 result_list->add_value(f); 381 } 382 break; 383 } 384 case DT_INT64: { 385 SmallVector<int64> list; 386 if (!name_and_feature.second.ParseInt64List(&list)) return false; 387 auto* result_list = value.mutable_int64_list(); 388 for (int64 i : list) { 389 result_list->add_value(i); 390 } 391 break; 392 } 393 default: 394 LOG(FATAL) << "Should not happen."; 395 } 396 } 397 return true; 398 } 399 400 // ----------------------------------------------------------------------------- 401 402 namespace { 403 404 using Config = FastParseExampleConfig; 405 406 void ParallelFor(const std::function<void(size_t)>& f, size_t n, 407 thread::ThreadPool* thread_pool) { 408 if (n == 0) return; 409 if (thread_pool == nullptr) { 410 for (size_t i = 0; i < n; ++i) { 411 f(i); 412 } 413 } else { 414 BlockingCounter counter(n - 1); 415 for (size_t i = 1; i < n; ++i) { 416 thread_pool->Schedule([i, &f, &counter] { 417 f(i); 418 counter.DecrementCount(); 419 }); 420 } 421 f(0); 422 counter.Wait(); 423 } 424 } 425 426 enum class Type { Sparse, Dense }; 427 428 struct SparseBuffer { 429 // Features are in one of the 3 vectors below depending on config's dtype. 430 // Other 2 vectors remain empty. 431 SmallVector<string> bytes_list; 432 SmallVector<float> float_list; 433 SmallVector<int64> int64_list; 434 435 // Features of example i are elements with indices 436 // from example_end_indices[i-1] to example_end_indices[i]-1 on the 437 // appropriate xxxxx_list 438 std::vector<size_t> example_end_indices; 439 }; 440 441 struct SeededHasher { 442 uint64 operator()(StringPiece s) const { 443 return Hash64(s.data(), s.size(), seed); 444 } 445 uint64 seed{0xDECAFCAFFE}; 446 }; 447 448 template <typename T> 449 class LimitedArraySlice { 450 public: 451 LimitedArraySlice(T* begin, size_t num_elements) 452 : current_(begin), end_(begin + num_elements) {} 453 454 // May return negative if there were push_back calls after slice was filled. 455 int64 EndDistance() const { return end_ - current_; } 456 457 // Attempts to push value to the back of this. If the slice has 458 // already been filled, this method has no effect on the underlying data, but 459 // it changes the number returned by EndDistance into negative values. 460 void push_back(T&& value) { 461 if (EndDistance() > 0) *current_ = std::move(value); 462 ++current_; 463 } 464 465 private: 466 T* current_; 467 T* end_; 468 }; 469 470 void LogDenseFeatureDataLoss(StringPiece feature_name) { 471 LOG(WARNING) << "Data loss! Feature '" << feature_name 472 << "' is present in multiple concatenated " 473 "tf.Examples. Ignoring all but last one."; 474 static auto* duplicated_dense_feature = monitoring::Counter<0>::New( 475 "/tensorflow/core/util/example_proto_fast_parsing/" 476 "duplicated_dense_feature", 477 "Dense feature appears twice in a tf.Example"); 478 duplicated_dense_feature->GetCell()->IncrementBy(1); 479 } 480 481 void LogSparseFeatureDataLoss(StringPiece feature_name) { 482 LOG(WARNING) << "Data loss! Feature '" << feature_name 483 << "' is present in multiple concatenated " 484 "tf.Examples. Ignoring all but last one."; 485 static auto* duplicated_sparse_feature = monitoring::Counter<0>::New( 486 "/tensorflow/core/util/example_proto_fast_parsing/" 487 "duplicated_sparse_feature", 488 "Sparse feature appears twice in a tf.Example"); 489 duplicated_sparse_feature->GetCell()->IncrementBy(1); 490 } 491 492 Status FastParseSerializedExample( 493 const string& serialized_example, const string& example_name, 494 const size_t example_index, const Config& config, 495 const PresizedCuckooMap<std::pair<size_t, Type>>& config_index, 496 SeededHasher hasher, std::vector<Tensor>* output_dense, 497 std::vector<SparseBuffer>* output_varlen_dense, 498 std::vector<SparseBuffer>* output_sparse) { 499 DCHECK(output_dense != nullptr); 500 DCHECK(output_sparse != nullptr); 501 parsed::Example parsed_example; 502 if (!ParseExample(serialized_example, &parsed_example)) { 503 return errors::InvalidArgument("Could not parse example input, value: '", 504 serialized_example, "'"); 505 } 506 std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1); 507 std::vector<int64> dense_feature_last_example(config.dense.size(), -1); 508 509 // Handle features present in the example. 510 const size_t parsed_example_size = parsed_example.size(); 511 for (size_t i = 0; i < parsed_example_size; ++i) { 512 // This is a logic that standard protobuf parsing is implementing. 513 // I.e. last entry in the map overwrites all the previous ones. 514 parsed::FeatureMapEntry& name_and_feature = 515 parsed_example[parsed_example_size - i - 1]; 516 517 const StringPiece feature_name = name_and_feature.first; 518 parsed::Feature& feature = name_and_feature.second; 519 520 std::pair<size_t, Type> d_and_type; 521 uint64 h = hasher(feature_name); 522 if (!config_index.Find(h, &d_and_type)) continue; 523 524 size_t d = d_and_type.first; 525 bool is_dense = d_and_type.second == Type::Dense; 526 527 { 528 // Testing for PresizedCuckooMap collision. 529 // TODO(lew): Use dense_hash_map and avoid this and hasher creation. 530 const string& config_feature_name = is_dense 531 ? config.dense[d].feature_name 532 : config.sparse[d].feature_name; 533 if (feature_name != config_feature_name) continue; 534 } 535 536 auto example_error = [&](StringPiece suffix) { 537 return errors::InvalidArgument("Name: ", example_name, 538 ", Key: ", feature_name, 539 ", Index: ", example_index, ". ", suffix); 540 }; 541 542 auto parse_error = [&] { 543 return example_error("Can't parse serialized Example."); 544 }; 545 546 DataType example_dtype; 547 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype)); 548 549 if (is_dense) { 550 if (example_dtype == DT_INVALID) continue; 551 552 // If feature was already visited, skip. 553 // Compare comment at the beginning of the loop. 554 if (dense_feature_last_example[d] == example_index) { 555 LogDenseFeatureDataLoss(feature_name); 556 continue; 557 } 558 dense_feature_last_example[d] = example_index; 559 560 if (example_dtype != config.dense[d].dtype) { 561 return example_error(strings::StrCat( 562 "Data types don't match. Data type: ", 563 DataTypeString(example_dtype), 564 " but expected type: ", DataTypeString(config.dense[d].dtype))); 565 } 566 if (!config.dense[d].variable_length) { 567 Tensor& out = (*output_dense)[d]; 568 569 const std::size_t num_elements = config.dense[d].elements_per_stride; 570 const std::size_t offset = example_index * num_elements; 571 572 auto shape_error = [&](size_t size, StringPiece type_str) { 573 return example_error(strings::StrCat( 574 "Number of ", type_str, 575 " values != expected. " 576 "Values size: ", 577 size, 578 " but output shape: ", config.dense[d].shape.DebugString())); 579 }; 580 581 switch (config.dense[d].dtype) { 582 case DT_INT64: { 583 auto out_p = out.flat<int64>().data() + offset; 584 LimitedArraySlice<int64> slice(out_p, num_elements); 585 if (!feature.ParseInt64List(&slice)) return parse_error(); 586 if (slice.EndDistance() != 0) { 587 return shape_error(num_elements - slice.EndDistance(), "int64"); 588 } 589 break; 590 } 591 case DT_FLOAT: { 592 auto out_p = out.flat<float>().data() + offset; 593 LimitedArraySlice<float> slice(out_p, num_elements); 594 if (!feature.ParseFloatList(&slice)) return parse_error(); 595 if (slice.EndDistance() != 0) { 596 return shape_error(num_elements - slice.EndDistance(), "float"); 597 } 598 break; 599 } 600 case DT_STRING: { 601 auto out_p = out.flat<string>().data() + offset; 602 LimitedArraySlice<string> slice(out_p, num_elements); 603 if (!feature.ParseBytesList(&slice)) return parse_error(); 604 if (slice.EndDistance() != 0) { 605 return shape_error(num_elements - slice.EndDistance(), "bytes"); 606 } 607 break; 608 } 609 default: 610 LOG(FATAL) << "Should not happen."; 611 } 612 } else { // if variable length 613 SparseBuffer& out = (*output_varlen_dense)[d]; 614 615 const std::size_t num_elements = config.dense[d].elements_per_stride; 616 617 if (example_dtype != DT_INVALID && 618 example_dtype != config.dense[d].dtype) { 619 return example_error(strings::StrCat( 620 "Data types don't match. ", 621 "Expected type: ", DataTypeString(config.dense[d].dtype))); 622 } 623 624 auto shape_error = [&](size_t size, StringPiece type_str) { 625 return example_error(strings::StrCat( 626 "Number of ", type_str, 627 " values is not a multiple of stride length. Saw ", size, 628 " values but output shape is: ", 629 config.dense[d].shape.DebugString())); 630 }; 631 632 switch (config.dense[d].dtype) { 633 case DT_INT64: { 634 if (example_dtype != DT_INVALID) { 635 if (!feature.ParseInt64List(&out.int64_list)) { 636 return parse_error(); 637 } 638 if (out.int64_list.size() % num_elements != 0) { 639 return shape_error(out.int64_list.size(), "int64"); 640 } 641 } 642 out.example_end_indices.push_back(out.int64_list.size()); 643 break; 644 } 645 case DT_FLOAT: { 646 if (example_dtype != DT_INVALID) { 647 if (!feature.ParseFloatList(&out.float_list)) { 648 return parse_error(); 649 } 650 if (out.float_list.size() % num_elements != 0) { 651 return shape_error(out.float_list.size(), "float"); 652 } 653 } 654 out.example_end_indices.push_back(out.float_list.size()); 655 break; 656 } 657 case DT_STRING: { 658 if (example_dtype != DT_INVALID) { 659 if (!feature.ParseBytesList(&out.bytes_list)) { 660 return parse_error(); 661 } 662 if (out.bytes_list.size() % num_elements != 0) { 663 return shape_error(out.bytes_list.size(), "bytes"); 664 } 665 } 666 out.example_end_indices.push_back(out.bytes_list.size()); 667 break; 668 } 669 default: 670 LOG(FATAL) << "Should not happen."; 671 } 672 } 673 } else { 674 // If feature was already visited, skip. 675 // Compare comment at the beginning of the loop. 676 if (sparse_feature_last_example[d] == example_index) { 677 LogSparseFeatureDataLoss(feature_name); 678 continue; 679 } 680 sparse_feature_last_example[d] = example_index; 681 682 // Handle sparse features. 683 SparseBuffer& out = (*output_sparse)[d]; 684 if (example_dtype != DT_INVALID && 685 example_dtype != config.sparse[d].dtype) { 686 return example_error(strings::StrCat( 687 "Data types don't match. ", 688 "Expected type: ", DataTypeString(config.sparse[d].dtype), 689 ", Actual type: ", DataTypeString(example_dtype))); 690 } 691 692 switch (config.sparse[d].dtype) { 693 case DT_INT64: { 694 if (example_dtype != DT_INVALID) { 695 if (!feature.ParseInt64List(&out.int64_list)) { 696 return parse_error(); 697 } 698 } 699 out.example_end_indices.push_back(out.int64_list.size()); 700 break; 701 } 702 case DT_FLOAT: { 703 if (example_dtype != DT_INVALID) { 704 if (!feature.ParseFloatList(&out.float_list)) { 705 return parse_error(); 706 } 707 } 708 out.example_end_indices.push_back(out.float_list.size()); 709 break; 710 } 711 case DT_STRING: { 712 if (example_dtype != DT_INVALID) { 713 if (!feature.ParseBytesList(&out.bytes_list)) { 714 return parse_error(); 715 } 716 } 717 out.example_end_indices.push_back(out.bytes_list.size()); 718 break; 719 } 720 default: 721 LOG(FATAL) << "Should not happen."; 722 } 723 } 724 } 725 726 // Handle missing dense features for fixed strides. 727 for (size_t d = 0; d < config.dense.size(); ++d) { 728 if (config.dense[d].variable_length) continue; 729 if (dense_feature_last_example[d] == example_index) continue; 730 if (config.dense[d].default_value.NumElements() == 0) { 731 return errors::InvalidArgument( 732 "Name: ", example_name, ", Feature: ", config.dense[d].feature_name, 733 " (data type: ", DataTypeString(config.dense[d].dtype), ")", 734 " is required but could not be found."); 735 } 736 const Tensor& in = config.dense[d].default_value; 737 Tensor& out = (*output_dense)[d]; 738 const std::size_t num_elements = in.shape().num_elements(); 739 const std::size_t offset = example_index * num_elements; 740 741 switch (config.dense[d].dtype) { 742 case DT_INT64: { 743 std::copy_n(in.flat<int64>().data(), num_elements, 744 out.flat<int64>().data() + offset); 745 break; 746 } 747 case DT_FLOAT: { 748 std::copy_n(in.flat<float>().data(), num_elements, 749 out.flat<float>().data() + offset); 750 break; 751 } 752 case DT_STRING: { 753 std::copy_n(in.flat<string>().data(), num_elements, 754 out.flat<string>().data() + offset); 755 break; 756 } 757 default: 758 LOG(FATAL) << "Should not happen."; 759 } 760 } 761 762 // Handle missing varlen dense features. 763 for (size_t d = 0; d < config.dense.size(); ++d) { 764 if (!config.dense[d].variable_length) continue; 765 if (dense_feature_last_example[d] == example_index) continue; 766 SparseBuffer& out = (*output_varlen_dense)[d]; 767 size_t prev_example_end_index = 768 out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); 769 out.example_end_indices.push_back(prev_example_end_index); 770 } 771 772 // Handle missing sparse features. 773 for (size_t d = 0; d < config.sparse.size(); ++d) { 774 if (sparse_feature_last_example[d] == example_index) continue; 775 SparseBuffer& out = (*output_sparse)[d]; 776 size_t prev_example_end_index = 777 out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); 778 out.example_end_indices.push_back(prev_example_end_index); 779 } 780 781 return Status::OK(); 782 } 783 784 Status CheckConfigDataType(DataType dtype) { 785 switch (dtype) { 786 case DT_INT64: 787 case DT_FLOAT: 788 case DT_STRING: 789 return Status::OK(); 790 default: 791 return errors::InvalidArgument("Invalid config dtype: ", 792 DataTypeString(dtype)); 793 } 794 } 795 796 template <typename T> 797 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer); 798 799 template <> 800 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) { 801 return buffer.int64_list; 802 } 803 template <> 804 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) { 805 return buffer.float_list; 806 } 807 template <> 808 const SmallVector<string>& GetListFromBuffer<string>( 809 const SparseBuffer& buffer) { 810 return buffer.bytes_list; 811 } 812 813 template <typename T> 814 void CopyOrMoveBlock(const T* b, const T* e, T* t) { 815 std::copy(b, e, t); 816 } 817 template <> 818 void CopyOrMoveBlock(const string* b, const string* e, string* t) { 819 std::move(b, e, t); 820 } 821 822 template <typename T> 823 void FillAndCopyVarLen( 824 const int d, const size_t num_elements, 825 const size_t num_elements_per_minibatch, const Config& config, 826 const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers, 827 Tensor* values) { 828 const Tensor& default_value = config.dense[d].default_value; 829 830 // Copy-fill the tensors (creating the zero/fill-padding) 831 std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements, 832 default_value.flat<T>()(0)); 833 834 // Data is [batch_size, max_num_elements, data_stride_size] 835 // and num_elements_per_minibatch = max_num_elements * data_stride_size 836 auto data = values->flat<T>().data(); 837 838 // Iterate over minibatch elements 839 for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) { 840 const SparseBuffer& buffer = varlen_dense_buffers[i][d]; 841 // Number of examples being stored in this buffer 842 const auto& end_indices = buffer.example_end_indices; 843 const size_t examples_in_buffer = end_indices.size(); 844 // const size_t stride_size = config.dense[d].elements_per_stride; 845 846 const auto& list = GetListFromBuffer<T>(buffer); 847 auto list_ptr = list.begin(); 848 849 size_t elements_tally = 0; 850 // Iterate through all the examples stored in this buffer. 851 for (size_t j = 0; j < examples_in_buffer; ++j) { 852 // Number of elements stored for this example. 853 const size_t num_elems = end_indices[j] - elements_tally; 854 CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data); 855 // Move forward this many elements in the varlen buffer. 856 list_ptr += num_elems; 857 // Move forward to the next minibatch entry in the values output. 858 data += num_elements_per_minibatch; 859 elements_tally = end_indices[j]; 860 } 861 DCHECK(elements_tally == list.size()); 862 } 863 } 864 865 } // namespace 866 867 Status FastParseExample(const Config& config, 868 gtl::ArraySlice<string> serialized, 869 gtl::ArraySlice<string> example_names, 870 thread::ThreadPool* thread_pool, Result* result) { 871 DCHECK(result != nullptr); 872 // Check config so we can safely CHECK(false) in switches on config.*.dtype 873 for (auto& c : config.sparse) { 874 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); 875 } 876 for (auto& c : config.dense) { 877 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); 878 } 879 880 size_t config_size = config.dense.size() + config.sparse.size(); 881 SeededHasher hasher; 882 // Build config index. 883 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size); 884 bool ok = true; 885 for (size_t i = 0; i < 1000; ++i) { 886 for (size_t d = 0; d < config.dense.size(); ++d) { 887 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name), 888 {d, Type::Dense}); 889 } 890 for (size_t d = 0; d < config.sparse.size(); ++d) { 891 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name), 892 {d, Type::Sparse}); 893 } 894 if (ok) break; 895 LOG(WARNING) << "Collision found. This should happen only if you have " 896 "around 2^32 entries in your config."; 897 hasher.seed++; 898 config_index.Clear(config_size); 899 } 900 if (!ok) { 901 return errors::Internal( 902 "Could not avoid collision. This should not happen."); 903 } 904 905 // Allocate dense output for fixed length dense values 906 // (variable-length dense and sparse have to be buffered). 907 std::vector<Tensor> fixed_dense_values(config.dense.size()); 908 for (size_t d = 0; d < config.dense.size(); ++d) { 909 if (config.dense[d].variable_length) continue; 910 TensorShape out_shape; 911 out_shape.AddDim(serialized.size()); 912 for (const int64 dim : config.dense[d].shape.dim_sizes()) { 913 out_shape.AddDim(dim); 914 } 915 fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape); 916 } 917 918 // This parameter affects performance in a big and data-dependent way. 919 const size_t kMiniBatchSizeBytes = 50000; 920 921 // Calculate number of minibatches. 922 // In main regime make each minibatch around kMiniBatchSizeBytes bytes. 923 // Apply 'special logic' below for small and big regimes. 924 const size_t num_minibatches = [&] { 925 size_t result = 0; 926 size_t minibatch_bytes = 0; 927 for (size_t i = 0; i < serialized.size(); i++) { 928 if (minibatch_bytes == 0) { // start minibatch 929 result++; 930 } 931 minibatch_bytes += serialized[i].size() + 1; 932 if (minibatch_bytes > kMiniBatchSizeBytes) { 933 minibatch_bytes = 0; 934 } 935 } 936 // 'special logic' 937 const size_t min_minibatches = std::min<size_t>(8, serialized.size()); 938 const size_t max_minibatches = 64; 939 return std::max<size_t>(min_minibatches, 940 std::min<size_t>(max_minibatches, result)); 941 }(); 942 943 auto first_example_of_minibatch = [&](size_t minibatch) -> size_t { 944 return (serialized.size() * minibatch) / num_minibatches; 945 }; 946 947 // TODO(lew): A big performance low-hanging fruit here is to improve 948 // num_minibatches calculation to take into account actual amount of work 949 // needed, as the size in bytes is not perfect. Linear combination of 950 // size in bytes and average number of features per example is promising. 951 // Even better: measure time instead of estimating, but this is too costly 952 // in small batches. 953 // Maybe accept outside parameter #num_minibatches? 954 955 // Do minibatches in parallel. 956 std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches); 957 std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches); 958 std::vector<Status> status_of_minibatch(num_minibatches); 959 auto ProcessMiniBatch = [&](size_t minibatch) { 960 sparse_buffers[minibatch].resize(config.sparse.size()); 961 varlen_dense_buffers[minibatch].resize(config.dense.size()); 962 size_t start = first_example_of_minibatch(minibatch); 963 size_t end = first_example_of_minibatch(minibatch + 1); 964 for (size_t e = start; e < end; ++e) { 965 status_of_minibatch[minibatch] = FastParseSerializedExample( 966 serialized[e], 967 (!example_names.empty() ? example_names[e] : "<unknown>"), e, config, 968 config_index, hasher, &fixed_dense_values, 969 &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch]); 970 if (!status_of_minibatch[minibatch].ok()) break; 971 } 972 }; 973 974 ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool); 975 976 for (Status& status : status_of_minibatch) { 977 TF_RETURN_IF_ERROR(status); 978 } 979 980 for (size_t d = 0; d < config.dense.size(); ++d) { 981 result->dense_values.push_back(std::move(fixed_dense_values[d])); 982 } 983 984 // Merge SparseBuffers from all minibatches for every config.sparse. 985 auto MergeSparseMinibatches = [&](size_t d) { 986 // Loop over minibatches 987 size_t total_num_features = 0; 988 size_t max_num_features = 0; 989 for (auto& sparse_values_tmp : sparse_buffers) { 990 const std::vector<size_t>& end_indices = 991 sparse_values_tmp[d].example_end_indices; 992 total_num_features += end_indices.back(); 993 max_num_features = std::max(max_num_features, end_indices[0]); 994 for (size_t i = 1; i < end_indices.size(); ++i) { 995 size_t example_size = end_indices[i] - end_indices[i - 1]; 996 max_num_features = std::max(max_num_features, example_size); 997 } 998 } 999 1000 TensorShape indices_shape; 1001 indices_shape.AddDim(total_num_features); 1002 indices_shape.AddDim(2); 1003 result->sparse_indices.emplace_back(DT_INT64, indices_shape); 1004 Tensor* indices = &result->sparse_indices.back(); 1005 1006 TensorShape values_shape; 1007 values_shape.AddDim(total_num_features); 1008 result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape); 1009 Tensor* values = &result->sparse_values.back(); 1010 1011 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2})); 1012 auto shapes_shape_t = result->sparse_shapes.back().vec<int64>(); 1013 shapes_shape_t(0) = serialized.size(); 1014 shapes_shape_t(1) = max_num_features; 1015 1016 size_t offset = 0; 1017 for (size_t i = 0; i < sparse_buffers.size(); ++i) { 1018 const SparseBuffer& buffer = sparse_buffers[i][d]; 1019 1020 // Update indices. 1021 int64* ix_p = &indices->matrix<int64>()(offset, 0); 1022 size_t delta = 0; 1023 size_t example_index = first_example_of_minibatch(i); 1024 for (size_t example_end_index : buffer.example_end_indices) { 1025 size_t feature_index = 0; 1026 for (; delta < example_end_index; ++delta) { 1027 // Column 0: example index 1028 *ix_p = example_index; 1029 // Column 1: the feature index buffer example 1030 *(ix_p + 1) = feature_index; 1031 ix_p += 2; 1032 ++feature_index; 1033 } 1034 ++example_index; 1035 } 1036 1037 // Copy values over. 1038 switch (config.sparse[d].dtype) { 1039 case DT_INT64: { 1040 std::copy(buffer.int64_list.begin(), buffer.int64_list.end(), 1041 values->flat<int64>().data() + offset); 1042 break; 1043 } 1044 case DT_FLOAT: { 1045 std::copy(buffer.float_list.begin(), buffer.float_list.end(), 1046 values->flat<float>().data() + offset); 1047 break; 1048 } 1049 case DT_STRING: { 1050 std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(), 1051 values->flat<string>().data() + offset); 1052 break; 1053 } 1054 default: 1055 LOG(FATAL) << "Should not happen."; 1056 } 1057 1058 offset += delta; 1059 } 1060 }; 1061 1062 // Merge SparseBuffers from all minibatches for every config.dense having 1063 // variable_length. 1064 auto MergeDenseVarLenMinibatches = [&](size_t d) { 1065 if (!config.dense[d].variable_length) return; 1066 1067 // Loop over minibatches 1068 size_t max_num_features = 0; 1069 for (auto& dense_values_tmp : varlen_dense_buffers) { 1070 std::vector<size_t>& end_indices = 1071 dense_values_tmp[d].example_end_indices; 1072 max_num_features = std::max(max_num_features, end_indices[0]); 1073 for (size_t i = 1; i < end_indices.size(); ++i) { 1074 size_t example_size = end_indices[i] - end_indices[i - 1]; 1075 max_num_features = std::max(max_num_features, example_size); 1076 } 1077 } 1078 1079 const size_t stride_size = config.dense[d].elements_per_stride; 1080 const size_t max_num_elements = max_num_features / stride_size; 1081 TensorShape values_shape; 1082 DCHECK(max_num_features % config.dense[d].elements_per_stride == 0); 1083 const size_t batch_size = serialized.size(); 1084 values_shape.AddDim(batch_size); 1085 values_shape.AddDim(max_num_elements); 1086 for (int i = 1; i < config.dense[d].shape.dims(); ++i) { 1087 values_shape.AddDim(config.dense[d].shape.dim_size(i)); 1088 } 1089 Tensor values(config.dense[d].dtype, values_shape); 1090 result->dense_values[d] = values; 1091 const size_t num_elements = values.NumElements(); 1092 1093 // Nothing to write, exit early. 1094 if (num_elements == 0) return; 1095 1096 const size_t num_elements_per_minibatch = num_elements / batch_size; 1097 1098 switch (config.dense[d].dtype) { 1099 case DT_INT64: { 1100 FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch, 1101 config, varlen_dense_buffers, &values); 1102 break; 1103 } 1104 case DT_FLOAT: { 1105 FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch, 1106 config, varlen_dense_buffers, &values); 1107 break; 1108 } 1109 case DT_STRING: { 1110 FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch, 1111 config, varlen_dense_buffers, &values); 1112 break; 1113 } 1114 default: 1115 LOG(FATAL) << "Should not happen."; 1116 } 1117 }; 1118 1119 for (size_t d = 0; d < config.dense.size(); ++d) { 1120 MergeDenseVarLenMinibatches(d); 1121 } 1122 1123 for (size_t d = 0; d < config.sparse.size(); ++d) { 1124 MergeSparseMinibatches(d); 1125 } 1126 1127 return Status::OK(); 1128 } 1129 1130 Status FastParseSingleExample(const Config& config, const string& serialized, 1131 Result* result) { 1132 DCHECK(result != nullptr); 1133 // Check config so we can safely CHECK(false) in switches on config.*.dtype 1134 for (auto& c : config.sparse) { 1135 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); 1136 } 1137 for (auto& c : config.dense) { 1138 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); 1139 } 1140 1141 // TODO(mrry): Cache the construction of this map at Op construction time. 1142 size_t config_size = config.dense.size() + config.sparse.size(); 1143 SeededHasher hasher; 1144 // Build config index. 1145 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size); 1146 bool ok = true; 1147 for (size_t i = 0; i < 1000; ++i) { 1148 for (size_t d = 0; d < config.dense.size(); ++d) { 1149 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name), 1150 {d, Type::Dense}); 1151 } 1152 for (size_t d = 0; d < config.sparse.size(); ++d) { 1153 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name), 1154 {d, Type::Sparse}); 1155 } 1156 if (ok) break; 1157 LOG(WARNING) << "Collision found. This should happen only if you have " 1158 "around 2^32 entries in your config."; 1159 hasher.seed++; 1160 config_index.Clear(config_size); 1161 } 1162 if (!ok) { 1163 return errors::Internal( 1164 "Could not avoid collision. This should not happen."); 1165 } 1166 1167 // Allocate dense output tensors. 1168 for (size_t d = 0; d < config.dense.size(); ++d) { 1169 if (!config.dense[d].variable_length) { 1170 TensorShape values_shape; 1171 if (!config.dense[d].shape.AsTensorShape(&values_shape)) { 1172 return errors::Internal( 1173 "Fixed-length shape was not a statically defined shape."); 1174 } 1175 result->dense_values.emplace_back(config.dense[d].dtype, values_shape); 1176 } else { 1177 // Variable-length tensor will be allocated later. 1178 result->dense_values.emplace_back(); 1179 } 1180 } 1181 1182 // Allocate sparse output tensors. 1183 for (size_t d = 0; d < config.sparse.size(); ++d) { 1184 // The dense_shape is always a vector of length 1. 1185 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1})); 1186 // Variable-length tensors will be allocated later. 1187 result->sparse_indices.emplace_back(); 1188 result->sparse_values.emplace_back(); 1189 } 1190 1191 parsed::Example parsed_example; 1192 if (!ParseExample(serialized, &parsed_example)) { 1193 return errors::InvalidArgument("Could not parse example input, value: '", 1194 serialized, "'"); 1195 } 1196 std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false); 1197 std::vector<bool> dense_feature_already_seen(config.dense.size(), false); 1198 1199 // Handle features present in the example. 1200 const size_t parsed_example_size = parsed_example.size(); 1201 for (size_t i = 0; i < parsed_example_size; ++i) { 1202 // This is a logic that standard protobuf parsing is implementing. 1203 // I.e. last entry in the map overwrites all the previous ones. 1204 parsed::FeatureMapEntry& name_and_feature = 1205 parsed_example[parsed_example_size - i - 1]; 1206 1207 const StringPiece feature_name = name_and_feature.first; 1208 parsed::Feature& feature = name_and_feature.second; 1209 1210 std::pair<size_t, Type> d_and_type; 1211 uint64 h = hasher(feature_name); 1212 if (!config_index.Find(h, &d_and_type)) continue; 1213 1214 size_t d = d_and_type.first; 1215 bool is_dense = d_and_type.second == Type::Dense; 1216 1217 { 1218 // Testing for PresizedCuckooMap collision. 1219 // TODO(lew): Use dense_hash_map and avoid this and hasher creation. 1220 const string& config_feature_name = is_dense 1221 ? config.dense[d].feature_name 1222 : config.sparse[d].feature_name; 1223 if (feature_name != config_feature_name) continue; 1224 } 1225 1226 auto example_error = [feature_name](StringPiece suffix) { 1227 return errors::InvalidArgument("Key: ", feature_name, ". ", suffix); 1228 }; 1229 1230 auto parse_error = [feature_name] { 1231 return errors::InvalidArgument("Key: ", feature_name, 1232 ". Can't parse serialized Example."); 1233 }; 1234 1235 DataType example_dtype; 1236 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype)); 1237 if (example_dtype == DT_INVALID) continue; 1238 1239 if (is_dense && !config.dense[d].variable_length) { 1240 // If feature was already visited, skip. 1241 // Compare comment at the beginning of the loop. 1242 if (dense_feature_already_seen[d]) { 1243 LogDenseFeatureDataLoss(feature_name); 1244 continue; 1245 } 1246 dense_feature_already_seen[d] = true; 1247 1248 if (example_dtype != config.dense[d].dtype) { 1249 return example_error(strings::StrCat( 1250 "Data types don't match. Data type: ", 1251 DataTypeString(example_dtype), 1252 " but expected type: ", DataTypeString(config.dense[d].dtype))); 1253 } 1254 1255 Tensor* out = &result->dense_values[d]; 1256 const std::size_t num_elements = config.dense[d].elements_per_stride; 1257 1258 switch (example_dtype) { 1259 case DT_INT64: { 1260 auto out_p = out->flat<int64>().data(); 1261 LimitedArraySlice<int64> slice(out_p, num_elements); 1262 if (!feature.ParseInt64List(&slice)) return parse_error(); 1263 if (slice.EndDistance() != 0) { 1264 return parse_error(); 1265 } 1266 break; 1267 } 1268 case DT_FLOAT: { 1269 auto out_p = out->flat<float>().data(); 1270 LimitedArraySlice<float> slice(out_p, num_elements); 1271 if (!feature.ParseFloatList(&slice)) return parse_error(); 1272 if (slice.EndDistance() != 0) { 1273 return parse_error(); 1274 } 1275 break; 1276 } 1277 case DT_STRING: { 1278 auto out_p = out->flat<string>().data(); 1279 LimitedArraySlice<string> slice(out_p, num_elements); 1280 if (!feature.ParseBytesList(&slice)) return parse_error(); 1281 if (slice.EndDistance() != 0) { 1282 return parse_error(); 1283 } 1284 break; 1285 } 1286 default: 1287 LOG(FATAL) << "Should not happen."; 1288 } 1289 1290 } else { // if variable length 1291 SparseBuffer out_temp; 1292 const size_t num_elements_divisor = 1293 is_dense ? config.dense[d].elements_per_stride : 1; 1294 size_t num_elements; 1295 1296 if (is_dense) { 1297 // If feature was already visited, skip. 1298 // Compare comment at the beginning of the loop. 1299 if (dense_feature_already_seen[d]) { 1300 LogDenseFeatureDataLoss(feature_name); 1301 continue; 1302 } 1303 dense_feature_already_seen[d] = true; 1304 if (example_dtype != config.dense[d].dtype) { 1305 return example_error(strings::StrCat( 1306 "Data types don't match. Data type: ", 1307 DataTypeString(example_dtype), 1308 " but expected type: ", DataTypeString(config.dense[d].dtype))); 1309 } 1310 } else { 1311 // If feature was already visited, skip. 1312 // Compare comment at the beginning of the loop. 1313 if (sparse_feature_already_seen[d]) { 1314 LogSparseFeatureDataLoss(feature_name); 1315 continue; 1316 } 1317 sparse_feature_already_seen[d] = true; 1318 1319 // Handle sparse features. 1320 if (example_dtype != DT_INVALID && 1321 example_dtype != config.sparse[d].dtype) { 1322 return example_error(strings::StrCat( 1323 "Data types don't match. ", 1324 "Expected type: ", DataTypeString(config.sparse[d].dtype), 1325 ", Actual type: ", DataTypeString(example_dtype))); 1326 } 1327 } 1328 1329 switch (example_dtype) { 1330 case DT_INT64: { 1331 // TODO(mrry): Use the fact that the `int64_list` is packed to read 1332 // out the length and pre-allocate the output tensor. 1333 if (!feature.ParseInt64List(&out_temp.int64_list)) 1334 return parse_error(); 1335 num_elements = out_temp.int64_list.size(); 1336 break; 1337 } 1338 case DT_FLOAT: { 1339 // TODO(mrry): Use the fact that the `float_list` is packed to read 1340 // out the length and pre-allocate the output tensor. 1341 if (!feature.ParseFloatList(&out_temp.float_list)) 1342 return parse_error(); 1343 num_elements = out_temp.float_list.size(); 1344 break; 1345 } 1346 case DT_STRING: { 1347 int actual_num_elements = 0; 1348 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) { 1349 return parse_error(); 1350 } 1351 out_temp.bytes_list.reserve(actual_num_elements); 1352 if (!feature.ParseBytesList(&out_temp.bytes_list)) 1353 return parse_error(); 1354 num_elements = out_temp.bytes_list.size(); 1355 break; 1356 } 1357 default: 1358 LOG(FATAL) << "Should not happen. " << DataTypeString(example_dtype); 1359 } 1360 1361 if (num_elements % num_elements_divisor != 0) { 1362 return parse_error(); 1363 } 1364 1365 Tensor* out; 1366 if (is_dense) { 1367 TensorShape values_shape; 1368 values_shape.AddDim(num_elements / num_elements_divisor); 1369 for (int i = 1; i < config.dense[d].shape.dims(); ++i) { 1370 values_shape.AddDim(config.dense[d].shape.dim_size(i)); 1371 } 1372 1373 out = &result->dense_values[d]; 1374 *out = Tensor(config.dense[d].dtype, values_shape); 1375 1376 } else { 1377 Tensor* out_indices = &result->sparse_indices[d]; 1378 Tensor* out_dense_shape = &result->sparse_shapes[d]; 1379 out = &result->sparse_values[d]; 1380 1381 // TODO(mrry): Investigate the possibility of not materializing 1382 // the indices (and perhaps dense_shape) until they are needed. 1383 *out_indices = Tensor( 1384 DT_INT64, TensorShape({static_cast<int64>(num_elements), 1})); 1385 auto indices_flat = out_indices->flat<int64>(); 1386 for (size_t i = 0; i < num_elements; ++i) { 1387 indices_flat(i) = static_cast<int64>(i); 1388 } 1389 1390 *out_dense_shape = Tensor(DT_INT64, TensorShape({1})); 1391 auto shapes_shape_t = out_dense_shape->vec<int64>(); 1392 shapes_shape_t(0) = num_elements; 1393 1394 *out = Tensor(config.sparse[d].dtype, 1395 TensorShape({static_cast<int64>(num_elements)})); 1396 } 1397 1398 switch (example_dtype) { 1399 case DT_INT64: { 1400 CopyOrMoveBlock(out_temp.int64_list.begin(), 1401 out_temp.int64_list.end(), out->flat<int64>().data()); 1402 break; 1403 } 1404 case DT_FLOAT: { 1405 CopyOrMoveBlock(out_temp.float_list.begin(), 1406 out_temp.float_list.end(), out->flat<float>().data()); 1407 break; 1408 } 1409 case DT_STRING: { 1410 CopyOrMoveBlock(out_temp.bytes_list.begin(), 1411 out_temp.bytes_list.end(), 1412 out->flat<string>().data()); 1413 break; 1414 } 1415 default: 1416 LOG(FATAL) << "Should not happen."; 1417 } 1418 } 1419 } 1420 1421 // Handle missing dense features. 1422 for (size_t d = 0; d < config.dense.size(); ++d) { 1423 if (!dense_feature_already_seen[d]) { 1424 if (!config.dense[d].variable_length) { 1425 // Handle missing fixed-length dense feature. 1426 if (config.dense[d].default_value.NumElements() == 0) { 1427 return errors::InvalidArgument( 1428 "Feature: ", config.dense[d].feature_name, 1429 " (data type: ", DataTypeString(config.dense[d].dtype), ")", 1430 " is required but could not be found."); 1431 } 1432 result->dense_values[d] = config.dense[d].default_value; 1433 } else { 1434 // Handle missing varlen dense feature. 1435 TensorShape empty_shape; 1436 empty_shape.AddDim(0); 1437 for (int i = 1; i < config.dense[d].shape.dims(); ++i) { 1438 empty_shape.AddDim(config.dense[d].shape.dim_size(i)); 1439 } 1440 result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape); 1441 } 1442 } 1443 } 1444 1445 // Handle missing sparse features. 1446 for (size_t d = 0; d < config.sparse.size(); ++d) { 1447 if (!sparse_feature_already_seen[d]) { 1448 result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1})); 1449 result->sparse_values[d] = 1450 Tensor(config.sparse[d].dtype, TensorShape({0})); 1451 result->sparse_shapes[d].vec<int64>()(0) = 0; 1452 } 1453 } 1454 1455 return Status::OK(); 1456 } 1457 1458 } // namespace example 1459 } // namespace tensorflow 1460