Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/parsing_ops.cc.
     17 
     18 #include <numeric>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/example/example.pb.h"
     23 #include "tensorflow/core/example/feature.pb_text.h"
     24 #include "tensorflow/core/framework/common_shape_fns.h"
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/lib/gtl/array_slice.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 #include "tensorflow/core/platform/protobuf.h"
     30 #include "tensorflow/core/util/example_proto_fast_parsing.h"
     31 #include "tensorflow/core/util/example_proto_helper.h"
     32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     33 #include "tensorflow/core/util/work_sharder.h"
     34 
     35 namespace tensorflow {
     36 
     37 class ParseExampleOp : public OpKernel {
     38  public:
     39   explicit ParseExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     40     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
     41   }
     42 
     43   void Compute(OpKernelContext* ctx) override {
     44     const Tensor* names;
     45     const Tensor* serialized;
     46     OpInputList dense_keys;
     47     OpInputList sparse_keys;
     48     OpInputList dense_defaults;
     49 
     50     // Grab the input list arguments.
     51     OP_REQUIRES_OK(ctx, ctx->input("names", &names));
     52     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
     53     OP_REQUIRES_OK(ctx, ctx->input_list("dense_keys", &dense_keys));
     54     OP_REQUIRES_OK(ctx, ctx->input_list("sparse_keys", &sparse_keys));
     55     OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
     56 
     57     std::vector<string> dense_keys_t(attrs_.num_dense);
     58     std::vector<string> sparse_keys_t(attrs_.num_sparse);
     59 
     60     // Check that the input list sizes match the attribute declared sizes.
     61     CHECK_EQ(dense_keys.size(), attrs_.num_dense);
     62     CHECK_EQ(sparse_keys.size(), attrs_.num_sparse);
     63 
     64     // Copy from OpInputList to std::vector<string>.
     65     for (int di = 0; di < attrs_.num_dense; ++di) {
     66       dense_keys_t[di] = dense_keys[di].scalar<string>()();
     67     }
     68     for (int di = 0; di < attrs_.num_sparse; ++di) {
     69       sparse_keys_t[di] = sparse_keys[di].scalar<string>()();
     70     }
     71 
     72     if (names->NumElements() > 0) {
     73       OP_REQUIRES(
     74           ctx, TensorShapeUtils::IsVector(names->shape()),
     75           errors::InvalidArgument("Expected names to be a vector, got shape: ",
     76                                   names->shape().DebugString()));
     77       OP_REQUIRES(
     78           ctx, names->NumElements() == serialized->NumElements(),
     79           errors::InvalidArgument(
     80               "Expected len(names) == len(serialized), but got: ",
     81               names->NumElements(), " vs. ", serialized->NumElements()));
     82     }
     83 
     84     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
     85                 errors::InvalidArgument(
     86                     "Expected serialized to be a vector, got shape: ",
     87                     serialized->shape().DebugString()));
     88     OP_REQUIRES(ctx, dense_defaults.size() == attrs_.num_dense,
     89                 errors::InvalidArgument(
     90                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
     91                     dense_defaults.size(), " vs. ", attrs_.num_dense));
     92 
     93     for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) {
     94       const Tensor& def_value = dense_defaults[d];
     95       if (attrs_.variable_length[d]) {
     96         OP_REQUIRES(ctx, def_value.NumElements() == 1,
     97                     errors::InvalidArgument(
     98                         "dense_shape[", d, "] is a variable length shape: ",
     99                         attrs_.dense_shapes[d].DebugString(),
    100                         ", therefore "
    101                         "def_value[",
    102                         d,
    103                         "] must contain a single element ("
    104                         "the padding element).  But its shape is: ",
    105                         def_value.shape().DebugString()));
    106       } else if (def_value.NumElements() > 0) {
    107         OP_REQUIRES(ctx,
    108                     attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()),
    109                     errors::InvalidArgument(
    110                         "def_value[", d,
    111                         "].shape() == ", def_value.shape().DebugString(),
    112                         " is not compatible with dense_shapes_[", d,
    113                         "] == ", attrs_.dense_shapes[d].DebugString()));
    114       }
    115       OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d],
    116                   errors::InvalidArgument(
    117                       "dense_defaults[", d, "].dtype() == ",
    118                       DataTypeString(def_value.dtype()), " != dense_types_[", d,
    119                       "] == ", DataTypeString(attrs_.dense_types[d])));
    120     }
    121 
    122     example::Result result;
    123 
    124     example::FastParseExampleConfig config;
    125     for (int d = 0; d < attrs_.num_dense; ++d) {
    126       config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d],
    127                               attrs_.dense_shapes[d], dense_defaults[d],
    128                               attrs_.variable_length[d],
    129                               attrs_.elements_per_stride[d]});
    130     }
    131     for (int d = 0; d < attrs_.num_sparse; ++d) {
    132       config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]});
    133     }
    134 
    135     auto serialized_t = serialized->flat<string>();
    136     auto names_t = names->flat<string>();
    137     gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
    138     gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size());
    139 
    140     OP_REQUIRES_OK(
    141         ctx,
    142         FastParseExample(
    143             config, slice, names_slice,
    144             ctx->device()->tensorflow_cpu_worker_threads()->workers, &result));
    145 
    146     OpOutputList dense_values;
    147     OpOutputList sparse_indices;
    148     OpOutputList sparse_values;
    149     OpOutputList sparse_shapes;
    150     OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
    151     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
    152     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
    153     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
    154     for (int d = 0; d < attrs_.num_dense; ++d) {
    155       dense_values.set(d, result.dense_values[d]);
    156     }
    157     for (int d = 0; d < attrs_.num_sparse; ++d) {
    158       sparse_indices.set(d, result.sparse_indices[d]);
    159       sparse_values.set(d, result.sparse_values[d]);
    160       sparse_shapes.set(d, result.sparse_shapes[d]);
    161     }
    162   }
    163 
    164  protected:
    165   ParseExampleAttrs attrs_;
    166 };
    167 
    168 REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU),
    169                         ParseExampleOp);
    170 
    171 class ParseSingleExampleOp : public OpKernel {
    172  public:
    173   explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    174     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
    175   }
    176 
    177   void Compute(OpKernelContext* ctx) override {
    178     const Tensor* serialized;
    179     OpInputList dense_defaults;
    180 
    181     // Grab the input list arguments.
    182     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
    183     OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
    184 
    185     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
    186                 errors::InvalidArgument(
    187                     "Expected serialized to be a scalar, got shape: ",
    188                     serialized->shape().DebugString()));
    189     OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(),
    190                 errors::InvalidArgument(
    191                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
    192                     dense_defaults.size(), " vs. ", attrs_.dense_keys.size()));
    193 
    194     for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) {
    195       const Tensor& def_value = dense_defaults[d];
    196       if (attrs_.variable_length[d]) {
    197         OP_REQUIRES(ctx, def_value.NumElements() == 1,
    198                     errors::InvalidArgument(
    199                         "dense_shape[", d, "] is a variable length shape: ",
    200                         attrs_.dense_shapes[d].DebugString(),
    201                         ", therefore "
    202                         "def_value[",
    203                         d,
    204                         "] must contain a single element ("
    205                         "the padding element).  But its shape is: ",
    206                         def_value.shape().DebugString()));
    207       } else if (def_value.NumElements() > 0) {
    208         OP_REQUIRES(ctx,
    209                     attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()),
    210                     errors::InvalidArgument(
    211                         "def_value[", d,
    212                         "].shape() == ", def_value.shape().DebugString(),
    213                         " is not compatible with dense_shapes_[", d,
    214                         "] == ", attrs_.dense_shapes[d].DebugString()));
    215       }
    216       OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d],
    217                   errors::InvalidArgument(
    218                       "dense_defaults[", d, "].dtype() == ",
    219                       DataTypeString(def_value.dtype()), " != dense_types_[", d,
    220                       "] == ", DataTypeString(attrs_.dense_types[d])));
    221     }
    222 
    223     example::Result result;
    224 
    225     // TODO(mrry): Build the configuration once and cache it.
    226     example::FastParseExampleConfig config;
    227     for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
    228       config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d],
    229                               attrs_.dense_shapes[d], dense_defaults[d],
    230                               attrs_.variable_length[d],
    231                               attrs_.elements_per_stride[d]});
    232     }
    233     for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
    234       config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]});
    235     }
    236 
    237     const string& serialized_proto = serialized->scalar<string>()();
    238 
    239     OP_REQUIRES_OK(ctx,
    240                    FastParseSingleExample(config, serialized_proto, &result));
    241 
    242     OpOutputList dense_values;
    243     OpOutputList sparse_indices;
    244     OpOutputList sparse_values;
    245     OpOutputList sparse_shapes;
    246     OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
    247     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
    248     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
    249     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
    250     for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
    251       dense_values.set(d, result.dense_values[d]);
    252     }
    253     for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
    254       sparse_indices.set(d, result.sparse_indices[d]);
    255       sparse_values.set(d, result.sparse_values[d]);
    256       sparse_shapes.set(d, result.sparse_shapes[d]);
    257     }
    258   }
    259 
    260  protected:
    261   ParseSingleExampleAttrs attrs_;
    262 };
    263 
    264 REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
    265                         ParseSingleExampleOp);
    266 
    267 class SingleSequenceExampleParserOp : public OpKernel {
    268  public:
    269   explicit SingleSequenceExampleParserOp(OpKernelConstruction* ctx)
    270       : OpKernel(ctx) {
    271     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
    272   }
    273 
    274   void Compute(OpKernelContext* ctx) override {
    275     const Tensor* debug_name;
    276     const Tensor* serialized;
    277     OpInputList context_dense_keys;
    278     OpInputList context_sparse_keys;
    279     OpInputList context_dense_defaults;
    280     OpInputList feature_list_dense_keys;
    281     OpInputList feature_list_sparse_keys;
    282     const Tensor* feature_list_dense_missing_assumed_empty;
    283 
    284     OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
    285     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
    286     OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty",
    287                                    &feature_list_dense_missing_assumed_empty));
    288     OP_REQUIRES_OK(ctx,
    289                    ctx->input_list("context_dense_keys", &context_dense_keys));
    290     OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys",
    291                                         &feature_list_dense_keys));
    292     OP_REQUIRES_OK(
    293         ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys));
    294     OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys",
    295                                         &feature_list_sparse_keys));
    296     OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
    297                                         &context_dense_defaults));
    298 
    299     std::vector<string> context_dense_keys_t(attrs_.num_context_dense);
    300     std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse);
    301     std::vector<string> feature_list_dense_keys_t(
    302         attrs_.num_feature_list_dense);
    303     std::vector<string> feature_list_sparse_keys_t(
    304         attrs_.num_feature_list_sparse);
    305     std::unordered_set<string> feature_list_dense_missing_assumed_empty_set;
    306     CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense);
    307     CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse);
    308     CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense);
    309     CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse);
    310     for (int di = 0; di < attrs_.num_context_dense; ++di) {
    311       OP_REQUIRES(ctx,
    312                   TensorShapeUtils::IsScalar(context_dense_keys[di].shape()),
    313                   errors::InvalidArgument(
    314                       "Expected context_dense_keys[", di,
    315                       "] to be a scalar, got shape: ",
    316                       context_dense_keys[di].shape().DebugString()));
    317       context_dense_keys_t[di] = context_dense_keys[di].scalar<string>()();
    318     }
    319     for (int di = 0; di < attrs_.num_context_sparse; ++di) {
    320       OP_REQUIRES(ctx,
    321                   TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()),
    322                   errors::InvalidArgument(
    323                       "Expected context_sparse_keys[", di,
    324                       "] to be a scalar, got shape: ",
    325                       context_sparse_keys[di].shape().DebugString()));
    326       context_sparse_keys_t[di] = context_sparse_keys[di].scalar<string>()();
    327     }
    328     for (int di = 0; di < attrs_.num_feature_list_dense; ++di) {
    329       OP_REQUIRES(
    330           ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()),
    331           errors::InvalidArgument(
    332               "Expected feature_list_dense_keys[", di,
    333               "] to be a scalar, got shape: ",
    334               feature_list_dense_keys[di].shape().DebugString()));
    335       feature_list_dense_keys_t[di] =
    336           feature_list_dense_keys[di].scalar<string>()();
    337     }
    338     for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) {
    339       OP_REQUIRES(
    340           ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()),
    341           errors::InvalidArgument(
    342               "Expected feature_list_sparse_keys[", di,
    343               "] to be a scalar, got shape: ",
    344               feature_list_sparse_keys[di].shape().DebugString()));
    345       feature_list_sparse_keys_t[di] =
    346           feature_list_sparse_keys[di].scalar<string>()();
    347     }
    348     OP_REQUIRES(
    349         ctx,
    350         TensorShapeUtils::IsVector(
    351             feature_list_dense_missing_assumed_empty->shape()),
    352         errors::InvalidArgument(
    353             "Expected feature_list_dense_missing_assumed_empty ",
    354             "to be a vector, got shape: ",
    355             feature_list_dense_missing_assumed_empty->shape().DebugString()));
    356     auto feature_list_dense_missing_assumped_empty_t =
    357         feature_list_dense_missing_assumed_empty->vec<string>();
    358     for (int de = 0;
    359          de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) {
    360       feature_list_dense_missing_assumed_empty_set.insert(
    361           feature_list_dense_missing_assumped_empty_t(de));
    362     }
    363 
    364     bool has_debug_name = (debug_name->NumElements() > 0);
    365     if (has_debug_name) {
    366       OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()),
    367                   errors::InvalidArgument(
    368                       "Expected debug_name to be a scalar, got shape: ",
    369                       debug_name->shape().DebugString()));
    370     }
    371     auto debug_name_t = debug_name->scalar<string>();
    372 
    373     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
    374                 errors::InvalidArgument(
    375                     "Expected serialized to be a scalar, got shape: ",
    376                     serialized->shape().DebugString()));
    377 
    378     OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
    379                 errors::InvalidArgument("Expected len(context_dense_defaults) "
    380                                         "== len(context_dense_keys) but got: ",
    381                                         context_dense_defaults.size(), " vs. ",
    382                                         attrs_.num_context_dense));
    383 
    384     std::vector<bool> required(attrs_.num_context_dense);
    385     for (int d = 0; d < attrs_.num_context_dense; ++d) {
    386       const Tensor& def_value = context_dense_defaults[d];
    387       required[d] = (def_value.NumElements() == 0);  // No default provided.
    388 
    389       if (def_value.NumElements() > 0) {
    390         OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
    391                     errors::InvalidArgument(
    392                         "def_value[", d,
    393                         "].shape() == ", def_value.shape().DebugString(),
    394                         " != context_dense_shapes_[", d,
    395                         "] == ", attrs_.context_dense_shapes[d].DebugString()));
    396         OP_REQUIRES(
    397             ctx, def_value.dtype() == attrs_.context_dense_types[d],
    398             errors::InvalidArgument(
    399                 "context_dense_defaults[", d, "].dtype() == ",
    400                 DataTypeString(def_value.dtype()), " != context_dense_types_[",
    401                 d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
    402       }
    403     }
    404 
    405     auto serialized_t = serialized->scalar<string>();
    406 
    407     OpOutputList context_sparse_indices;
    408     OpOutputList context_sparse_values;
    409     OpOutputList context_sparse_shapes;
    410     OpOutputList context_dense_values;
    411     OpOutputList feature_list_sparse_indices;
    412     OpOutputList feature_list_sparse_values;
    413     OpOutputList feature_list_sparse_shapes;
    414     OpOutputList feature_list_dense_values;
    415 
    416     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
    417                                          &context_sparse_indices));
    418     OP_REQUIRES_OK(
    419         ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
    420     OP_REQUIRES_OK(
    421         ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
    422     OP_REQUIRES_OK(
    423         ctx, ctx->output_list("context_dense_values", &context_dense_values));
    424     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
    425                                          &context_sparse_indices));
    426     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
    427                                          &feature_list_sparse_indices));
    428     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
    429                                          &feature_list_sparse_values));
    430     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
    431                                          &feature_list_sparse_shapes));
    432     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
    433                                          &feature_list_dense_values));
    434 
    435     SequenceExample ex;
    436     OP_REQUIRES(
    437         ctx, ParseProtoUnlimited(&ex, serialized_t()),
    438         errors::InvalidArgument("Could not parse example input, value: '",
    439                                 serialized_t(), "'"));
    440 
    441     const string& name = (has_debug_name) ? debug_name_t() : "<unknown>";
    442     const Features& context = ex.context();
    443     const auto& context_dict = context.feature();
    444 
    445     // Context Dense -----------------------------------------------------------
    446 
    447     // Preallocate context_dense_values, since we know their sizes
    448     for (int d = 0; d < attrs_.num_context_dense; ++d) {
    449       TensorShape out_shape;
    450       for (const int dim : attrs_.context_dense_shapes[d].dim_sizes())
    451         out_shape.AddDim(dim);
    452       Tensor* out = nullptr;
    453       OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out));
    454     }
    455 
    456     for (int d = 0; d < attrs_.num_context_dense; ++d) {
    457       const string& key = context_dense_keys_t[d];
    458       const DataType& dtype = attrs_.context_dense_types[d];
    459       const TensorShape& shape = attrs_.context_dense_shapes[d];
    460 
    461       const auto& feature_found = context_dict.find(key);
    462       OP_REQUIRES(
    463           ctx, (feature_found != context_dict.end()) || !required[d],
    464           errors::InvalidArgument("Name: ", name, ", Context feature '", key,
    465                                   "' is required but could not be found."));
    466       if (feature_found != context_dict.end()) {
    467         const Feature& f = feature_found->second;
    468         bool types_match;
    469         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
    470         OP_REQUIRES(
    471             ctx, types_match,
    472             errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
    473                                     ".  Data types don't match. ",
    474                                     "Expected type: ", DataTypeString(dtype),
    475                                     "  Feature is: ", ProtoDebugString(f)));
    476 
    477         OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f,
    478                                              context_dense_values[d]));
    479       } else {
    480         RowDenseCopy(0, dtype, context_dense_defaults[d],
    481                      context_dense_values[d]);
    482       }
    483     }
    484 
    485     // Context Sparse ----------------------------------------------------------
    486     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
    487       const string& key = context_sparse_keys_t[d];
    488       const DataType& dtype = attrs_.context_sparse_types[d];
    489 
    490       const auto& feature_found = context_dict.find(key);
    491       bool feature_has_data =  // Found key & data type is set
    492           (feature_found != context_dict.end() &&
    493            (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
    494 
    495       if (feature_has_data) {
    496         const Feature& f = feature_found->second;
    497         bool types_match;
    498         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
    499         OP_REQUIRES(
    500             ctx, types_match,
    501             errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
    502                                     ".  Data types don't match. ",
    503                                     "Expected type: ", DataTypeString(dtype),
    504                                     "  Feature is: ", ProtoDebugString(f)));
    505 
    506         Tensor feature_values = FeatureSparseCopy(0, key, dtype, f);
    507         const int64 num_elements = feature_values.NumElements();
    508         TensorShape indices_shape({num_elements, 1});
    509         Tensor* sp_indices_d = nullptr;
    510         Tensor* sp_shape_d = nullptr;
    511         OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
    512                                                             &sp_indices_d));
    513         context_sparse_values.set(d, feature_values);
    514         OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
    515                                                            &sp_shape_d));
    516         auto shape_t = sp_shape_d->vec<int64>();
    517         shape_t(0) = num_elements;
    518         auto indices_t = sp_indices_d->matrix<int64>();
    519         std::iota(indices_t.data(), indices_t.data() + num_elements, 0);
    520       } else {
    521         TensorShape indices_shape({0, 1});
    522         TensorShape values_shape({0});
    523         Tensor* sp_indices_d = nullptr;
    524         Tensor* sp_values_d = nullptr;
    525         Tensor* sp_shape_d = nullptr;
    526         OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
    527                                                             &sp_indices_d));
    528         OP_REQUIRES_OK(
    529             ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d));
    530         OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
    531                                                            &sp_shape_d));
    532         auto shape_t = sp_shape_d->vec<int64>();
    533         shape_t(0) = 0;
    534       }
    535     }
    536 
    537     // Feature List Dense ------------------------------------------------------
    538 
    539     // Preallocate context_dense_values, since we can infer their
    540     // sizes
    541     const FeatureLists& feature_lists = ex.feature_lists();
    542     const auto& feature_list_dict = feature_lists.feature_list();
    543     FeatureList empty_feature_list;  // Placeholder for missing FLs
    544 
    545     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
    546       const string& key = feature_list_dense_keys_t[d];
    547       const DataType& dtype = attrs_.feature_list_dense_types[d];
    548       const TensorShape& shape = attrs_.feature_list_dense_shapes[d];
    549 
    550       const auto& feature_list_found = feature_list_dict.find(key);
    551       bool feature_list_missing =
    552           (feature_list_found == feature_list_dict.end());
    553       bool feature_list_allowed_missing =
    554           (feature_list_dense_missing_assumed_empty_set.count(key) > 0);
    555 
    556       OP_REQUIRES(
    557           ctx, !feature_list_missing || feature_list_allowed_missing,
    558           errors::InvalidArgument("Name: ", name, ", Feature list '", key,
    559                                   "' is required but could not be found.  "
    560                                   "Did you mean to include it in "
    561                                   "feature_list_dense_missing_assumed_empty or "
    562                                   "feature_list_dense_defaults?"));
    563 
    564       TensorShape out_shape;
    565       const FeatureList& fl = (feature_list_missing)
    566                                   ? empty_feature_list
    567                                   : feature_list_found->second;
    568       out_shape.AddDim(fl.feature_size());
    569       for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) {
    570         out_shape.AddDim(dim);
    571       }
    572       Tensor* out = nullptr;
    573       OP_REQUIRES_OK(ctx,
    574                      feature_list_dense_values.allocate(d, out_shape, &out));
    575 
    576       for (int64 t = 0; t < fl.feature_size(); ++t) {
    577         const Feature& f = fl.feature(t);
    578         bool types_match;
    579         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
    580         OP_REQUIRES(ctx, types_match,
    581                     errors::InvalidArgument(
    582                         "Name: ", name, ", Feature list: ", key, ", Index: ", t,
    583                         ".  Data types don't match. ",
    584                         "Expected type: ", DataTypeString(dtype),
    585                         "  Feature is: ", ProtoDebugString(f)));
    586         OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f,
    587                                              feature_list_dense_values[d]));
    588       }
    589     }
    590 
    591     // Feature List Sparse -----------------------------------------------------
    592     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
    593       const string& key = feature_list_sparse_keys_t[d];
    594       const DataType& dtype = attrs_.feature_list_sparse_types[d];
    595 
    596       const auto& feature_list_found = feature_list_dict.find(key);
    597       bool feature_list_has_data =  // Found key
    598           (feature_list_found != feature_list_dict.end());
    599 
    600       std::vector<Tensor> sparse_values_tmp;
    601       int64 feature_list_size = 0;
    602       if (feature_list_has_data) {
    603         const FeatureList& fl = feature_list_found->second;
    604         feature_list_size = fl.feature_size();
    605         for (int64 t = 0; t < feature_list_size; ++t) {
    606           const Feature& f = fl.feature(t);
    607           bool types_match;
    608           OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
    609           OP_REQUIRES(
    610               ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match,
    611               errors::InvalidArgument("Name: ", name, ", Feature List: ", key,
    612                                       ", Index: ", t,
    613                                       ".  Data types don't match. ",
    614                                       "Expected type: ", DataTypeString(dtype),
    615                                       "  Feature is: ", ProtoDebugString(f)));
    616           sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f));
    617         }
    618       } else {
    619         sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0})));
    620       }
    621 
    622       int64 total_num_features = 0;
    623       int64 max_num_features = 0;
    624       for (int t = 0; t < feature_list_size; ++t) {
    625         const Tensor& v = sparse_values_tmp[t];
    626         const int64 num_elements = v.shape().num_elements();
    627         total_num_features += num_elements;
    628         max_num_features = std::max(max_num_features, num_elements);
    629       }
    630 
    631       TensorShape indices_shape({total_num_features, 2});
    632       TensorShape values_shape({total_num_features});
    633       Tensor* sp_indices_d = nullptr;
    634       Tensor* sp_values_d = nullptr;
    635       Tensor* sp_shape_d = nullptr;
    636       OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape,
    637                                                                &sp_indices_d));
    638       OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape,
    639                                                               &sp_values_d));
    640       OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate(
    641                               d, TensorShape({2}), &sp_shape_d));
    642       auto shape_t = sp_shape_d->vec<int64>();
    643       shape_t(0) = feature_list_size;
    644       shape_t(1) = max_num_features;
    645 
    646       int64 offset = 0;
    647 
    648       for (int t = 0; t < feature_list_size; ++t) {
    649         const int64 num_elements = CopyIntoSparseTensor(
    650             sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d);
    651         offset += num_elements;
    652       }
    653     }
    654   }
    655 
    656  protected:
    657   ParseSingleSequenceExampleAttrs attrs_;
    658 };
    659 
    660 REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
    661                         SingleSequenceExampleParserOp);
    662 
    663 #ifndef IS_MOBILE_PLATFORM
    664 // when using lite protos on mobile, decoding JSON is not available.
    665 
    666 class DecodeJSONExampleOp : public OpKernel {
    667  public:
    668   explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    669     resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool(
    670         "type.googleapis.com", protobuf::DescriptorPool::generated_pool()));
    671   }
    672 
    673   void Compute(OpKernelContext* ctx) {
    674     const Tensor* json_examples;
    675     OP_REQUIRES_OK(ctx, ctx->input("json_examples", &json_examples));
    676     Tensor* binary_examples;
    677     OP_REQUIRES_OK(
    678         ctx, ctx->allocate_output("binary_examples", json_examples->shape(),
    679                                   &binary_examples));
    680 
    681     for (int i = 0; i < json_examples->NumElements(); ++i) {
    682       const string& json_example = json_examples->flat<string>()(i);
    683       auto status = protobuf::util::JsonToBinaryString(
    684           resolver_.get(), "type.googleapis.com/tensorflow.Example",
    685           json_example, &binary_examples->flat<string>()(i));
    686       OP_REQUIRES(ctx, status.ok(),
    687                   errors::InvalidArgument("Error while parsing JSON: ",
    688                                           string(status.error_message())));
    689     }
    690   }
    691 
    692  private:
    693   std::unique_ptr<protobuf::util::TypeResolver> resolver_;
    694 };
    695 
    696 REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample").Device(DEVICE_CPU),
    697                         DecodeJSONExampleOp);
    698 #endif
    699 
    700 }  // namespace tensorflow
    701