Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/string_ops.cc.
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/core/framework/kernel_def_builder.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/core/stringpiece.h"
     26 #include "tensorflow/core/lib/strings/str_util.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 // Split input string `str` based on a character delimiter.
     31 // Returns a vector of StringPieces which are valid as long as input `str`
     32 // is valid.
     33 // Note: The single character delimiter is a common case and is implemented as
     34 // a series of finds in the input string, making it much more effcient than
     35 // SplitOnCharSet.
     36 template <typename Predicate>
     37 std::vector<StringPiece> SplitOnChar(const string& str, const char delim,
     38                                      Predicate p) {
     39   std::vector<StringPiece> result;
     40   StringPiece text(str);
     41   auto f = text.find(delim);
     42   while (f != StringPiece::npos) {
     43     StringPiece token = text.substr(0, f);
     44     if (p(token)) {
     45       result.emplace_back(token);
     46     }
     47     text.remove_prefix(f + 1);
     48     f = text.find(delim);
     49   }
     50   if (p(text)) {
     51     result.push_back(text);
     52   }
     53   return result;
     54 }
     55 
     56 // Split input string `str` based on a set of character delimiters.
     57 // Returns a vector of StringPieces which are valid as long as input `str`
     58 // is valid.
     59 // Based on str_util::Split.
     60 template <typename Predicate>
     61 std::vector<StringPiece> SplitOnCharSet(const string& str,
     62                                         const string& delim_set, Predicate p) {
     63   std::vector<StringPiece> result;
     64   StringPiece text(str);
     65   StringPiece delims(delim_set);
     66   size_t token_start = 0;
     67   for (size_t i = 0; i < text.size() + 1; i++) {
     68     if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
     69       StringPiece token(text.data() + token_start, i - token_start);
     70       if (p(token)) {
     71         result.emplace_back(token);
     72       }
     73       token_start = i + 1;
     74     }
     75   }
     76   return result;
     77 }
     78 
     79 // Split input string `str` based on given delimiter.
     80 // Returns a vector of StringPieces which are valid as long as input `str`
     81 // is valid.
     82 template <typename Predicate>
     83 std::vector<StringPiece> Split(const string& str, const string& delimiter,
     84                                Predicate predicate) {
     85   if (str.empty()) {
     86     return std::vector<StringPiece>();
     87   }
     88   if (delimiter.empty()) {
     89     std::vector<StringPiece> result;
     90     result.resize(str.size());
     91     for (size_t i = 0; i < str.size(); ++i) {
     92       result[i] = StringPiece(str.data() + i, 1);
     93     }
     94     return result;
     95   }
     96   if (delimiter.size() == 1) {
     97     return SplitOnChar(str, delimiter[0], predicate);
     98   }
     99   return SplitOnCharSet(str, delimiter, predicate);
    100 }
    101 
    102 std::vector<StringPiece> SplitV2(const string& str, StringPiece sep,
    103                                  int maxsplit) {
    104   // This SplitV2 method matches the behavior of python's str.split:
    105   //   If sep is given, consecutive delimiters are not grouped together
    106   //   and are deemed to delimit empty strings (for example, '1,,2'.split(',')
    107   //   returns ['1', '', '2']). The sep argument may consist of multiple
    108   //   characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']).
    109   //   Splitting an empty string with a specified separator returns [''].
    110   //
    111   //   If sep is not specified or is None, a different splitting algorithm is
    112   //   applied: runs of consecutive whitespace are regarded as a single
    113   //   separator, and the result will contain no empty strings at the start or
    114   //   end if the string has leading or trailing whitespace. Consequently,
    115   //   splitting an empty string or a string consisting of just whitespace
    116   //   with a None separator returns [].
    117 
    118   std::vector<StringPiece> result;
    119 
    120   StringPiece text(str);
    121   if (maxsplit == 0) {
    122     result.emplace_back(text);
    123     return result;
    124   }
    125 
    126   if (sep.empty()) {
    127     StringPiece token;
    128     // Remove leading whitespaces.
    129     str_util::RemoveLeadingWhitespace(&text);
    130     int split = 0;
    131     while (str_util::ConsumeNonWhitespace(&text, &token)) {
    132       result.push_back(token);
    133       str_util::RemoveLeadingWhitespace(&text);
    134       ++split;
    135       if (maxsplit > 0 && split == maxsplit) {
    136         result.push_back(text);
    137         return result;
    138       }
    139     }
    140     return result;
    141   }
    142   auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
    143   int split = 0;
    144   while (p != text.end()) {
    145     StringPiece token = text.substr(0, p - text.begin());
    146     result.push_back(token);
    147     text.remove_prefix(token.size());
    148     text.remove_prefix(sep.size());
    149     ++split;
    150     if (maxsplit > 0 && split == maxsplit) {
    151       result.push_back(StringPiece(text));
    152       return result;
    153     }
    154     p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
    155   }
    156   result.push_back(text);
    157   return result;
    158 }
    159 
    160 }  // namespace
    161 
    162 class StringSplitOp : public OpKernel {
    163  public:
    164   explicit StringSplitOp(OpKernelConstruction* context)
    165       : OpKernel(context), skip_empty_(true) {
    166     bool skip_empty;
    167     // By default skip_empty_ is true. We only get the value from attr if it is
    168     // available, so that it is backward compatible.
    169     if (context->GetAttr("skip_empty", &skip_empty).ok()) {
    170       skip_empty_ = skip_empty;
    171     }
    172   }
    173 
    174   void Compute(OpKernelContext* ctx) override {
    175     const Tensor* input_tensor;
    176     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
    177     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()),
    178                 errors::InvalidArgument("input must be a vector, got shape: ",
    179                                         input_tensor->shape().DebugString()));
    180 
    181     const auto input_vec = input_tensor->vec<string>();
    182     const int64 batch_size = input_vec.dimension(0);
    183 
    184     const Tensor* delimiter_tensor;
    185     OP_REQUIRES_OK(ctx, ctx->input("delimiter", &delimiter_tensor));
    186     OP_REQUIRES(
    187         ctx, TensorShapeUtils::IsScalar(delimiter_tensor->shape()),
    188         errors::InvalidArgument("delimiter must be a scalar, got shape: ",
    189                                 delimiter_tensor->shape().DebugString()));
    190     const auto delimiter_vec = delimiter_tensor->flat<string>();
    191     const string& delimiter = delimiter_vec(0);
    192     // Empty delimiter means split the input character by character.
    193     std::vector<StringPiece> tokens;
    194     // Guess that we'll be unpacking a handful of tokens per example.
    195     static constexpr int kReserveSize = 4;
    196     tokens.reserve(batch_size * kReserveSize);
    197 
    198     int64 output_size = 0;
    199     int64 max_num_entries = 0;
    200     std::vector<int64> num_indices(batch_size);
    201     for (int64 i = 0; i < batch_size; ++i) {
    202       std::vector<StringPiece> parts =
    203           skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty())
    204                       : Split(input_vec(i), delimiter, str_util::AllowEmpty());
    205       int64 n_entries = parts.size();
    206       num_indices[i] = n_entries;
    207       output_size += n_entries;
    208       max_num_entries = std::max(max_num_entries, n_entries);
    209       tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()),
    210                     std::make_move_iterator(parts.end()));
    211     }
    212 
    213     Tensor* sp_indices_t;
    214     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}),
    215                                              &sp_indices_t));
    216     Tensor* sp_tokens_t;
    217     OP_REQUIRES_OK(
    218         ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t));
    219     Tensor* sp_shape_t;
    220     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
    221 
    222     auto sp_indices = sp_indices_t->matrix<int64>();
    223     auto sp_tokens = sp_tokens_t->vec<string>();
    224     auto sp_shape = sp_shape_t->vec<int64>();
    225     sp_shape(0) = batch_size;
    226     sp_shape(1) = max_num_entries;
    227     size_t c = 0;
    228     for (size_t i = 0; i < batch_size; ++i) {
    229       for (size_t j = 0; j < num_indices[i]; ++j) {
    230         sp_indices(c, 0) = i;
    231         sp_indices(c, 1) = j;
    232         sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
    233         ++c;
    234       }
    235     }
    236   }
    237 
    238  private:
    239   bool skip_empty_;
    240 };
    241 
    242 class StringSplitV2Op : public OpKernel {
    243  public:
    244   explicit StringSplitV2Op(OpKernelConstruction* context)
    245       : OpKernel(context), maxsplit_(-1) {
    246     OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_));
    247   }
    248 
    249   void Compute(OpKernelContext* ctx) override {
    250     const Tensor* input_tensor;
    251     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
    252     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()),
    253                 errors::InvalidArgument("input must be a vector, got shape: ",
    254                                         input_tensor->shape().DebugString()));
    255 
    256     const auto input_vec = input_tensor->vec<string>();
    257     const int64 batch_size = input_vec.dimension(0);
    258 
    259     const Tensor* sep_tensor;
    260     OP_REQUIRES_OK(ctx, ctx->input("sep", &sep_tensor));
    261     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()),
    262                 errors::InvalidArgument("sep must be a scalar, got shape: ",
    263                                         sep_tensor->shape().DebugString()));
    264     const auto sep_vec = sep_tensor->flat<string>();
    265     StringPiece sep(sep_vec(0));
    266     std::vector<StringPiece> tokens;
    267     // Guess that we'll be unpacking a handful of tokens per example.
    268     static constexpr int kReserveSize = 4;
    269     tokens.reserve(batch_size * kReserveSize);
    270 
    271     int64 output_size = 0;
    272     int64 max_num_entries = 0;
    273     std::vector<int64> num_indices(batch_size);
    274     for (int64 i = 0; i < batch_size; ++i) {
    275       std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_);
    276       int64 n_entries = parts.size();
    277       num_indices[i] = n_entries;
    278       output_size += n_entries;
    279       max_num_entries = std::max(max_num_entries, n_entries);
    280       tokens.insert(tokens.end(), parts.begin(), parts.end());
    281     }
    282 
    283     Tensor* sp_indices_t;
    284     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}),
    285                                              &sp_indices_t));
    286     Tensor* sp_tokens_t;
    287     OP_REQUIRES_OK(
    288         ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t));
    289     Tensor* sp_shape_t;
    290     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
    291 
    292     auto sp_indices = sp_indices_t->matrix<int64>();
    293     auto sp_tokens = sp_tokens_t->vec<string>();
    294     auto sp_shape = sp_shape_t->vec<int64>();
    295     sp_shape(0) = batch_size;
    296     sp_shape(1) = max_num_entries;
    297     size_t c = 0;
    298     for (size_t i = 0; i < batch_size; ++i) {
    299       for (size_t j = 0; j < num_indices[i]; ++j) {
    300         sp_indices(c, 0) = i;
    301         sp_indices(c, 1) = j;
    302         sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
    303         ++c;
    304       }
    305     }
    306   }
    307 
    308  private:
    309   int maxsplit_;
    310 };
    311 
    312 REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp);
    313 REGISTER_KERNEL_BUILDER(Name("StringSplitV2").Device(DEVICE_CPU),
    314                         StringSplitV2Op);
    315 
    316 }  // namespace tensorflow
    317