1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/parsing_ops.cc. 17 18 #include <numeric> 19 #include <unordered_set> 20 #include <vector> 21 22 #include "tensorflow/core/example/example.pb.h" 23 #include "tensorflow/core/example/feature.pb_text.h" 24 #include "tensorflow/core/framework/common_shape_fns.h" 25 #include "tensorflow/core/framework/numeric_op.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/lib/gtl/array_slice.h" 28 #include "tensorflow/core/platform/logging.h" 29 #include "tensorflow/core/platform/protobuf.h" 30 #include "tensorflow/core/util/example_proto_fast_parsing.h" 31 #include "tensorflow/core/util/example_proto_helper.h" 32 #include "tensorflow/core/util/sparse/sparse_tensor.h" 33 #include "tensorflow/core/util/work_sharder.h" 34 35 namespace tensorflow { 36 37 class ParseExampleOp : public OpKernel { 38 public: 39 explicit ParseExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 40 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 41 } 42 43 void Compute(OpKernelContext* ctx) override { 44 const Tensor* names; 45 const Tensor* serialized; 46 OpInputList dense_keys; 47 OpInputList sparse_keys; 48 OpInputList dense_defaults; 49 50 // Grab the input list arguments. 51 OP_REQUIRES_OK(ctx, ctx->input("names", &names)); 52 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 53 OP_REQUIRES_OK(ctx, ctx->input_list("dense_keys", &dense_keys)); 54 OP_REQUIRES_OK(ctx, ctx->input_list("sparse_keys", &sparse_keys)); 55 OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults)); 56 57 std::vector<string> dense_keys_t(attrs_.num_dense); 58 std::vector<string> sparse_keys_t(attrs_.num_sparse); 59 60 // Check that the input list sizes match the attribute declared sizes. 61 CHECK_EQ(dense_keys.size(), attrs_.num_dense); 62 CHECK_EQ(sparse_keys.size(), attrs_.num_sparse); 63 64 // Copy from OpInputList to std::vector<string>. 65 for (int di = 0; di < attrs_.num_dense; ++di) { 66 dense_keys_t[di] = dense_keys[di].scalar<string>()(); 67 } 68 for (int di = 0; di < attrs_.num_sparse; ++di) { 69 sparse_keys_t[di] = sparse_keys[di].scalar<string>()(); 70 } 71 72 if (names->NumElements() > 0) { 73 OP_REQUIRES( 74 ctx, TensorShapeUtils::IsVector(names->shape()), 75 errors::InvalidArgument("Expected names to be a vector, got shape: ", 76 names->shape().DebugString())); 77 OP_REQUIRES( 78 ctx, names->NumElements() == serialized->NumElements(), 79 errors::InvalidArgument( 80 "Expected len(names) == len(serialized), but got: ", 81 names->NumElements(), " vs. ", serialized->NumElements())); 82 } 83 84 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()), 85 errors::InvalidArgument( 86 "Expected serialized to be a vector, got shape: ", 87 serialized->shape().DebugString())); 88 OP_REQUIRES(ctx, dense_defaults.size() == attrs_.num_dense, 89 errors::InvalidArgument( 90 "Expected len(dense_defaults) == len(dense_keys) but got: ", 91 dense_defaults.size(), " vs. ", attrs_.num_dense)); 92 93 for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) { 94 const Tensor& def_value = dense_defaults[d]; 95 if (attrs_.variable_length[d]) { 96 OP_REQUIRES(ctx, def_value.NumElements() == 1, 97 errors::InvalidArgument( 98 "dense_shape[", d, "] is a variable length shape: ", 99 attrs_.dense_shapes[d].DebugString(), 100 ", therefore " 101 "def_value[", 102 d, 103 "] must contain a single element (" 104 "the padding element). But its shape is: ", 105 def_value.shape().DebugString())); 106 } else if (def_value.NumElements() > 0) { 107 OP_REQUIRES(ctx, 108 attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()), 109 errors::InvalidArgument( 110 "def_value[", d, 111 "].shape() == ", def_value.shape().DebugString(), 112 " is not compatible with dense_shapes_[", d, 113 "] == ", attrs_.dense_shapes[d].DebugString())); 114 } 115 OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], 116 errors::InvalidArgument( 117 "dense_defaults[", d, "].dtype() == ", 118 DataTypeString(def_value.dtype()), " != dense_types_[", d, 119 "] == ", DataTypeString(attrs_.dense_types[d]))); 120 } 121 122 example::Result result; 123 124 example::FastParseExampleConfig config; 125 for (int d = 0; d < attrs_.num_dense; ++d) { 126 config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d], 127 attrs_.dense_shapes[d], dense_defaults[d], 128 attrs_.variable_length[d], 129 attrs_.elements_per_stride[d]}); 130 } 131 for (int d = 0; d < attrs_.num_sparse; ++d) { 132 config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]}); 133 } 134 135 auto serialized_t = serialized->flat<string>(); 136 auto names_t = names->flat<string>(); 137 gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size()); 138 gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size()); 139 140 OP_REQUIRES_OK( 141 ctx, 142 FastParseExample( 143 config, slice, names_slice, 144 ctx->device()->tensorflow_cpu_worker_threads()->workers, &result)); 145 146 OpOutputList dense_values; 147 OpOutputList sparse_indices; 148 OpOutputList sparse_values; 149 OpOutputList sparse_shapes; 150 OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); 151 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices)); 152 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values)); 153 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes)); 154 for (int d = 0; d < attrs_.num_dense; ++d) { 155 dense_values.set(d, result.dense_values[d]); 156 } 157 for (int d = 0; d < attrs_.num_sparse; ++d) { 158 sparse_indices.set(d, result.sparse_indices[d]); 159 sparse_values.set(d, result.sparse_values[d]); 160 sparse_shapes.set(d, result.sparse_shapes[d]); 161 } 162 } 163 164 protected: 165 ParseExampleAttrs attrs_; 166 }; 167 168 REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU), 169 ParseExampleOp); 170 171 class ParseSingleExampleOp : public OpKernel { 172 public: 173 explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 174 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 175 } 176 177 void Compute(OpKernelContext* ctx) override { 178 const Tensor* serialized; 179 OpInputList dense_defaults; 180 181 // Grab the input list arguments. 182 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 183 OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults)); 184 185 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), 186 errors::InvalidArgument( 187 "Expected serialized to be a scalar, got shape: ", 188 serialized->shape().DebugString())); 189 OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(), 190 errors::InvalidArgument( 191 "Expected len(dense_defaults) == len(dense_keys) but got: ", 192 dense_defaults.size(), " vs. ", attrs_.dense_keys.size())); 193 194 for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) { 195 const Tensor& def_value = dense_defaults[d]; 196 if (attrs_.variable_length[d]) { 197 OP_REQUIRES(ctx, def_value.NumElements() == 1, 198 errors::InvalidArgument( 199 "dense_shape[", d, "] is a variable length shape: ", 200 attrs_.dense_shapes[d].DebugString(), 201 ", therefore " 202 "def_value[", 203 d, 204 "] must contain a single element (" 205 "the padding element). But its shape is: ", 206 def_value.shape().DebugString())); 207 } else if (def_value.NumElements() > 0) { 208 OP_REQUIRES(ctx, 209 attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()), 210 errors::InvalidArgument( 211 "def_value[", d, 212 "].shape() == ", def_value.shape().DebugString(), 213 " is not compatible with dense_shapes_[", d, 214 "] == ", attrs_.dense_shapes[d].DebugString())); 215 } 216 OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], 217 errors::InvalidArgument( 218 "dense_defaults[", d, "].dtype() == ", 219 DataTypeString(def_value.dtype()), " != dense_types_[", d, 220 "] == ", DataTypeString(attrs_.dense_types[d]))); 221 } 222 223 example::Result result; 224 225 // TODO(mrry): Build the configuration once and cache it. 226 example::FastParseExampleConfig config; 227 for (int d = 0; d < attrs_.dense_keys.size(); ++d) { 228 config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d], 229 attrs_.dense_shapes[d], dense_defaults[d], 230 attrs_.variable_length[d], 231 attrs_.elements_per_stride[d]}); 232 } 233 for (int d = 0; d < attrs_.sparse_keys.size(); ++d) { 234 config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]}); 235 } 236 237 const string& serialized_proto = serialized->scalar<string>()(); 238 239 OP_REQUIRES_OK(ctx, 240 FastParseSingleExample(config, serialized_proto, &result)); 241 242 OpOutputList dense_values; 243 OpOutputList sparse_indices; 244 OpOutputList sparse_values; 245 OpOutputList sparse_shapes; 246 OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); 247 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices)); 248 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values)); 249 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes)); 250 for (int d = 0; d < attrs_.dense_keys.size(); ++d) { 251 dense_values.set(d, result.dense_values[d]); 252 } 253 for (int d = 0; d < attrs_.sparse_keys.size(); ++d) { 254 sparse_indices.set(d, result.sparse_indices[d]); 255 sparse_values.set(d, result.sparse_values[d]); 256 sparse_shapes.set(d, result.sparse_shapes[d]); 257 } 258 } 259 260 protected: 261 ParseSingleExampleAttrs attrs_; 262 }; 263 264 REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU), 265 ParseSingleExampleOp); 266 267 class SingleSequenceExampleParserOp : public OpKernel { 268 public: 269 explicit SingleSequenceExampleParserOp(OpKernelConstruction* ctx) 270 : OpKernel(ctx) { 271 OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); 272 } 273 274 void Compute(OpKernelContext* ctx) override { 275 const Tensor* debug_name; 276 const Tensor* serialized; 277 OpInputList context_dense_keys; 278 OpInputList context_sparse_keys; 279 OpInputList context_dense_defaults; 280 OpInputList feature_list_dense_keys; 281 OpInputList feature_list_sparse_keys; 282 const Tensor* feature_list_dense_missing_assumed_empty; 283 284 OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name)); 285 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); 286 OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty", 287 &feature_list_dense_missing_assumed_empty)); 288 OP_REQUIRES_OK(ctx, 289 ctx->input_list("context_dense_keys", &context_dense_keys)); 290 OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys", 291 &feature_list_dense_keys)); 292 OP_REQUIRES_OK( 293 ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys)); 294 OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys", 295 &feature_list_sparse_keys)); 296 OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults", 297 &context_dense_defaults)); 298 299 std::vector<string> context_dense_keys_t(attrs_.num_context_dense); 300 std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse); 301 std::vector<string> feature_list_dense_keys_t( 302 attrs_.num_feature_list_dense); 303 std::vector<string> feature_list_sparse_keys_t( 304 attrs_.num_feature_list_sparse); 305 std::unordered_set<string> feature_list_dense_missing_assumed_empty_set; 306 CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense); 307 CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse); 308 CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense); 309 CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse); 310 for (int di = 0; di < attrs_.num_context_dense; ++di) { 311 OP_REQUIRES(ctx, 312 TensorShapeUtils::IsScalar(context_dense_keys[di].shape()), 313 errors::InvalidArgument( 314 "Expected context_dense_keys[", di, 315 "] to be a scalar, got shape: ", 316 context_dense_keys[di].shape().DebugString())); 317 context_dense_keys_t[di] = context_dense_keys[di].scalar<string>()(); 318 } 319 for (int di = 0; di < attrs_.num_context_sparse; ++di) { 320 OP_REQUIRES(ctx, 321 TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()), 322 errors::InvalidArgument( 323 "Expected context_sparse_keys[", di, 324 "] to be a scalar, got shape: ", 325 context_sparse_keys[di].shape().DebugString())); 326 context_sparse_keys_t[di] = context_sparse_keys[di].scalar<string>()(); 327 } 328 for (int di = 0; di < attrs_.num_feature_list_dense; ++di) { 329 OP_REQUIRES( 330 ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()), 331 errors::InvalidArgument( 332 "Expected feature_list_dense_keys[", di, 333 "] to be a scalar, got shape: ", 334 feature_list_dense_keys[di].shape().DebugString())); 335 feature_list_dense_keys_t[di] = 336 feature_list_dense_keys[di].scalar<string>()(); 337 } 338 for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) { 339 OP_REQUIRES( 340 ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()), 341 errors::InvalidArgument( 342 "Expected feature_list_sparse_keys[", di, 343 "] to be a scalar, got shape: ", 344 feature_list_sparse_keys[di].shape().DebugString())); 345 feature_list_sparse_keys_t[di] = 346 feature_list_sparse_keys[di].scalar<string>()(); 347 } 348 OP_REQUIRES( 349 ctx, 350 TensorShapeUtils::IsVector( 351 feature_list_dense_missing_assumed_empty->shape()), 352 errors::InvalidArgument( 353 "Expected feature_list_dense_missing_assumed_empty ", 354 "to be a vector, got shape: ", 355 feature_list_dense_missing_assumed_empty->shape().DebugString())); 356 auto feature_list_dense_missing_assumped_empty_t = 357 feature_list_dense_missing_assumed_empty->vec<string>(); 358 for (int de = 0; 359 de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) { 360 feature_list_dense_missing_assumed_empty_set.insert( 361 feature_list_dense_missing_assumped_empty_t(de)); 362 } 363 364 bool has_debug_name = (debug_name->NumElements() > 0); 365 if (has_debug_name) { 366 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()), 367 errors::InvalidArgument( 368 "Expected debug_name to be a scalar, got shape: ", 369 debug_name->shape().DebugString())); 370 } 371 auto debug_name_t = debug_name->scalar<string>(); 372 373 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), 374 errors::InvalidArgument( 375 "Expected serialized to be a scalar, got shape: ", 376 serialized->shape().DebugString())); 377 378 OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense, 379 errors::InvalidArgument("Expected len(context_dense_defaults) " 380 "== len(context_dense_keys) but got: ", 381 context_dense_defaults.size(), " vs. ", 382 attrs_.num_context_dense)); 383 384 std::vector<bool> required(attrs_.num_context_dense); 385 for (int d = 0; d < attrs_.num_context_dense; ++d) { 386 const Tensor& def_value = context_dense_defaults[d]; 387 required[d] = (def_value.NumElements() == 0); // No default provided. 388 389 if (def_value.NumElements() > 0) { 390 OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d], 391 errors::InvalidArgument( 392 "def_value[", d, 393 "].shape() == ", def_value.shape().DebugString(), 394 " != context_dense_shapes_[", d, 395 "] == ", attrs_.context_dense_shapes[d].DebugString())); 396 OP_REQUIRES( 397 ctx, def_value.dtype() == attrs_.context_dense_types[d], 398 errors::InvalidArgument( 399 "context_dense_defaults[", d, "].dtype() == ", 400 DataTypeString(def_value.dtype()), " != context_dense_types_[", 401 d, "] == ", DataTypeString(attrs_.context_dense_types[d]))); 402 } 403 } 404 405 auto serialized_t = serialized->scalar<string>(); 406 407 OpOutputList context_sparse_indices; 408 OpOutputList context_sparse_values; 409 OpOutputList context_sparse_shapes; 410 OpOutputList context_dense_values; 411 OpOutputList feature_list_sparse_indices; 412 OpOutputList feature_list_sparse_values; 413 OpOutputList feature_list_sparse_shapes; 414 OpOutputList feature_list_dense_values; 415 416 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", 417 &context_sparse_indices)); 418 OP_REQUIRES_OK( 419 ctx, ctx->output_list("context_sparse_values", &context_sparse_values)); 420 OP_REQUIRES_OK( 421 ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes)); 422 OP_REQUIRES_OK( 423 ctx, ctx->output_list("context_dense_values", &context_dense_values)); 424 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices", 425 &context_sparse_indices)); 426 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices", 427 &feature_list_sparse_indices)); 428 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values", 429 &feature_list_sparse_values)); 430 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes", 431 &feature_list_sparse_shapes)); 432 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values", 433 &feature_list_dense_values)); 434 435 SequenceExample ex; 436 OP_REQUIRES( 437 ctx, ParseProtoUnlimited(&ex, serialized_t()), 438 errors::InvalidArgument("Could not parse example input, value: '", 439 serialized_t(), "'")); 440 441 const string& name = (has_debug_name) ? debug_name_t() : "<unknown>"; 442 const Features& context = ex.context(); 443 const auto& context_dict = context.feature(); 444 445 // Context Dense ----------------------------------------------------------- 446 447 // Preallocate context_dense_values, since we know their sizes 448 for (int d = 0; d < attrs_.num_context_dense; ++d) { 449 TensorShape out_shape; 450 for (const int dim : attrs_.context_dense_shapes[d].dim_sizes()) 451 out_shape.AddDim(dim); 452 Tensor* out = nullptr; 453 OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out)); 454 } 455 456 for (int d = 0; d < attrs_.num_context_dense; ++d) { 457 const string& key = context_dense_keys_t[d]; 458 const DataType& dtype = attrs_.context_dense_types[d]; 459 const TensorShape& shape = attrs_.context_dense_shapes[d]; 460 461 const auto& feature_found = context_dict.find(key); 462 OP_REQUIRES( 463 ctx, (feature_found != context_dict.end()) || !required[d], 464 errors::InvalidArgument("Name: ", name, ", Context feature '", key, 465 "' is required but could not be found.")); 466 if (feature_found != context_dict.end()) { 467 const Feature& f = feature_found->second; 468 bool types_match; 469 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 470 OP_REQUIRES( 471 ctx, types_match, 472 errors::InvalidArgument("Name: ", name, ", Context feature: ", key, 473 ". Data types don't match. ", 474 "Expected type: ", DataTypeString(dtype), 475 " Feature is: ", ProtoDebugString(f))); 476 477 OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f, 478 context_dense_values[d])); 479 } else { 480 RowDenseCopy(0, dtype, context_dense_defaults[d], 481 context_dense_values[d]); 482 } 483 } 484 485 // Context Sparse ---------------------------------------------------------- 486 for (int d = 0; d < attrs_.num_context_sparse; ++d) { 487 const string& key = context_sparse_keys_t[d]; 488 const DataType& dtype = attrs_.context_sparse_types[d]; 489 490 const auto& feature_found = context_dict.find(key); 491 bool feature_has_data = // Found key & data type is set 492 (feature_found != context_dict.end() && 493 (feature_found->second.kind_case() != Feature::KIND_NOT_SET)); 494 495 if (feature_has_data) { 496 const Feature& f = feature_found->second; 497 bool types_match; 498 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 499 OP_REQUIRES( 500 ctx, types_match, 501 errors::InvalidArgument("Name: ", name, ", Context feature: ", key, 502 ". Data types don't match. ", 503 "Expected type: ", DataTypeString(dtype), 504 " Feature is: ", ProtoDebugString(f))); 505 506 Tensor feature_values = FeatureSparseCopy(0, key, dtype, f); 507 const int64 num_elements = feature_values.NumElements(); 508 TensorShape indices_shape({num_elements, 1}); 509 Tensor* sp_indices_d = nullptr; 510 Tensor* sp_shape_d = nullptr; 511 OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape, 512 &sp_indices_d)); 513 context_sparse_values.set(d, feature_values); 514 OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}), 515 &sp_shape_d)); 516 auto shape_t = sp_shape_d->vec<int64>(); 517 shape_t(0) = num_elements; 518 auto indices_t = sp_indices_d->matrix<int64>(); 519 std::iota(indices_t.data(), indices_t.data() + num_elements, 0); 520 } else { 521 TensorShape indices_shape({0, 1}); 522 TensorShape values_shape({0}); 523 Tensor* sp_indices_d = nullptr; 524 Tensor* sp_values_d = nullptr; 525 Tensor* sp_shape_d = nullptr; 526 OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape, 527 &sp_indices_d)); 528 OP_REQUIRES_OK( 529 ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d)); 530 OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}), 531 &sp_shape_d)); 532 auto shape_t = sp_shape_d->vec<int64>(); 533 shape_t(0) = 0; 534 } 535 } 536 537 // Feature List Dense ------------------------------------------------------ 538 539 // Preallocate context_dense_values, since we can infer their 540 // sizes 541 const FeatureLists& feature_lists = ex.feature_lists(); 542 const auto& feature_list_dict = feature_lists.feature_list(); 543 FeatureList empty_feature_list; // Placeholder for missing FLs 544 545 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { 546 const string& key = feature_list_dense_keys_t[d]; 547 const DataType& dtype = attrs_.feature_list_dense_types[d]; 548 const TensorShape& shape = attrs_.feature_list_dense_shapes[d]; 549 550 const auto& feature_list_found = feature_list_dict.find(key); 551 bool feature_list_missing = 552 (feature_list_found == feature_list_dict.end()); 553 bool feature_list_allowed_missing = 554 (feature_list_dense_missing_assumed_empty_set.count(key) > 0); 555 556 OP_REQUIRES( 557 ctx, !feature_list_missing || feature_list_allowed_missing, 558 errors::InvalidArgument("Name: ", name, ", Feature list '", key, 559 "' is required but could not be found. " 560 "Did you mean to include it in " 561 "feature_list_dense_missing_assumed_empty or " 562 "feature_list_dense_defaults?")); 563 564 TensorShape out_shape; 565 const FeatureList& fl = (feature_list_missing) 566 ? empty_feature_list 567 : feature_list_found->second; 568 out_shape.AddDim(fl.feature_size()); 569 for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) { 570 out_shape.AddDim(dim); 571 } 572 Tensor* out = nullptr; 573 OP_REQUIRES_OK(ctx, 574 feature_list_dense_values.allocate(d, out_shape, &out)); 575 576 for (int64 t = 0; t < fl.feature_size(); ++t) { 577 const Feature& f = fl.feature(t); 578 bool types_match; 579 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 580 OP_REQUIRES(ctx, types_match, 581 errors::InvalidArgument( 582 "Name: ", name, ", Feature list: ", key, ", Index: ", t, 583 ". Data types don't match. ", 584 "Expected type: ", DataTypeString(dtype), 585 " Feature is: ", ProtoDebugString(f))); 586 OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f, 587 feature_list_dense_values[d])); 588 } 589 } 590 591 // Feature List Sparse ----------------------------------------------------- 592 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { 593 const string& key = feature_list_sparse_keys_t[d]; 594 const DataType& dtype = attrs_.feature_list_sparse_types[d]; 595 596 const auto& feature_list_found = feature_list_dict.find(key); 597 bool feature_list_has_data = // Found key 598 (feature_list_found != feature_list_dict.end()); 599 600 std::vector<Tensor> sparse_values_tmp; 601 int64 feature_list_size = 0; 602 if (feature_list_has_data) { 603 const FeatureList& fl = feature_list_found->second; 604 feature_list_size = fl.feature_size(); 605 for (int64 t = 0; t < feature_list_size; ++t) { 606 const Feature& f = fl.feature(t); 607 bool types_match; 608 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); 609 OP_REQUIRES( 610 ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match, 611 errors::InvalidArgument("Name: ", name, ", Feature List: ", key, 612 ", Index: ", t, 613 ". Data types don't match. ", 614 "Expected type: ", DataTypeString(dtype), 615 " Feature is: ", ProtoDebugString(f))); 616 sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f)); 617 } 618 } else { 619 sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0}))); 620 } 621 622 int64 total_num_features = 0; 623 int64 max_num_features = 0; 624 for (int t = 0; t < feature_list_size; ++t) { 625 const Tensor& v = sparse_values_tmp[t]; 626 const int64 num_elements = v.shape().num_elements(); 627 total_num_features += num_elements; 628 max_num_features = std::max(max_num_features, num_elements); 629 } 630 631 TensorShape indices_shape({total_num_features, 2}); 632 TensorShape values_shape({total_num_features}); 633 Tensor* sp_indices_d = nullptr; 634 Tensor* sp_values_d = nullptr; 635 Tensor* sp_shape_d = nullptr; 636 OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape, 637 &sp_indices_d)); 638 OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape, 639 &sp_values_d)); 640 OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate( 641 d, TensorShape({2}), &sp_shape_d)); 642 auto shape_t = sp_shape_d->vec<int64>(); 643 shape_t(0) = feature_list_size; 644 shape_t(1) = max_num_features; 645 646 int64 offset = 0; 647 648 for (int t = 0; t < feature_list_size; ++t) { 649 const int64 num_elements = CopyIntoSparseTensor( 650 sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d); 651 offset += num_elements; 652 } 653 } 654 } 655 656 protected: 657 ParseSingleSequenceExampleAttrs attrs_; 658 }; 659 660 REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU), 661 SingleSequenceExampleParserOp); 662 663 #ifndef IS_MOBILE_PLATFORM 664 // when using lite protos on mobile, decoding JSON is not available. 665 666 class DecodeJSONExampleOp : public OpKernel { 667 public: 668 explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 669 resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool( 670 "type.googleapis.com", protobuf::DescriptorPool::generated_pool())); 671 } 672 673 void Compute(OpKernelContext* ctx) { 674 const Tensor* json_examples; 675 OP_REQUIRES_OK(ctx, ctx->input("json_examples", &json_examples)); 676 Tensor* binary_examples; 677 OP_REQUIRES_OK( 678 ctx, ctx->allocate_output("binary_examples", json_examples->shape(), 679 &binary_examples)); 680 681 for (int i = 0; i < json_examples->NumElements(); ++i) { 682 const string& json_example = json_examples->flat<string>()(i); 683 auto status = protobuf::util::JsonToBinaryString( 684 resolver_.get(), "type.googleapis.com/tensorflow.Example", 685 json_example, &binary_examples->flat<string>()(i)); 686 OP_REQUIRES(ctx, status.ok(), 687 errors::InvalidArgument("Error while parsing JSON: ", 688 string(status.error_message()))); 689 } 690 } 691 692 private: 693 std::unique_ptr<protobuf::util::TypeResolver> resolver_; 694 }; 695 696 REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample").Device(DEVICE_CPU), 697 DecodeJSONExampleOp); 698 #endif 699 700 } // namespace tensorflow 701