Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/parsing_ops.cc.
     17 #include <vector>
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/tensor_shape.h"
     21 #include "tensorflow/core/framework/types.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/strings/numbers.h"
     24 
     25 namespace tensorflow {
     26 
     27 class DecodeCSVOp : public OpKernel {
     28  public:
     29   explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     30     string delim;
     31 
     32     OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE", &out_type_));
     33     OP_REQUIRES(ctx, out_type_.size() < std::numeric_limits<int>::max(),
     34                 errors::InvalidArgument("Out type too large"));
     35     OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim", &delim));
     36     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim", &use_quote_delim_));
     37     OP_REQUIRES(ctx, delim.size() == 1,
     38                 errors::InvalidArgument("field_delim should be only 1 char"));
     39     delim_ = delim[0];
     40     OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value", &na_value_));
     41   }
     42 
     43   void Compute(OpKernelContext* ctx) override {
     44     const Tensor* records;
     45     OpInputList record_defaults;
     46 
     47     OP_REQUIRES_OK(ctx, ctx->input("records", &records));
     48     OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
     49 
     50     for (int i = 0; i < record_defaults.size(); ++i) {
     51       OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
     52                   errors::InvalidArgument(
     53                       "There should only be 1 default per field but field ", i,
     54                       " has ", record_defaults[i].NumElements()));
     55     }
     56 
     57     auto records_t = records->flat<string>();
     58     int64 records_size = records_t.size();
     59 
     60     OpOutputList output;
     61     OP_REQUIRES_OK(ctx, ctx->output_list("output", &output));
     62 
     63     for (int i = 0; i < static_cast<int>(out_type_.size()); ++i) {
     64       Tensor* out = nullptr;
     65       OP_REQUIRES_OK(ctx, output.allocate(i, records->shape(), &out));
     66     }
     67 
     68     for (int64 i = 0; i < records_size; ++i) {
     69       const StringPiece record(records_t(i));
     70       std::vector<string> fields;
     71       ExtractFields(ctx, record, &fields);
     72       OP_REQUIRES(ctx, fields.size() == out_type_.size(),
     73                   errors::InvalidArgument("Expect ", out_type_.size(),
     74                                           " fields but have ", fields.size(),
     75                                           " in record ", i));
     76 
     77       // Check each field in the record
     78       for (int f = 0; f < static_cast<int>(out_type_.size()); ++f) {
     79         const DataType& dtype = out_type_[f];
     80         switch (dtype) {
     81           case DT_INT32: {
     82             // If this field is empty or NA value, check if default is given:
     83             // If yes, use default value; Otherwise report error.
     84             if (fields[f].empty() || fields[f] == na_value_) {
     85               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
     86                           errors::InvalidArgument(
     87                               "Field ", f,
     88                               " is required but missing in record ", i, "!"));
     89 
     90               output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0);
     91             } else {
     92               int32 value;
     93               OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value),
     94                           errors::InvalidArgument(
     95                               "Field ", f, " in record ", i,
     96                               " is not a valid int32: ", fields[f]));
     97               output[f]->flat<int32>()(i) = value;
     98             }
     99             break;
    100           }
    101           case DT_INT64: {
    102             // If this field is empty or NA value, check if default is given:
    103             // If yes, use default value; Otherwise report error.
    104             if (fields[f].empty() || fields[f] == na_value_) {
    105               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
    106                           errors::InvalidArgument(
    107                               "Field ", f,
    108                               " is required but missing in record ", i, "!"));
    109 
    110               output[f]->flat<int64>()(i) = record_defaults[f].flat<int64>()(0);
    111             } else {
    112               int64 value;
    113               OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value),
    114                           errors::InvalidArgument(
    115                               "Field ", f, " in record ", i,
    116                               " is not a valid int64: ", fields[f]));
    117               output[f]->flat<int64>()(i) = value;
    118             }
    119             break;
    120           }
    121           case DT_FLOAT: {
    122             // If this field is empty or NA value, check if default is given:
    123             // If yes, use default value; Otherwise report error.
    124             if (fields[f].empty() || fields[f] == na_value_) {
    125               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
    126                           errors::InvalidArgument(
    127                               "Field ", f,
    128                               " is required but missing in record ", i, "!"));
    129               output[f]->flat<float>()(i) = record_defaults[f].flat<float>()(0);
    130             } else {
    131               float value;
    132               OP_REQUIRES(ctx, strings::safe_strtof(fields[f].c_str(), &value),
    133                           errors::InvalidArgument(
    134                               "Field ", f, " in record ", i,
    135                               " is not a valid float: ", fields[f]));
    136               output[f]->flat<float>()(i) = value;
    137             }
    138             break;
    139           }
    140           case DT_DOUBLE: {
    141             // If this field is empty or NA value, check if default is given:
    142             // If yes, use default value; Otherwise report error.
    143             if (fields[f].empty() || fields[f] == na_value_) {
    144               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
    145                           errors::InvalidArgument(
    146                               "Field ", f,
    147                               " is required but missing in record ", i, "!"));
    148               output[f]->flat<double>()(i) =
    149                   record_defaults[f].flat<double>()(0);
    150             } else {
    151               double value;
    152               OP_REQUIRES(ctx, strings::safe_strtod(fields[f].c_str(), &value),
    153                           errors::InvalidArgument(
    154                               "Field ", f, " in record ", i,
    155                               " is not a valid double: ", fields[f]));
    156               output[f]->flat<double>()(i) = value;
    157             }
    158             break;
    159           }
    160           case DT_STRING: {
    161             // If this field is empty or NA value, check if default is given:
    162             // If yes, use default value; Otherwise report error.
    163             if (fields[f].empty() || fields[f] == na_value_) {
    164               OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
    165                           errors::InvalidArgument(
    166                               "Field ", f,
    167                               " is required but missing in record ", i, "!"));
    168               output[f]->flat<string>()(i) =
    169                   record_defaults[f].flat<string>()(0);
    170             } else {
    171               output[f]->flat<string>()(i) = fields[f];
    172             }
    173             break;
    174           }
    175           default:
    176             OP_REQUIRES(ctx, false,
    177                         errors::InvalidArgument("csv: data type ", dtype,
    178                                                 " not supported in field ", f));
    179         }
    180       }
    181     }
    182   }
    183 
    184  private:
    185   std::vector<DataType> out_type_;
    186   char delim_;
    187   bool use_quote_delim_;
    188   string na_value_;
    189 
    190   void ExtractFields(OpKernelContext* ctx, StringPiece input,
    191                      std::vector<string>* result) {
    192     int64 current_idx = 0;
    193     if (!input.empty()) {
    194       while (static_cast<size_t>(current_idx) < input.size()) {
    195         if (input[current_idx] == '\n' || input[current_idx] == '\r') {
    196           current_idx++;
    197           continue;
    198         }
    199 
    200         bool quoted = false;
    201         if (use_quote_delim_ && input[current_idx] == '"') {
    202           quoted = true;
    203           current_idx++;
    204         }
    205 
    206         // This is the body of the field;
    207         string field;
    208         if (!quoted) {
    209           while (static_cast<size_t>(current_idx) < input.size() &&
    210                  input[current_idx] != delim_) {
    211             OP_REQUIRES(ctx,
    212                         (!use_quote_delim_ || input[current_idx] != '"') &&
    213                             input[current_idx] != '\n' &&
    214                             input[current_idx] != '\r',
    215                         errors::InvalidArgument(
    216                             "Unquoted fields cannot have quotes/CRLFs inside"));
    217             field += input[current_idx];
    218             current_idx++;
    219           }
    220 
    221           // Go to next field or the end
    222           current_idx++;
    223         } else if (use_quote_delim_) {
    224           // Quoted field needs to be ended with '"' and delim or end
    225           while (
    226               (static_cast<size_t>(current_idx) < input.size() - 1) &&
    227               (input[current_idx] != '"' || input[current_idx + 1] != delim_)) {
    228             if (input[current_idx] != '"') {
    229               field += input[current_idx];
    230               current_idx++;
    231             } else {
    232               OP_REQUIRES(
    233                   ctx, input[current_idx + 1] == '"',
    234                   errors::InvalidArgument("Quote inside a string has to be "
    235                                           "escaped by another quote"));
    236               field += '"';
    237               current_idx += 2;
    238             }
    239           }
    240 
    241           OP_REQUIRES(
    242               ctx,
    243               (static_cast<size_t>(current_idx) < input.size() &&
    244                input[current_idx] == '"' &&
    245                (static_cast<size_t>(current_idx) == input.size() - 1 ||
    246                 input[current_idx + 1] == delim_)),
    247               errors::InvalidArgument("Quoted field has to end with quote "
    248                                       "followed by delim or end"));
    249 
    250           current_idx += 2;
    251         }
    252 
    253         result->push_back(field);
    254       }
    255 
    256       // Check if the last field is missing
    257       if (input[input.size() - 1] == delim_) result->push_back(string());
    258     }
    259   }
    260 };
    261 
    262 REGISTER_KERNEL_BUILDER(Name("DecodeCSV").Device(DEVICE_CPU), DecodeCSVOp);
    263 
    264 }  // namespace tensorflow
    265