Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
     17 // serialized as strings.
     18 //
     19 // See docs in ../ops/decode_proto_op.cc.
     20 //
     21 // This implementation reads the serialized format using a handful of calls from
     22 // the WireFormatLite API used by generated proto code. WireFormatLite is marked
     23 // as an "internal" proto API but is widely used in practice and highly unlikely
     24 // to change. This will be much faster than the previous implementation based on
     25 // constructing a temporary dynamic message in memory and using the proto
     26 // reflection api to read it. It can be used with any proto whose descriptors
     27 // are available at runtime but should be competitive in speed with approaches
     28 // that compile in the proto definitions.
     29 
     30 #include <memory>
     31 #include <string>
     32 #include <vector>
     33 
     34 #include "absl/container/flat_hash_map.h"
     35 #include "third_party/eigen3/Eigen/Core"
     36 #include "tensorflow/core/framework/op_kernel.h"
     37 #include "tensorflow/core/framework/tensor_types.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/lib/core/errors.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/protobuf.h"
     42 #include "tensorflow/core/util/proto/decode.h"
     43 #include "tensorflow/core/util/proto/descriptors.h"
     44 #include "tensorflow/core/util/proto/proto_utils.h"
     45 #include "tensorflow/core/util/ptr_util.h"
     46 
     47 namespace tensorflow {
     48 namespace {
     49 
     50 using ::tensorflow::MakeUnique;
     51 using ::tensorflow::protobuf::Descriptor;
     52 using ::tensorflow::protobuf::DescriptorPool;
     53 using ::tensorflow::protobuf::DynamicMessageFactory;
     54 using ::tensorflow::protobuf::FieldDescriptor;
     55 using ::tensorflow::protobuf::Message;
     56 using ::tensorflow::protobuf::TextFormat;
     57 using ::tensorflow::protobuf::internal::WireFormatLite;
     58 using ::tensorflow::protobuf::io::CodedInputStream;
     59 
     60 const bool kFailOnDecodeError = true;
     61 
     62 // Used to store the default value of a protocol message field, casted to the
     63 // type of the output tensor.
     64 //
     65 // TODO(paskin): Use absl::variant once TensorFlow gets absl dependencies.
     66 struct DefaultValue {
     67   DataType dtype = DataType::DT_INVALID;
     68   union Value {
     69     bool v_bool;           // DT_BOOL
     70     double v_double;       // DT_DOUBLE
     71     float v_float;         // DT_FLOAT
     72     int8 v_int8;           // DT_INT8
     73     int32 v_int32;         // DT_INT32
     74     int64 v_int64;         // DT_INT64
     75     const char* v_string;  // DT_STRING
     76     uint8 v_uint8;         // DT_UINT8
     77     uint8 v_uint32;        // DT_UINT32
     78     uint8 v_uint64;        // DT_UINT64
     79   };
     80   Value value;
     81 };
     82 
     83 // Initializes a DefaultValue object.  This generic template handles numeric
     84 // types and strings are handled by a template specialization below.
     85 //
     86 // Args:
     87 //   dtype: the type of the output tensor
     88 //   value: the default value as obtained from the FieldDescriptor
     89 //   result: the object to initialize
     90 template <typename T>
     91 Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
     92   result->dtype = dtype;
     93   switch (dtype) {
     94     case DT_BOOL:
     95       result->value.v_bool = static_cast<bool>(value);
     96       break;
     97     case DT_DOUBLE:
     98       result->value.v_double = static_cast<double>(value);
     99       break;
    100     case DT_FLOAT:
    101       result->value.v_float = static_cast<float>(value);
    102       break;
    103     case DT_INT8:
    104       result->value.v_int8 = static_cast<int8>(value);
    105       break;
    106     case DT_INT32:
    107       result->value.v_int32 = static_cast<int32>(value);
    108       break;
    109     case DT_INT64:
    110       result->value.v_int64 = static_cast<int64>(value);
    111       break;
    112     case DT_UINT8:
    113       result->value.v_uint8 = static_cast<uint8>(value);
    114       break;
    115     case DT_UINT32:
    116       result->value.v_uint32 = static_cast<uint32>(value);
    117       break;
    118     case DT_UINT64:
    119       result->value.v_uint64 = static_cast<uint64>(value);
    120       break;
    121     default:
    122       // We should never get here, given the type checking that occurs earlier.
    123       return errors::Internal(
    124           "Cannot initialize default value for unsupported type: ",
    125           DataTypeString(dtype));
    126   }
    127   return Status::OK();
    128 }
    129 
    130 template <>
    131 Status InitDefaultValue(DataType dtype, const char* value,
    132                         DefaultValue* result) {
    133   // These are sanity checks that should never trigger given the code that
    134   // leads here.
    135   if (TF_PREDICT_FALSE(dtype != DT_STRING)) {
    136     return errors::InvalidArgument(
    137         "Cannot cast field to anything but DT_STRING");
    138   }
    139   if (TF_PREDICT_FALSE(value == nullptr)) {
    140     return errors::InvalidArgument("Null default string value.");
    141   }
    142   result->dtype = DT_STRING;
    143   result->value.v_string = value;
    144   return Status::OK();
    145 }
    146 
    147 // Initializes a default value from the output data type and the field
    148 // descriptor.
    149 Status InitDefaultValueFromFieldDescriptor(DataType dtype,
    150                                            const FieldDescriptor* field_desc,
    151                                            DefaultValue* result) {
    152   switch (field_desc->type()) {
    153     case WireFormatLite::TYPE_DOUBLE:
    154       return InitDefaultValue(dtype, field_desc->default_value_double(),
    155                               result);
    156     case WireFormatLite::TYPE_FLOAT:
    157       return InitDefaultValue(dtype, field_desc->default_value_float(), result);
    158     case WireFormatLite::TYPE_INT64:
    159     case WireFormatLite::TYPE_SINT64:
    160     case WireFormatLite::TYPE_SFIXED64:
    161       return InitDefaultValue(dtype, field_desc->default_value_int64(), result);
    162     case WireFormatLite::TYPE_FIXED64:
    163     case WireFormatLite::TYPE_UINT64:
    164       return InitDefaultValue(dtype, field_desc->default_value_uint64(),
    165                               result);
    166     case WireFormatLite::TYPE_ENUM:
    167     case WireFormatLite::TYPE_INT32:
    168     case WireFormatLite::TYPE_SINT32:
    169     case WireFormatLite::TYPE_SFIXED32:
    170       return InitDefaultValue(dtype, field_desc->default_value_int32(), result);
    171     case WireFormatLite::TYPE_FIXED32:
    172     case WireFormatLite::TYPE_UINT32:
    173       return InitDefaultValue(dtype, field_desc->default_value_uint32(),
    174                               result);
    175     case WireFormatLite::TYPE_BOOL:
    176       return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
    177     case WireFormatLite::TYPE_BYTES:
    178     case WireFormatLite::TYPE_STRING:
    179       // Manipulating default string values as C-style pointers should be OK
    180       // for typical code-generated protocol messages.  It is possible in
    181       // principle to register a message descriptor on the fly, and these
    182       // pointers may not be stable if that descriptor has a weird
    183       // implementation.  (But the return type of default_value_string() is
    184       // const string&, so it'd have to be very weird.)
    185       return InitDefaultValue(dtype, field_desc->default_value_string().c_str(),
    186                               result);
    187     case WireFormatLite::TYPE_GROUP:
    188     case WireFormatLite::TYPE_MESSAGE:
    189       return InitDefaultValue(dtype, "", result);
    190       // default: intentionally omitted in order to enable static checking.
    191   }
    192   return Status::OK();
    193 }
    194 
    195 // A FieldInfo holds a handful of information from the FieldDescriptor
    196 // and user attributes.
    197 struct FieldInfo {
    198   FieldInfo(const FieldDescriptor* field_desc, int user_index,
    199             DefaultValue def_value)
    200       : output_index(user_index), default_value(def_value) {
    201     // Without this intermediate data structure, the profile had hotspots
    202     // calling methods of FieldDescriptor.
    203     number = field_desc->number();
    204 
    205     // The wire format library defines the same constants used in
    206     // descriptor.proto. This static_cast is safe because they are guaranteed to
    207     // stay in sync. We need the field type from the FieldDescriptor here
    208     // because the wire format doesn't tell us anything about what happens
    209     // inside a packed repeated field: there is enough information in the wire
    210     // format to skip the whole field but not enough to know how to parse what's
    211     // inside. For that we go to the schema.
    212     type = static_cast<WireFormatLite::FieldType>(field_desc->type());
    213     is_repeated = field_desc->is_repeated();
    214   }
    215 
    216   // Disable copy and move.
    217   FieldInfo(const FieldInfo&) = delete;
    218   FieldInfo& operator=(const FieldInfo&) = delete;
    219 
    220   // Internally we sort field descriptors by wire number for fast lookup. In
    221   // general this is different from the order given by the user. Output_index
    222   // gives the index into the field_names and output_types attributes and into
    223   // the output tensor list.
    224   int output_index = -1;
    225 
    226   // This is a cache of the relevant fields from `FieldDescriptorProto`. This
    227   // was added after noticing that FieldDescriptor->type() was using 6% of the
    228   // cpu profile.
    229   WireFormatLite::FieldType type;
    230   int number;
    231   bool is_repeated;
    232   DefaultValue default_value;
    233 };
    234 
    235 // A CountCollector counts sizes of repeated and optional fields in a proto.
    236 //
    237 // Each field is tracked by a single CountCollector instance. The instance
    238 // manages a single count, which is stored as a pointer (it is intended to be a
    239 // reference to the `sizes` output which is being filled in). The pointer is
    240 // passed in at initialization.
    241 //
    242 // Counting is done as a separate pass in order to allocate output tensors all
    243 // at once. This allows the TensorFlow runtime to optimize allocation for the
    244 // consumer, while removing the need for copying inside this op. After this
    245 // pass, the DenseCollector class (below) gathers the data: it is more complex
    246 // and provides better motivation for the API here.
    247 class CountCollector {
    248  public:
    249   CountCollector() = delete;
    250 
    251   // The count may be stored inside an Eigen Tensor to eliminate copying.
    252   explicit CountCollector(int32* count) : count_ptr_(count) {}
    253 
    254   // Reads (in this case counts) a single value.
    255   Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
    256     // Only repeated fields can have count > 1.
    257     if (*count_ptr_ == 0 || field.is_repeated) {
    258       (*count_ptr_)++;
    259     }
    260     // We expect a wire type based on the schema field_type, to allow a little
    261     // more checking.
    262     if (!SkipValue(input, field)) {
    263       return errors::DataLoss("ReadValue: Failed skipping field when counting");
    264     }
    265     return Status::OK();
    266   }
    267 
    268   // Reads (in this case counts) a length-delimited list of values.
    269   Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
    270                           size_t buf_size) {
    271     if (buf_size == 0) {
    272       return Status::OK();
    273     }
    274 
    275     const void* tmpbuf;
    276     int unused_max_buf_size;
    277 
    278     input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
    279     // This is safe because the underlying storage for the CodedInputStream is
    280     // owned by the input tensor. If it were a Cord or file-backed stream this
    281     // pointer would go stale after the bytes were skipped.
    282     const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
    283 
    284     // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
    285     // so the bounds check on buf_size inside Skip() is critical, and
    286     // must be done before scanning the contents.
    287     if (!input->Skip(buf_size)) {
    288       return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
    289     }
    290 
    291     // Dispatch to the appropriately typed field reader based on the schema
    292     // type.
    293     Status st;
    294     switch (field.type) {
    295       case WireFormatLite::TYPE_DOUBLE:
    296         st = CountPackedFixed<double>(buf, buf_size);
    297         break;
    298       case WireFormatLite::TYPE_FLOAT:
    299         st = CountPackedFixed<float>(buf, buf_size);
    300         break;
    301       case WireFormatLite::TYPE_INT64:
    302         st = CountPackedVarint(buf, buf_size);
    303         break;
    304       case WireFormatLite::TYPE_UINT64:
    305         st = CountPackedVarint(buf, buf_size);
    306         break;
    307       case WireFormatLite::TYPE_INT32:
    308         st = CountPackedVarint(buf, buf_size);
    309         break;
    310       case WireFormatLite::TYPE_FIXED64:
    311         st = CountPackedFixed<uint64>(buf, buf_size);
    312         break;
    313       case WireFormatLite::TYPE_FIXED32:
    314         st = CountPackedFixed<uint32>(buf, buf_size);
    315         break;
    316       case WireFormatLite::TYPE_BOOL:
    317         st = CountPackedVarint(buf, buf_size);
    318         break;
    319       case WireFormatLite::TYPE_STRING:
    320         st = errors::DataLoss("TYPE_STRING encountered as packed");
    321         break;
    322       case WireFormatLite::TYPE_GROUP:
    323         st = errors::DataLoss("TYPE_GROUP encountered as packed");
    324         break;
    325       case WireFormatLite::TYPE_MESSAGE:
    326         st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
    327         break;
    328       case WireFormatLite::TYPE_BYTES:
    329         st = errors::DataLoss("TYPE_BYTES encountered as packed");
    330         break;
    331       case WireFormatLite::TYPE_UINT32:
    332         st = CountPackedVarint(buf, buf_size);
    333         break;
    334       case WireFormatLite::TYPE_ENUM:
    335         st = CountPackedVarint(buf, buf_size);
    336         break;
    337       case WireFormatLite::TYPE_SFIXED32:
    338         st = CountPackedFixed<int32>(buf, buf_size);
    339         break;
    340       case WireFormatLite::TYPE_SFIXED64:
    341         st = CountPackedFixed<int64>(buf, buf_size);
    342         break;
    343       case WireFormatLite::TYPE_SINT32:
    344         st = CountPackedVarint(buf, buf_size);
    345         break;
    346       case WireFormatLite::TYPE_SINT64:
    347         st = CountPackedVarint(buf, buf_size);
    348         break;
    349         // default: intentionally omitted in order to enable static checking.
    350     }
    351     if (!st.ok()) {
    352       return st;
    353     }
    354 
    355     if (!field.is_repeated && *count_ptr_ > 1) {
    356       *count_ptr_ = 1;
    357     }
    358     return Status::OK();
    359   }
    360 
    361  private:
    362   // Skips a length-delimited value.
    363   static bool SkipBytes(CodedInputStream* input) {
    364     uint32 length;
    365     if (!input->ReadVarint32(&length)) {
    366       return false;
    367     }
    368     return input->Skip(length);
    369   }
    370 
    371   // Counts the number of packed varints in an array. The end of a varint is
    372   // signaled by a value < 0x80, so counting them requires parsing the
    373   // bytestream. It is the caller's responsibility to ensure that len > 0.
    374   Status CountPackedVarint(const uint8* buf, size_t len) {
    375     const uint8* bound = buf + len;
    376     int count;
    377 
    378     // The last byte in a valid encoded varint is guaranteed to have the high
    379     // bit unset. We rely on this property to prevent ReadVarint64FromArray from
    380     // going out of bounds, so validate the end of the buf before scanning
    381     // anything.
    382     if (bound[-1] & 0x80) {
    383       return errors::DataLoss("Corrupt packed varint");
    384     }
    385 
    386     // Now we can trust ReadVarint64FromArray to stay in bounds.
    387     for (count = 0; buf < bound; ++count) {
    388       uint64 temp;
    389       bool ok;
    390       buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
    391       if (!ok) {
    392         return errors::DataLoss("Corrupt packed varint");
    393       }
    394     }
    395 
    396     *count_ptr_ += count;
    397     return Status::OK();
    398   }
    399 
    400   // Counts the number of fixed-size values in a packed field. This can be done
    401   // without actually parsing anything.
    402   template <typename T>
    403   Status CountPackedFixed(const uint8* unused_buf, size_t len) {
    404     int count = len / sizeof(T);
    405     if (count * sizeof(T) != len) {
    406       return errors::DataLoss(
    407           "Illegal data length for packed fixed-size type: ", len);
    408     }
    409     *count_ptr_ += len / sizeof(T);
    410     return Status::OK();
    411   }
    412 
    413   // Skips a single value in the input stream. Dispatches to the appropriately
    414   // typed field skipper based on the schema type tag. This is not as permissive
    415   // as just handling the wire type.
    416   static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
    417     uint32 tmp32;
    418     protobuf_uint64 tmp64;
    419     switch (field.type) {
    420       case WireFormatLite::TYPE_DOUBLE:
    421         return input->ReadLittleEndian64(&tmp64);
    422       case WireFormatLite::TYPE_FLOAT:
    423         return input->ReadLittleEndian32(&tmp32);
    424       case WireFormatLite::TYPE_INT64:
    425         return input->ReadVarint64(&tmp64);
    426       case WireFormatLite::TYPE_UINT64:
    427         return input->ReadVarint64(&tmp64);
    428       case WireFormatLite::TYPE_INT32:
    429         return input->ReadVarint32(&tmp32);
    430       case WireFormatLite::TYPE_FIXED64:
    431         return input->ReadLittleEndian64(&tmp64);
    432       case WireFormatLite::TYPE_FIXED32:
    433         return input->ReadLittleEndian32(&tmp32);
    434       case WireFormatLite::TYPE_BOOL:
    435         return input->ReadVarint32(&tmp32);
    436       case WireFormatLite::TYPE_STRING:
    437         return SkipBytes(input);
    438       case WireFormatLite::TYPE_GROUP:
    439         return WireFormatLite::SkipField(
    440             input, WireFormatLite::MakeTag(
    441                        field.number, WireFormatLite::WIRETYPE_START_GROUP));
    442       case WireFormatLite::TYPE_MESSAGE:
    443         return SkipBytes(input);
    444       case WireFormatLite::TYPE_BYTES:
    445         return SkipBytes(input);
    446       case WireFormatLite::TYPE_UINT32:
    447         return input->ReadVarint32(&tmp32);
    448       case WireFormatLite::TYPE_ENUM:
    449         return input->ReadVarint32(&tmp32);
    450       case WireFormatLite::TYPE_SFIXED32:
    451         return input->ReadLittleEndian32(&tmp32);
    452       case WireFormatLite::TYPE_SFIXED64:
    453         return input->ReadLittleEndian64(&tmp64);
    454       case WireFormatLite::TYPE_SINT32:
    455         return input->ReadVarint32(&tmp32);
    456       case WireFormatLite::TYPE_SINT64:
    457         return input->ReadVarint64(&tmp64);
    458         // default: intentionally omitted in order to enable static checking.
    459     }
    460   }
    461 
    462   int32* count_ptr_ = nullptr;
    463 };
    464 
    465 // A DenseCollector accumulates values from a proto into a tensor.
    466 //
    467 // There is an instance of DenseCollector for each field of each proto. The
    468 // DenseCollector deserializes the value from the wire directly into the
    469 // preallocated output Tensor.
    470 //
    471 // This class is named DenseCollector because in the future there should be a
    472 // SparseCollector that accumulates field data into sparse tensors if the user
    473 // requests it.
    474 class DenseCollector {
    475  public:
    476   DenseCollector() = delete;
    477 
    478   // A DenseCollector applies to one field of a serialized message.
    479   // Note that default_value.dtype is the type of the output tensor.
    480   DenseCollector(uint8* datap, DefaultValue default_value, int max_repeat_count)
    481       : datap_(datap),
    482         default_value_(default_value),
    483         max_repeat_count_(max_repeat_count) {}
    484 
    485   // Reads a value from the input stream and stores it.
    486   //
    487   // Always inlining gave a ~50% speedup on microbenchmarks at one point.
    488   // TODO(nix): try removing it to see if that still holds.
    489   // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
    490   Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
    491     // For required and optional fields, we overwrite values[0] with
    492     // the latest one in the wire stream.
    493     // See https://developers.google.com/protocol-buffers/docs/encoding#optional
    494     // Only for repeated fields do we advance the next_repeat_index_ past 1.
    495     // TODO(nix): to handle oneof we must also zero out any previous values
    496     //  seen on the wire.
    497     int32 index = 0;
    498     if (field.is_repeated) {
    499       index = next_repeat_index_;
    500     }
    501     next_repeat_index_ = index + 1;
    502 
    503     return internal::ReadValue(input, field.type, field.number,
    504                                default_value_.dtype, index, datap_);
    505   }
    506 
    507   // Reads and stores a length-delimited list of values.
    508   Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
    509                           const size_t buf_size) {
    510     const void* buf;
    511     int unused_max_buf_size;
    512     input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
    513     // This is safe because the underlying storage for the CodedInputStream is
    514     // owned by the input tensor. If it were a Cord or file-backed stream this
    515     // pointer would go stale after the bytes were skipped.
    516     if (!input->Skip(buf_size)) {
    517       return errors::DataLoss(
    518           "ReadPackedValues: Skipping packed field failed.  Field tag: ",
    519           field.number);
    520     }
    521 
    522     // Setting stride=0 causes new values to overwrite old ones for
    523     // non-repeated fields.
    524     const int stride = field.is_repeated ? 1 : 0;
    525 
    526     if (next_repeat_index_ >= max_repeat_count_) {
    527       return errors::DataLoss(
    528           "ReadPackedValues: Tried to write more entries than allowed.  "
    529           "Field tag: ",
    530           field.number, ", Max entries allowed: ", max_repeat_count_);
    531     } else {
    532       return internal::ReadPackedFromArray(buf, buf_size, field.type,
    533                                            field.number, default_value_.dtype,
    534                                            stride, &next_repeat_index_, datap_);
    535     }
    536   }
    537 
    538   // Fills in any missing values in the output array with defaults. Dispatches
    539   // to the appropriately typed field default based on the runtime type tag.
    540   Status FillWithDefaults() {
    541     switch (default_value_.dtype) {
    542       case DataType::DT_BOOL:
    543         return FillDefault<bool>(default_value_.value.v_bool);
    544       case DataType::DT_FLOAT:
    545         return FillDefault<float>(default_value_.value.v_float);
    546       case DataType::DT_DOUBLE:
    547         return FillDefault<double>(default_value_.value.v_double);
    548       case DataType::DT_INT8:
    549         return FillDefault<int8>(default_value_.value.v_int8);
    550       case DataType::DT_INT32:
    551         return FillDefault<int32>(default_value_.value.v_int32);
    552       case DataType::DT_INT64:
    553         return FillDefault<int64>(default_value_.value.v_int64);
    554       case DataType::DT_STRING:
    555         return FillDefault<string>(default_value_.value.v_string);
    556       case DataType::DT_UINT8:
    557         return FillDefault<uint8>(default_value_.value.v_uint8);
    558       case DataType::DT_UINT32:
    559         return FillDefault<uint32>(default_value_.value.v_uint32);
    560       case DataType::DT_UINT64:
    561         return FillDefault<uint64>(default_value_.value.v_uint64);
    562       default:
    563         // There are many tensorflow dtypes not handled here, but they
    564         // should not come up unless type casting is added to the Op.
    565         // Chaining with tf.cast() should do the right thing until then.
    566         return errors::DataLoss("Failed filling defaults for ",
    567                                 DataTypeString(default_value_.dtype));
    568     }
    569   }
    570 
    571  private:
    572   // Fills empty values in the dense representation with a default value. This
    573   // uses next_repeat_index_ which counts the number of parsed values for the
    574   // field.
    575   template <class T>
    576   Status FillDefault(const T& default_value) {
    577     for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
    578       reinterpret_cast<T*>(datap_)[i] = default_value;
    579     }
    580     return Status::OK();
    581   }
    582 
    583   int32 next_repeat_index_ = 0;
    584 
    585   // This is a pointer to data_[message_index_]. There is no bounds checking at
    586   // this level: we computed the max repeat size for each field in
    587   // CountCollector and use the same code to traverse it here, so we are
    588   // guaranteed not to be called for more items than we have allocated space.
    589   void* const datap_ = nullptr;
    590 
    591   const DefaultValue default_value_;
    592   const int max_repeat_count_ = 0;
    593 };
    594 
    595 class DecodeProtoOp : public OpKernel {
    596  public:
    597   explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
    598     string descriptor_source;
    599     OP_REQUIRES_OK(context,
    600                    context->GetAttr("descriptor_source", &descriptor_source));
    601 
    602     // We always get back a desc_pool, but we may not own it. If we own it,
    603     // owned_desc_pool_ will be filled in.
    604     DescriptorPool const* desc_pool;
    605     OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
    606                                               &desc_pool, &owned_desc_pool_));
    607 
    608     string message_type;
    609     OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
    610 
    611     const Descriptor* message_desc =
    612         desc_pool->FindMessageTypeByName(message_type);
    613     OP_REQUIRES(context, message_desc != nullptr,
    614                 errors::InvalidArgument("No descriptor found for message type ",
    615                                         message_type));
    616 
    617     std::vector<string> field_names;
    618     OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
    619     std::vector<DataType> output_types;
    620     OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
    621     OP_REQUIRES(
    622         context, field_names.size() == output_types.size(),
    623         errors::InvalidArgument("field_names and output_types attributes must "
    624                                 "have the same length"));
    625 
    626     // Gather the field descriptors and check that requested output types match.
    627     int field_index = 0;
    628     std::vector<const FieldDescriptor*> field_descs;
    629     std::vector<const FieldDescriptor*> exts;
    630     absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field;
    631     std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin();
    632     for (const string& name : field_names) {
    633       auto fd = message_desc->FindFieldByName(name);
    634       if (fd == nullptr) {
    635         // If field can't be found in original message, try to find a matching
    636         // extension (by its full_name). First check a hashmap for a matching
    637         // extension, and if not found, then iterate through available
    638         // extensions to find a match (updating the hashmap while iterating.)
    639         auto lookup_result = ext_name_to_field.find(name);
    640         if (lookup_result != ext_name_to_field.end()) {
    641           fd = lookup_result->second;
    642         } else {
    643           if (ext_it == exts.begin()) {
    644             desc_pool->FindAllExtensions(message_desc, &exts);
    645             ext_it = exts.begin();
    646           }
    647           while (ext_it != exts.end()) {
    648             auto ext_name = (*ext_it)->full_name();
    649             auto ext_field = *ext_it;
    650             ++ext_it;
    651 
    652             ext_name_to_field.insert({ext_name, ext_field});
    653             if (ext_name == name) {
    654               fd = ext_field;
    655               break;
    656             }
    657           }
    658         }
    659       }
    660       OP_REQUIRES(context, fd != nullptr,
    661                   errors::InvalidArgument("Unknown field: ", name,
    662                                           " in message type ", message_type));
    663       OP_REQUIRES(
    664           context,
    665           proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
    666           // Many TensorFlow types don't have corresponding proto types and the
    667           // user will get an error if they are requested. It would be nice to
    668           // allow conversions here, but tf.cast already exists so we don't
    669           // duplicate the functionality.
    670           errors::InvalidArgument("Unexpected output type for ",
    671                                   fd->full_name(), ": ", fd->cpp_type(), " to ",
    672                                   output_types[field_index]));
    673 
    674       field_index++;
    675       field_descs.push_back(fd);
    676     }
    677 
    678     // Internally we want the field_descs sorted by their number on the wire.
    679     // But the output tensors are allocated in the order given by the caller.
    680     // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
    681     std::vector<int> output_indices;
    682     output_indices.reserve(field_names.size());
    683     for (int i = 0; i < field_names.size(); i++) {
    684       output_indices.push_back(i);
    685     }
    686     std::sort(output_indices.begin(), output_indices.end(),
    687               [field_descs](int a, int b) {
    688                 return field_descs[a]->number() < field_descs[b]->number();
    689               });
    690 
    691     // Now store the fields in sorted order.
    692     for (int i = 0; i < field_names.size(); i++) {
    693       const int output_index = output_indices[i];
    694       const DataType dtype = output_types[output_index];
    695       const FieldDescriptor* field_descriptor = field_descs[output_index];
    696       DefaultValue default_value;
    697       OP_REQUIRES_OK(context, InitDefaultValueFromFieldDescriptor(
    698                                   dtype, field_descriptor, &default_value));
    699       fields_.push_back(
    700           MakeUnique<FieldInfo>(field_descriptor, output_index, default_value));
    701     }
    702 
    703     message_prototype_ = message_factory_.GetPrototype(message_desc);
    704     OP_REQUIRES(context, message_prototype_ != nullptr,
    705                 errors::InvalidArgument("Couldn't get prototype message: ",
    706                                         message_desc->full_name()));
    707     string format;
    708     OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
    709     OP_REQUIRES(
    710         context, format == "binary" || format == "text",
    711         errors::InvalidArgument("format must be one of binary or text"));
    712     is_binary_ = format == "binary";
    713 
    714     // Enable the initial protobuf sanitizer, which is much more expensive than
    715     // the decoder.
    716     // TODO(nix): Remove this once the fast decoder has passed security review.
    717     OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
    718   }
    719 
    720   void Compute(OpKernelContext* ctx) override {
    721     const Tensor& buf_tensor = ctx->input(0);
    722     int message_count = buf_tensor.NumElements();
    723     OP_REQUIRES(ctx, message_count >= 1,
    724                 errors::InvalidArgument(
    725                     "Bufs argument must contain at least one value"));
    726 
    727     int field_count = fields_.size();
    728 
    729     // Save the argument shape for later, then flatten the input Tensor since we
    730     // are working componentwise. We will restore the same shape in the returned
    731     // Tensor.
    732     const TensorShape& shape_prefix = buf_tensor.shape();
    733 
    734     TensorShape sizes_shape = shape_prefix;
    735     sizes_shape.AddDim(field_count);
    736     Tensor* sizes_tensor = nullptr;
    737     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
    738 
    739     // This is used to allocate binary bufs if used. It serves only to define
    740     // memory ownership.
    741     std::vector<string> tmp_binary_bufs(message_count);
    742 
    743     // These are the actual buffers to use, which may be in tmp_binary_bufs
    744     // or may be pointers into the buf_tensor. Either way they are not owned
    745     // here.
    746     std::vector<const string*> bufs;
    747 
    748     if (is_binary_ && !sanitize_) {
    749       // Fast path.
    750       for (int mi = 0; mi < message_count; ++mi) {
    751         const string* buf = &buf_tensor.flat<string>()(mi);
    752         bufs.push_back(buf);
    753       }
    754     } else {
    755       // We will have to allocate a copy, either to convert from text to binary
    756       // or to sanitize a binary proto.
    757       for (int mi = 0; mi < message_count; ++mi) {
    758         ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
    759                            &tmp_binary_bufs[mi]);
    760         if (!ctx->status().ok()) {
    761           return;
    762         }
    763         bufs.push_back(&tmp_binary_bufs[mi]);
    764       }
    765     }
    766 
    767     // Walk through all the strings in the input tensor, counting the number of
    768     // fields in each. We can't allocate our actual output Tensor until we know
    769     // the maximum repeat count, so we do a first pass through the serialized
    770     // proto just counting fields. We always allocate at least one value so that
    771     // optional fields are populated with default values - this avoids a TF
    772     // conditional when handling the output data. The caller can distinguish
    773     // between real data and defaults using the repeat count matrix that is
    774     // returned by decode_proto.
    775     std::vector<int32> max_sizes(field_count, 1);
    776     for (int mi = 0; mi < message_count; ++mi) {
    777       CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
    778       if (!ctx->status().ok()) {
    779         return;
    780       }
    781     }
    782 
    783     // Allocate the output tensors now that we've seen the max size.
    784     // TODO(nix): Use allocate_output_or_forward_input for the largest
    785     //   output tensor. This can avoid one large allocation by re-using
    786     //   the memory of the input tensor.
    787     std::vector<Tensor*> outputs(field_count);
    788     for (int fi = 0; fi < field_count; ++fi) {
    789       TensorShape flat_shape = {static_cast<int64>(message_count),
    790                                 max_sizes[fi]};
    791       TensorShape out_shape = shape_prefix;
    792       out_shape.AddDim(max_sizes[fi]);
    793 
    794       // Surprisingly we don't specify the types from the output_types
    795       // attribute: that is done for us based on the Op declaration:
    796       //  REGISTER_OP(...)
    797       //    .Attr("output_types: list(type) >= 0")
    798       //    .Output("values: output_types")
    799       OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
    800                                                out_shape, &outputs[fi]));
    801     }
    802 
    803     // Make the second pass through the serialized proto, decoding into
    804     // preallocated tensors.
    805     AccumulateFields(ctx, bufs, outputs);
    806   }
    807 
    808  private:
    809   // Copy a serialized message to binary, e.g. to handle text proto inputs.
    810   void ReserializeMessage(OpKernelContext* ctx, const string& buf,
    811                           string* binary_buf) {
    812     // Handle text protos by translating them to binary.
    813     std::unique_ptr<Message> message(message_prototype_->New());
    814     OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
    815 
    816     if (is_binary_) {
    817       // If we get here we are sanitizing the input protobuf by parsing
    818       // and reserializing it with a trusted (but very slow) library.
    819       OP_REQUIRES(ctx, message->ParseFromString(buf),
    820                   errors::DataLoss("Unable to parse binary protobuf"));
    821     } else {
    822       OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
    823                   errors::DataLoss("Unable to parse text protobuf"));
    824     }
    825 
    826     OP_REQUIRES(ctx, message->SerializeToString(binary_buf),
    827                 errors::DataLoss("Unable to reserialize text proto as binary"));
    828   }
    829 
    830   // Count the number of occurrences of each requested field in a message batch.
    831   void CountFields(OpKernelContext* ctx, int message_index, const string& buf,
    832                    Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
    833     int field_count = fields_.size();
    834 
    835     CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
    836                            buf.size());
    837 
    838     std::vector<int32> field_sizes(field_count, 0);
    839     std::vector<CountCollector> counters;
    840     counters.reserve(field_count);
    841     for (int i = 0; i < field_count; i++) {
    842       counters.emplace_back(&field_sizes[i]);
    843     }
    844 
    845     Status st = Collect(&input, &counters);
    846     if (st.ok() && !input.ConsumedEntireMessage()) {
    847       st = errors::DataLoss("CountFields: Failed to consume entire buffer");
    848     }
    849     if (kFailOnDecodeError) {
    850       OP_REQUIRES_OK(ctx, st);  // NOLINT
    851     }
    852     if (!st.ok()) {
    853       // This code suppresses the corrupt proto, treating it as empty
    854       // to avoid crashing the process.
    855       LOG(WARNING) << "Proto counting error for message type " << message_type_
    856                    << ": " << st;
    857 
    858       for (int fi = 0; fi < field_count; fi++) {
    859         field_sizes[fi] = 0;
    860       }
    861       // Finished decoding this message.
    862       return;
    863     }
    864 
    865     // Update the size tensor and max repeat size for each field.
    866     auto sizes = sizes_tensor->flat_inner_dims<int32>();
    867     for (int fi = 0; fi < field_count; fi++) {
    868       int32 size = field_sizes[fi];
    869       sizes(message_index, fields_[fi]->output_index) = size;
    870       if ((*max_sizes)[fi] < size) {
    871         (*max_sizes)[fi] = size;
    872       }
    873     }
    874   }
    875 
    876   // Parse fields from a serialized message into preallocated tensors.
    877   void AccumulateFields(OpKernelContext* ctx,
    878                         const std::vector<const string*>& bufs,
    879                         std::vector<Tensor*> outputs) {
    880     struct TensorInfo {
    881       explicit TensorInfo(Tensor* tensor) {
    882         // Note that we can decode only max_repeat_count values before overflow.
    883         // No other bounds checking is done for repeated fields. For
    884         // optional fields there is a check to make sure that only the last
    885         // value on the wire appears in the output tensor.
    886         dtype = tensor->dtype();
    887         last_dim_size = tensor->dim_size(tensor->dims() - 1);
    888 
    889         if (dtype != DT_STRING) {
    890           const int element_size = DataTypeSize(dtype);
    891           CHECK_GT(element_size, 0);
    892           stride = last_dim_size * element_size;
    893 
    894           const int64 flatshape[1] = {tensor->NumElements() * element_size};
    895           data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
    896         } else {
    897           // DataTypeSize() returns 0 for string types.
    898           stride = last_dim_size * sizeof(string);
    899           data = reinterpret_cast<uint8*>(tensor->flat<string>().data());
    900         }
    901       }
    902 
    903       DataType dtype;
    904       int last_dim_size;
    905       int stride;
    906       uint8* data;
    907     };
    908 
    909     int field_count = fields_.size();
    910 
    911     std::vector<TensorInfo> tensors;
    912     tensors.reserve(field_count);
    913     for (int fi = 0; fi < field_count; fi++) {
    914       tensors.emplace_back(outputs[fi]);
    915     }
    916 
    917     for (int message_index = 0; message_index < bufs.size(); ++message_index) {
    918       const string& buf = *bufs[message_index];
    919 
    920       std::vector<DenseCollector> collectors;
    921       collectors.reserve(field_count);
    922       for (int output_index = 0; output_index < field_count; ++output_index) {
    923         const TensorInfo& info = tensors[output_index];
    924         const FieldInfo* field_info = fields_[output_index].get();
    925         DCHECK(field_info != nullptr);
    926         const DefaultValue default_value = field_info->default_value;
    927         collectors.emplace_back(info.data + message_index * info.stride,
    928                                 default_value, info.last_dim_size);
    929       }
    930 
    931       // Fill in output tensors from the wire.
    932       CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
    933                              buf.size());
    934       Status st = Collect(&input, &collectors);
    935       if (st.ok() && !input.ConsumedEntireMessage()) {
    936         st = errors::DataLoss(
    937             "AccumulateFields: Failed to consume entire buffer");
    938       }
    939       if (kFailOnDecodeError) {
    940         OP_REQUIRES_OK(ctx, st);  // NOLINT
    941       }
    942       if (!st.ok()) {
    943         // This code suppresses the corrupt proto, treating it as empty
    944         // to avoid crashing training.
    945         LOG(WARNING) << "Proto counting error for message type "
    946                      << message_type_ << ": " << st;
    947       }
    948 
    949       // Fill the remainder of the dense outputs with default values.
    950       for (auto& collector : collectors) {
    951         OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
    952       }
    953     }
    954   }
    955 
    956   // Look up the FieldDescriptor for a particular field number.
    957   bool LookupField(int field_number, int* field_index) {
    958     // Look up the FieldDescriptor using linear search.
    959     //
    960     // TODO(nix): this could be sped up with binary search, but we are
    961     // already way off the fastpath at this point. If you see a hotspot
    962     // here, somebody is sending you very inefficient protos.
    963     for (int fi = fields_.size() - 1; fi >= 0; fi--) {
    964       if (field_number == fields_[fi]->number) {
    965         *field_index = fi;
    966         return true;
    967       }
    968     }
    969     return false;
    970   }
    971 
    972   // Traverses a serialized protobuf, dispatching values to the collectors.
    973   template <class CollectorClass>
    974   Status Collect(CodedInputStream* input,
    975                  std::vector<CollectorClass>* collectors) {
    976     int last_good_field_index = -1;
    977     bool fields_disordered = false;
    978     int prev_field_number = -1;
    979     int field_number = -1;
    980     int last_good_field_number = -1;
    981     int next_good_field_number = fields_[0]->number;
    982 
    983     // The 'tag' variable should always be treated as tainted.
    984     for (uint32 tag = input->ReadTag();
    985          tag != 0 && WireFormatLite::GetTagWireType(tag) !=
    986                          WireFormatLite::WIRETYPE_END_GROUP;
    987          tag = input->ReadTag(), prev_field_number = field_number) {
    988       field_number = WireFormatLite::GetTagFieldNumber(tag);
    989       const FieldInfo* field = nullptr;
    990 
    991       // This takes advantage of the sorted field numbers in most serialized
    992       // protos: it tries the next expected field first rather than doing
    993       // a lookup by field number.
    994       //
    995       // TODO(nix): haberman@ suggests a hybrid approach with a lookup table
    996       // for small field numbers and a hash table for larger ones. This would
    997       // be a simpler approach that should offer comparable speed in most
    998       // cases.
    999       if (field_number == last_good_field_number) {
   1000         field = fields_[last_good_field_index].get();
   1001       } else {
   1002         if (field_number < prev_field_number) {
   1003           fields_disordered = true;
   1004         }
   1005 
   1006         // If fields are out of order, fall back to slow lookup.
   1007         if (fields_disordered) {
   1008           int field_index;
   1009           if (LookupField(field_number, &field_index)) {
   1010             field = fields_[field_index].get();
   1011             last_good_field_index = field_index;
   1012           }
   1013         } else {
   1014           // If we see a field that is past the next field we want, it was
   1015           // empty. Look for the one after that. Repeat until we run out of
   1016           // fields that we care about.
   1017           while (field_number >= next_good_field_number) {
   1018             if (field_number == next_good_field_number) {
   1019               last_good_field_number = field_number;
   1020               field = fields_[last_good_field_index + 1].get();
   1021             }
   1022 
   1023             // Start looking for the field after the current one.
   1024             ++last_good_field_index;
   1025             if (last_good_field_index < fields_.size() - 1) {
   1026               next_good_field_number =
   1027                   fields_[last_good_field_index + 1]->number;
   1028             } else {
   1029               // Saw something past the last field we care about. Continue
   1030               // parsing the message just in case there are disordered fields
   1031               // later, but any remaining ordered fields will have no effect.
   1032               next_good_field_number = INT_MAX;
   1033             }
   1034           }
   1035         }
   1036       }
   1037 
   1038       if (!field) {
   1039         // Unknown and unrequested fields are skipped.
   1040         if (!WireFormatLite::SkipField(input, tag)) {
   1041           return errors::DataLoss("Failed skipping unrequested field");
   1042         }
   1043         continue;
   1044       }
   1045 
   1046       Status st = CollectField(*field, WireFormatLite::GetTagWireType(tag),
   1047                                input, &(*collectors)[last_good_field_index]);
   1048       if (!st.ok()) {
   1049         return st;
   1050       }
   1051     }
   1052     return Status::OK();
   1053   }
   1054 
   1055   // Collects values for a single field.
   1056   template <class CollectorClass>
   1057   Status CollectField(const FieldInfo& field,
   1058                       WireFormatLite::WireType wire_type,
   1059                       CodedInputStream* input, CollectorClass* collector) {
   1060     // The wire format library defines the same constants used in
   1061     // descriptor.proto. This static_cast is safe because they are guaranteed to
   1062     // stay in sync.
   1063     //
   1064     // We need the field type from the FieldDescriptor here because the wire
   1065     // format doesn't tell us anything about what happens inside a packed
   1066     // repeated field: there is enough information in the wire format to skip
   1067     // the whole field but not enough to know how to parse what's inside. For
   1068     // that we go to the schema.
   1069     WireFormatLite::WireType schema_wire_type =
   1070         WireFormatLite::WireTypeForFieldType(field.type);
   1071 
   1072     // Handle packed repeated fields. SkipField would skip the whole
   1073     // length-delimited blob without letting us count the values, so we have to
   1074     // scan them ourselves.
   1075     if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
   1076         schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
   1077       // Handle packed repeated primitives.
   1078       int length;
   1079       if (!input->ReadVarintSizeAsInt(&length)) {
   1080         return errors::DataLoss("CollectField: Failed reading packed size");
   1081       }
   1082       return collector->ReadPackedValues(input, field, length);
   1083     }
   1084 
   1085     // Read ordinary values, including strings, bytes, and messages.
   1086     if (wire_type != schema_wire_type) {
   1087       if (!WireFormatLite::SkipField(
   1088               input, WireFormatLite::MakeTag(field.number, wire_type))) {
   1089         return errors::DataLoss(
   1090             "CollectField: Failed skipping malformed field");
   1091       }
   1092       return Status::OK();
   1093     }
   1094     return collector->ReadValue(input, field);
   1095   }
   1096 
   1097   string message_type_;
   1098   // Note that fields are sorted by increasing field number, which is not in
   1099   // general the order given by the user-specified field_names and output_types
   1100   // Op attributes.
   1101   std::vector<std::unique_ptr<const FieldInfo>> fields_;
   1102 
   1103   // Owned_desc_pool_ is null when using descriptor_source=local.
   1104   std::unique_ptr<DescriptorPool> owned_desc_pool_;
   1105   DynamicMessageFactory message_factory_;
   1106   const Message* message_prototype_;
   1107 
   1108   // True if decoding binary format, false if decoding text format.
   1109   bool is_binary_;
   1110 
   1111   // True if the protos should be sanitized before parsing. Enables the initial
   1112   // protobuf sanitizer, which is much more expensive than the decoder. The flag
   1113   // defaults to true but can be set to false for trusted sources.
   1114   //
   1115   // TODO(nix): Flip the default to false when the fast decoder has passed
   1116   // security review.
   1117   bool sanitize_;
   1118 
   1119   TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
   1120 };
   1121 
   1122 REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
   1123                         DecodeProtoOp);
   1124 
   1125 }  // namespace
   1126 }  // namespace tensorflow
   1127