1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/string_ops.cc. 17 18 #include <string> 19 20 #include "tensorflow/core/framework/kernel_def_builder.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/core/stringpiece.h" 26 #include "tensorflow/core/lib/strings/str_util.h" 27 28 namespace tensorflow { 29 namespace { 30 // Split input string `str` based on a character delimiter. 31 // Returns a vector of StringPieces which are valid as long as input `str` 32 // is valid. 33 // Note: The single character delimiter is a common case and is implemented as 34 // a series of finds in the input string, making it much more effcient than 35 // SplitOnCharSet. 36 template <typename Predicate> 37 std::vector<StringPiece> SplitOnChar(const string& str, const char delim, 38 Predicate p) { 39 std::vector<StringPiece> result; 40 StringPiece text(str); 41 auto f = text.find(delim); 42 while (f != StringPiece::npos) { 43 StringPiece token = text.substr(0, f); 44 if (p(token)) { 45 result.emplace_back(token); 46 } 47 text.remove_prefix(f + 1); 48 f = text.find(delim); 49 } 50 if (p(text)) { 51 result.push_back(text); 52 } 53 return result; 54 } 55 56 // Split input string `str` based on a set of character delimiters. 57 // Returns a vector of StringPieces which are valid as long as input `str` 58 // is valid. 59 // Based on str_util::Split. 60 template <typename Predicate> 61 std::vector<StringPiece> SplitOnCharSet(const string& str, 62 const string& delim_set, Predicate p) { 63 std::vector<StringPiece> result; 64 StringPiece text(str); 65 StringPiece delims(delim_set); 66 size_t token_start = 0; 67 for (size_t i = 0; i < text.size() + 1; i++) { 68 if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { 69 StringPiece token(text.data() + token_start, i - token_start); 70 if (p(token)) { 71 result.emplace_back(token); 72 } 73 token_start = i + 1; 74 } 75 } 76 return result; 77 } 78 79 // Split input string `str` based on given delimiter. 80 // Returns a vector of StringPieces which are valid as long as input `str` 81 // is valid. 82 template <typename Predicate> 83 std::vector<StringPiece> Split(const string& str, const string& delimiter, 84 Predicate predicate) { 85 if (str.empty()) { 86 return std::vector<StringPiece>(); 87 } 88 if (delimiter.empty()) { 89 std::vector<StringPiece> result; 90 result.resize(str.size()); 91 for (size_t i = 0; i < str.size(); ++i) { 92 result[i] = StringPiece(str.data() + i, 1); 93 } 94 return result; 95 } 96 if (delimiter.size() == 1) { 97 return SplitOnChar(str, delimiter[0], predicate); 98 } 99 return SplitOnCharSet(str, delimiter, predicate); 100 } 101 102 std::vector<StringPiece> SplitV2(const string& str, StringPiece sep, 103 int maxsplit) { 104 // This SplitV2 method matches the behavior of python's str.split: 105 // If sep is given, consecutive delimiters are not grouped together 106 // and are deemed to delimit empty strings (for example, '1,,2'.split(',') 107 // returns ['1', '', '2']). The sep argument may consist of multiple 108 // characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']). 109 // Splitting an empty string with a specified separator returns ['']. 110 // 111 // If sep is not specified or is None, a different splitting algorithm is 112 // applied: runs of consecutive whitespace are regarded as a single 113 // separator, and the result will contain no empty strings at the start or 114 // end if the string has leading or trailing whitespace. Consequently, 115 // splitting an empty string or a string consisting of just whitespace 116 // with a None separator returns []. 117 118 std::vector<StringPiece> result; 119 120 StringPiece text(str); 121 if (maxsplit == 0) { 122 result.emplace_back(text); 123 return result; 124 } 125 126 if (sep.empty()) { 127 StringPiece token; 128 // Remove leading whitespaces. 129 str_util::RemoveLeadingWhitespace(&text); 130 int split = 0; 131 while (str_util::ConsumeNonWhitespace(&text, &token)) { 132 result.push_back(token); 133 str_util::RemoveLeadingWhitespace(&text); 134 ++split; 135 if (maxsplit > 0 && split == maxsplit) { 136 result.push_back(text); 137 return result; 138 } 139 } 140 return result; 141 } 142 auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); 143 int split = 0; 144 while (p != text.end()) { 145 StringPiece token = text.substr(0, p - text.begin()); 146 result.push_back(token); 147 text.remove_prefix(token.size()); 148 text.remove_prefix(sep.size()); 149 ++split; 150 if (maxsplit > 0 && split == maxsplit) { 151 result.push_back(StringPiece(text)); 152 return result; 153 } 154 p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); 155 } 156 result.push_back(text); 157 return result; 158 } 159 160 } // namespace 161 162 class StringSplitOp : public OpKernel { 163 public: 164 explicit StringSplitOp(OpKernelConstruction* context) 165 : OpKernel(context), skip_empty_(true) { 166 bool skip_empty; 167 // By default skip_empty_ is true. We only get the value from attr if it is 168 // available, so that it is backward compatible. 169 if (context->GetAttr("skip_empty", &skip_empty).ok()) { 170 skip_empty_ = skip_empty; 171 } 172 } 173 174 void Compute(OpKernelContext* ctx) override { 175 const Tensor* input_tensor; 176 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); 177 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()), 178 errors::InvalidArgument("input must be a vector, got shape: ", 179 input_tensor->shape().DebugString())); 180 181 const auto input_vec = input_tensor->vec<string>(); 182 const int64 batch_size = input_vec.dimension(0); 183 184 const Tensor* delimiter_tensor; 185 OP_REQUIRES_OK(ctx, ctx->input("delimiter", &delimiter_tensor)); 186 OP_REQUIRES( 187 ctx, TensorShapeUtils::IsScalar(delimiter_tensor->shape()), 188 errors::InvalidArgument("delimiter must be a scalar, got shape: ", 189 delimiter_tensor->shape().DebugString())); 190 const auto delimiter_vec = delimiter_tensor->flat<string>(); 191 const string& delimiter = delimiter_vec(0); 192 // Empty delimiter means split the input character by character. 193 std::vector<StringPiece> tokens; 194 // Guess that we'll be unpacking a handful of tokens per example. 195 static constexpr int kReserveSize = 4; 196 tokens.reserve(batch_size * kReserveSize); 197 198 int64 output_size = 0; 199 int64 max_num_entries = 0; 200 std::vector<int64> num_indices(batch_size); 201 for (int64 i = 0; i < batch_size; ++i) { 202 std::vector<StringPiece> parts = 203 skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty()) 204 : Split(input_vec(i), delimiter, str_util::AllowEmpty()); 205 int64 n_entries = parts.size(); 206 num_indices[i] = n_entries; 207 output_size += n_entries; 208 max_num_entries = std::max(max_num_entries, n_entries); 209 tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()), 210 std::make_move_iterator(parts.end())); 211 } 212 213 Tensor* sp_indices_t; 214 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}), 215 &sp_indices_t)); 216 Tensor* sp_tokens_t; 217 OP_REQUIRES_OK( 218 ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t)); 219 Tensor* sp_shape_t; 220 OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); 221 222 auto sp_indices = sp_indices_t->matrix<int64>(); 223 auto sp_tokens = sp_tokens_t->vec<string>(); 224 auto sp_shape = sp_shape_t->vec<int64>(); 225 sp_shape(0) = batch_size; 226 sp_shape(1) = max_num_entries; 227 size_t c = 0; 228 for (size_t i = 0; i < batch_size; ++i) { 229 for (size_t j = 0; j < num_indices[i]; ++j) { 230 sp_indices(c, 0) = i; 231 sp_indices(c, 1) = j; 232 sp_tokens(c).assign(tokens[c].data(), tokens[c].size()); 233 ++c; 234 } 235 } 236 } 237 238 private: 239 bool skip_empty_; 240 }; 241 242 class StringSplitV2Op : public OpKernel { 243 public: 244 explicit StringSplitV2Op(OpKernelConstruction* context) 245 : OpKernel(context), maxsplit_(-1) { 246 OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_)); 247 } 248 249 void Compute(OpKernelContext* ctx) override { 250 const Tensor* input_tensor; 251 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); 252 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()), 253 errors::InvalidArgument("input must be a vector, got shape: ", 254 input_tensor->shape().DebugString())); 255 256 const auto input_vec = input_tensor->vec<string>(); 257 const int64 batch_size = input_vec.dimension(0); 258 259 const Tensor* sep_tensor; 260 OP_REQUIRES_OK(ctx, ctx->input("sep", &sep_tensor)); 261 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()), 262 errors::InvalidArgument("sep must be a scalar, got shape: ", 263 sep_tensor->shape().DebugString())); 264 const auto sep_vec = sep_tensor->flat<string>(); 265 StringPiece sep(sep_vec(0)); 266 std::vector<StringPiece> tokens; 267 // Guess that we'll be unpacking a handful of tokens per example. 268 static constexpr int kReserveSize = 4; 269 tokens.reserve(batch_size * kReserveSize); 270 271 int64 output_size = 0; 272 int64 max_num_entries = 0; 273 std::vector<int64> num_indices(batch_size); 274 for (int64 i = 0; i < batch_size; ++i) { 275 std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_); 276 int64 n_entries = parts.size(); 277 num_indices[i] = n_entries; 278 output_size += n_entries; 279 max_num_entries = std::max(max_num_entries, n_entries); 280 tokens.insert(tokens.end(), parts.begin(), parts.end()); 281 } 282 283 Tensor* sp_indices_t; 284 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}), 285 &sp_indices_t)); 286 Tensor* sp_tokens_t; 287 OP_REQUIRES_OK( 288 ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t)); 289 Tensor* sp_shape_t; 290 OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); 291 292 auto sp_indices = sp_indices_t->matrix<int64>(); 293 auto sp_tokens = sp_tokens_t->vec<string>(); 294 auto sp_shape = sp_shape_t->vec<int64>(); 295 sp_shape(0) = batch_size; 296 sp_shape(1) = max_num_entries; 297 size_t c = 0; 298 for (size_t i = 0; i < batch_size; ++i) { 299 for (size_t j = 0; j < num_indices[i]; ++j) { 300 sp_indices(c, 0) = i; 301 sp_indices(c, 1) = j; 302 sp_tokens(c).assign(tokens[c].data(), tokens[c].size()); 303 ++c; 304 } 305 } 306 } 307 308 private: 309 int maxsplit_; 310 }; 311 312 REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp); 313 REGISTER_KERNEL_BUILDER(Name("StringSplitV2").Device(DEVICE_CPU), 314 StringSplitV2Op); 315 316 } // namespace tensorflow 317