Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/string_ops.cc.
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/core/framework/kernel_def_builder.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/strings/str_util.h"
     26 
     27 namespace tensorflow {
     28 
     29 namespace {
     30 
     31 std::vector<string> Split(const string& str, const string& delimiter,
     32                           const bool skipEmpty) {
     33   if (!delimiter.empty()) {
     34     if (skipEmpty) {
     35       return str_util::Split(str, delimiter, str_util::SkipEmpty());
     36     }
     37     return str_util::Split(str, delimiter);
     38   }
     39   std::vector<string> char_vector(str.size());
     40   for (size_t i = 0; i < str.size(); ++i) {
     41     char_vector[i] = str[i];
     42   }
     43   return char_vector;
     44 }
     45 
     46 }  // namespace
     47 
     48 class StringSplitOp : public OpKernel {
     49  public:
     50   explicit StringSplitOp(OpKernelConstruction* context)
     51       : OpKernel(context), skip_empty_(true) {
     52     bool skip_empty;
     53     // By default skip_empty_ is true. We only get the value from attr if it is
     54     // available, so that it is backward compatible.
     55     if (context->GetAttr("skip_empty", &skip_empty).ok()) {
     56       skip_empty_ = skip_empty;
     57     }
     58   }
     59 
     60   void Compute(OpKernelContext* ctx) override {
     61     const Tensor* input_tensor;
     62     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
     63     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()),
     64                 errors::InvalidArgument("input must be a vector, got shape: ",
     65                                         input_tensor->shape().DebugString()));
     66 
     67     const auto input_vec = input_tensor->vec<string>();
     68     const int64 batch_size = input_vec.dimension(0);
     69 
     70     const Tensor* delimiter_tensor;
     71     OP_REQUIRES_OK(ctx, ctx->input("delimiter", &delimiter_tensor));
     72     OP_REQUIRES(
     73         ctx, TensorShapeUtils::IsScalar(delimiter_tensor->shape()),
     74         errors::InvalidArgument("delimiter must scalar, got shape: ",
     75                                 delimiter_tensor->shape().DebugString()));
     76     const auto delimiter_vec = delimiter_tensor->flat<string>();
     77     const string& delimiter = delimiter_vec(0);
     78     // Empty delimiter means split the input character by character.
     79     std::vector<string> tokens;
     80     // Guess that we'll be unpacking a handful of tokens per example.
     81     static constexpr int kReserveSize = 4;
     82     tokens.reserve(batch_size * kReserveSize);
     83 
     84     int64 output_size = 0;
     85     int64 max_num_entries = 0;
     86     std::vector<int64> num_indices(batch_size);
     87     for (int64 i = 0; i < batch_size; ++i) {
     88       std::vector<string> parts = Split(input_vec(i), delimiter, skip_empty_);
     89       int64 n_entries = parts.size();
     90       num_indices[i] = n_entries;
     91       output_size += n_entries;
     92       max_num_entries = std::max(max_num_entries, n_entries);
     93       tokens.insert(tokens.end(), parts.begin(), parts.end());
     94     }
     95 
     96     Tensor* sp_indices_t;
     97     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}),
     98                                              &sp_indices_t));
     99     Tensor* sp_tokens_t;
    100     OP_REQUIRES_OK(
    101         ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t));
    102     Tensor* sp_shape_t;
    103     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
    104 
    105     auto sp_indices = sp_indices_t->matrix<int64>();
    106     auto sp_tokens = sp_tokens_t->vec<string>();
    107     auto sp_shape = sp_shape_t->vec<int64>();
    108     sp_shape(0) = batch_size;
    109     sp_shape(1) = max_num_entries;
    110     size_t c = 0;
    111     for (size_t i = 0; i < batch_size; ++i) {
    112       for (size_t j = 0; j < num_indices[i]; ++j) {
    113         sp_indices(c, 0) = i;
    114         sp_indices(c, 1) = j;
    115         sp_tokens(c) = tokens[c];
    116         ++c;
    117       }
    118     }
    119   }
    120 
    121  private:
    122   bool skip_empty_;
    123 };
    124 
    125 REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp);
    126 
    127 }  // namespace tensorflow
    128