Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Operations for working with string Tensors."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import numpy as np
     23 
     24 from tensorflow.python.compat import compat
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import sparse_tensor
     29 from tensorflow.python.framework import tensor_util
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import gen_parsing_ops
     32 from tensorflow.python.ops import gen_string_ops
     33 from tensorflow.python.ops import math_ops
     34 
     35 # go/tf-wildcard-import
     36 # pylint: disable=wildcard-import
     37 # pylint: disable=g-bad-import-order
     38 from tensorflow.python.ops.gen_string_ops import *
     39 from tensorflow.python.util import compat as util_compat
     40 from tensorflow.python.util import deprecation
     41 from tensorflow.python.util import dispatch
     42 from tensorflow.python.util.tf_export import tf_export
     43 # pylint: enable=g-bad-import-order
     44 # pylint: enable=wildcard-import
     45 
     46 
     47 # pylint: disable=redefined-builtin
     48 @tf_export("strings.regex_full_match")
     49 @dispatch.add_dispatch_support
     50 def regex_full_match(input, pattern, name=None):
     51   r"""Match elements of `input` with regex `pattern`.
     52 
     53   Args:
     54     input: string `Tensor`, the source strings to process.
     55     pattern: string or scalar string `Tensor`, regular expression to use,
     56       see more details at https://github.com/google/re2/wiki/Syntax
     57     name: Name of the op.
     58 
     59   Returns:
     60     bool `Tensor` of the same shape as `input` with match results.
     61   """
     62   # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
     63   if not compat.forward_compatible(2018, 11, 10):
     64     return gen_string_ops.regex_full_match(
     65         input=input, pattern=pattern, name=name)
     66   if isinstance(pattern, util_compat.bytes_or_text_types):
     67     # When `pattern` is static through the life of the op we can
     68     # use a version which performs the expensive regex compilation once at
     69     # creation time.
     70     return gen_string_ops.static_regex_full_match(
     71         input=input, pattern=pattern, name=name)
     72   return gen_string_ops.regex_full_match(
     73       input=input, pattern=pattern, name=name)
     74 
     75 regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
     76 
     77 
     78 @tf_export(
     79     "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
     80 @deprecation.deprecated_endpoints("regex_replace")
     81 @dispatch.add_dispatch_support
     82 def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
     83   r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
     84 
     85   Args:
     86     input: string `Tensor`, the source strings to process.
     87     pattern: string or scalar string `Tensor`, regular expression to use,
     88       see more details at https://github.com/google/re2/wiki/Syntax
     89     rewrite: string or scalar string `Tensor`, value to use in match
     90       replacement, supports backslash-escaped digits (\1 to \9) can be to insert
     91       text matching corresponding parenthesized group.
     92     replace_global: `bool`, if `True` replace all non-overlapping matches,
     93       else replace only the first match.
     94     name: A name for the operation (optional).
     95 
     96   Returns:
     97     string `Tensor` of the same shape as `input` with specified replacements.
     98   """
     99   if (isinstance(pattern, util_compat.bytes_or_text_types) and
    100       isinstance(rewrite, util_compat.bytes_or_text_types)):
    101     # When `pattern` and `rewrite` are static through the life of the op we can
    102     # use a version which performs the expensive regex compilation once at
    103     # creation time.
    104     return gen_string_ops.static_regex_replace(
    105         input=input, pattern=pattern,
    106         rewrite=rewrite, replace_global=replace_global,
    107         name=name)
    108   return gen_string_ops.regex_replace(
    109       input=input, pattern=pattern,
    110       rewrite=rewrite, replace_global=replace_global,
    111       name=name)
    112 
    113 
    114 @tf_export("strings.format")
    115 def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
    116   r"""Formats a string template using a list of tensors.
    117 
    118   Formats a string template using a list of tensors, abbreviating tensors by
    119   only printing the first and last `summarize` elements of each dimension
    120   (recursively). If formatting only one tensor into a template, the tensor does
    121   not have to be wrapped in a list.
    122 
    123   Example:
    124     Formatting a single-tensor template:
    125     ```python
    126     sess = tf.Session()
    127     with sess.as_default():
    128         tensor = tf.range(10)
    129         formatted = tf.strings.format("tensor: {}, suffix", tensor)
    130         out = sess.run(formatted)
    131         expected = "tensor: [0 1 2 ... 7 8 9], suffix"
    132 
    133         assert(out.decode() == expected)
    134     ```
    135 
    136     Formatting a multi-tensor template:
    137     ```python
    138     sess = tf.Session()
    139     with sess.as_default():
    140         tensor_one = tf.reshape(tf.range(100), [10, 10])
    141         tensor_two = tf.range(10)
    142         formatted = tf.strings.format("first: {}, second: {}, suffix",
    143           (tensor_one, tensor_two))
    144 
    145         out = sess.run(formatted)
    146         expected = ("first: [[0 1 2 ... 7 8 9]\n"
    147               " [10 11 12 ... 17 18 19]\n"
    148               " [20 21 22 ... 27 28 29]\n"
    149               " ...\n"
    150               " [70 71 72 ... 77 78 79]\n"
    151               " [80 81 82 ... 87 88 89]\n"
    152               " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
    153 
    154         assert(out.decode() == expected)
    155     ```
    156 
    157   Args:
    158     template: A string template to format tensor values into.
    159     inputs: A list of `Tensor` objects, or a single Tensor.
    160       The list of tensors to format into the template string. If a solitary
    161       tensor is passed in, the input tensor will automatically be wrapped as a
    162       list.
    163     placeholder: An optional `string`. Defaults to `{}`.
    164       At each placeholder occurring in the template, a subsequent tensor
    165       will be inserted.
    166     summarize: An optional `int`. Defaults to `3`.
    167       When formatting the tensors, show the first and last `summarize`
    168       entries of each tensor dimension (recursively). If set to -1, all
    169       elements of the tensor will be shown.
    170     name: A name for the operation (optional).
    171 
    172   Returns:
    173     A scalar `Tensor` of type `string`.
    174 
    175   Raises:
    176     ValueError: if the number of placeholders does not match the number of
    177       inputs.
    178   """
    179   # If there is only one tensor to format, we will automatically wrap it in a
    180   # list to simplify the user experience
    181   if tensor_util.is_tensor(inputs):
    182     inputs = [inputs]
    183   if template.count(placeholder) != len(inputs):
    184     raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
    185                      " provided as input" % (template.count(placeholder),
    186                                              len(inputs)))
    187 
    188   return gen_string_ops.string_format(inputs,
    189                                       template=template,
    190                                       placeholder=placeholder,
    191                                       summarize=summarize,
    192                                       name=name)
    193 
    194 
    195 @tf_export(v1=["string_split"])
    196 @deprecation.deprecated_args(None,
    197                              "delimiter is deprecated, please use sep instead.",
    198                              "delimiter")
    199 def string_split(source, sep=None, skip_empty=True, delimiter=None):  # pylint: disable=invalid-name
    200   """Split elements of `source` based on `delimiter` into a `SparseTensor`.
    201 
    202   Let N be the size of source (typically N will be the batch size). Split each
    203   element of `source` based on `delimiter` and return a `SparseTensor`
    204   containing the split tokens. Empty tokens are ignored.
    205 
    206   If `sep` is an empty string, each element of the `source` is split
    207   into individual strings, each containing one byte. (This includes splitting
    208   multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
    209   treated as a set of delimiters with each considered a potential split point.
    210 
    211   For example:
    212   N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
    213   will be
    214 
    215   st.indices = [0, 0;
    216                 0, 1;
    217                 1, 0;
    218                 1, 1;
    219                 1, 2]
    220   st.shape = [2, 3]
    221   st.values = ['hello', 'world', 'a', 'b', 'c']
    222 
    223   Args:
    224     source: `1-D` string `Tensor`, the strings to split.
    225     sep: `0-D` string `Tensor`, the delimiter character, the string should
    226       be length 0 or 1. Default is ' '.
    227     skip_empty: A `bool`. If `True`, skip the empty strings from the result.
    228     delimiter: deprecated alias for `sep`.
    229 
    230   Raises:
    231     ValueError: If delimiter is not a string.
    232 
    233   Returns:
    234     A `SparseTensor` of rank `2`, the strings split according to the delimiter.
    235     The first column of the indices corresponds to the row in `source` and the
    236     second column corresponds to the index of the split component in this row.
    237   """
    238   delimiter = deprecation.deprecated_argument_lookup(
    239       "sep", sep, "delimiter", delimiter)
    240 
    241   if delimiter is None:
    242     delimiter = " "
    243   delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
    244   source = ops.convert_to_tensor(source, dtype=dtypes.string)
    245 
    246   indices, values, shape = gen_string_ops.string_split(
    247       source, delimiter=delimiter, skip_empty=skip_empty)
    248   indices.set_shape([None, 2])
    249   values.set_shape([None])
    250   shape.set_shape([2])
    251   return sparse_tensor.SparseTensor(indices, values, shape)
    252 
    253 
    254 @tf_export("strings.split")
    255 def string_split_v2(source, sep=None, maxsplit=-1):
    256   """Split elements of `source` based on `sep` into a `SparseTensor`.
    257 
    258   Let N be the size of source (typically N will be the batch size). Split each
    259   element of `source` based on `sep` and return a `SparseTensor`
    260   containing the split tokens. Empty tokens are ignored.
    261 
    262   For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
    263   then the output will be
    264 
    265   st.indices = [0, 0;
    266                 0, 1;
    267                 1, 0;
    268                 1, 1;
    269                 1, 2]
    270   st.shape = [2, 3]
    271   st.values = ['hello', 'world', 'a', 'b', 'c']
    272 
    273   If `sep` is given, consecutive delimiters are not grouped together and are
    274   deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
    275   sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
    276   string, consecutive whitespace are regarded as a single separator, and the
    277   result will contain no empty strings at the start or end if the string has
    278   leading or trailing whitespace.
    279 
    280   Note that the above mentioned behavior matches python's str.split.
    281 
    282   Args:
    283     source: `1-D` string `Tensor`, the strings to split.
    284     sep: `0-D` string `Tensor`, the delimiter character.
    285     maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
    286 
    287   Raises:
    288     ValueError: If sep is not a string.
    289 
    290   Returns:
    291     A `SparseTensor` of rank `2`, the strings split according to the delimiter.
    292     The first column of the indices corresponds to the row in `source` and the
    293     second column corresponds to the index of the split component in this row.
    294   """
    295   if sep is None:
    296     sep = ""
    297   sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
    298   source = ops.convert_to_tensor(source, dtype=dtypes.string)
    299 
    300   indices, values, shape = gen_string_ops.string_split_v2(
    301       source, sep=sep, maxsplit=maxsplit)
    302   indices.set_shape([None, 2])
    303   values.set_shape([None])
    304   shape.set_shape([2])
    305   return sparse_tensor.SparseTensor(indices, values, shape)
    306 
    307 
    308 def _reduce_join_reduction_dims(x, axis, reduction_indices):
    309   """Returns range(rank(x) - 1, 0, -1) if reduction_indices is None."""
    310   # TODO(aselle): Remove this after deprecation
    311   if reduction_indices is not None:
    312     if axis is not None:
    313       raise ValueError("Can't specify both 'axis' and 'reduction_indices'.")
    314     axis = reduction_indices
    315   if axis is not None:
    316     return axis
    317   else:
    318     # Fast path: avoid creating Rank and Range ops if ndims is known.
    319     if x.get_shape().ndims is not None:
    320       return constant_op.constant(
    321           np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32)
    322 
    323     # Otherwise, we rely on Range and Rank to do the right thing at run-time.
    324     return math_ops.range(array_ops.rank(x) - 1, -1, -1)
    325 
    326 
    327 @tf_export(v1=["strings.reduce_join", "reduce_join"])
    328 @deprecation.deprecated_endpoints("reduce_join")
    329 def reduce_join(inputs, axis=None,  # pylint: disable=missing-docstring
    330                 keep_dims=False,
    331                 separator="",
    332                 name=None,
    333                 reduction_indices=None,
    334                 keepdims=None):
    335   keep_dims = deprecation.deprecated_argument_lookup(
    336       "keepdims", keepdims, "keep_dims", keep_dims)
    337   inputs_t = ops.convert_to_tensor(inputs)
    338   reduction_indices = _reduce_join_reduction_dims(
    339       inputs_t, axis, reduction_indices)
    340   return gen_string_ops.reduce_join(
    341       inputs=inputs_t,
    342       reduction_indices=reduction_indices,
    343       keep_dims=keep_dims,
    344       separator=separator,
    345       name=name)
    346 
    347 
    348 @tf_export("strings.reduce_join", v1=[])
    349 def reduce_join_v2(  # pylint: disable=missing-docstring
    350     inputs,
    351     axis=None,
    352     keepdims=False,
    353     separator="",
    354     name=None):
    355   return reduce_join(
    356       inputs, axis, keep_dims=keepdims, separator=separator, name=name)
    357 
    358 
    359 reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
    360     gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
    361 reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
    362                                                   "tf.strings.reduce_join(")
    363 
    364 
    365 # This wrapper provides backwards compatibility for code that predates the
    366 # unit argument and that passed 'name' as a positional argument.
    367 @tf_export(v1=["strings.length"])
    368 @dispatch.add_dispatch_support
    369 def string_length(input, name=None, unit="BYTE"):
    370   return gen_string_ops.string_length(input, unit=unit, name=name)
    371 
    372 
    373 @tf_export("strings.length", v1=[])
    374 @dispatch.add_dispatch_support
    375 def string_length_v2(input, unit="BYTE", name=None):
    376   return string_length(input, name, unit)
    377 
    378 
    379 string_length.__doc__ = gen_string_ops.string_length.__doc__
    380 
    381 
    382 @tf_export(v1=["substr"])
    383 @deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
    384 def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
    385   return substr(input, pos, len, name=name, unit=unit)
    386 
    387 substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
    388 
    389 
    390 @tf_export(v1=["strings.substr"])
    391 @dispatch.add_dispatch_support
    392 def substr(input, pos, len, name=None, unit="BYTE"):
    393   return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
    394 
    395 substr.__doc__ = gen_string_ops.substr.__doc__
    396 
    397 
    398 @tf_export("strings.substr", v1=[])
    399 @dispatch.add_dispatch_support
    400 def substr_v2(input, pos, len, unit="BYTE", name=None):
    401   return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
    402 
    403 substr_v2.__doc__ = gen_string_ops.substr.__doc__
    404 
    405 
    406 ops.NotDifferentiable("RegexReplace")
    407 ops.NotDifferentiable("StringToHashBucket")
    408 ops.NotDifferentiable("StringToHashBucketFast")
    409 ops.NotDifferentiable("StringToHashBucketStrong")
    410 ops.NotDifferentiable("ReduceJoin")
    411 ops.NotDifferentiable("StringJoin")
    412 ops.NotDifferentiable("StringSplit")
    413 ops.NotDifferentiable("AsString")
    414 ops.NotDifferentiable("EncodeBase64")
    415 ops.NotDifferentiable("DecodeBase64")
    416 
    417 
    418 @tf_export("strings.to_number", v1=[])
    419 @dispatch.add_dispatch_support
    420 def string_to_number(input, out_type=dtypes.float32, name=None):
    421   r"""Converts each string in the input Tensor to the specified numeric type.
    422 
    423   (Note that int32 overflow results in an error while float overflow
    424   results in a rounded value.)
    425 
    426   Args:
    427     input: A `Tensor` of type `string`.
    428     out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32,
    429       tf.int64`. Defaults to `tf.float32`.
    430       The numeric type to interpret each string in `string_tensor` as.
    431     name: A name for the operation (optional).
    432 
    433   Returns:
    434     A `Tensor` of type `out_type`.
    435   """
    436   return gen_parsing_ops.string_to_number(input, out_type, name)
    437 
    438 
    439 @tf_export(v1=["strings.to_number", "string_to_number"])
    440 def string_to_number_v1(
    441     string_tensor=None,
    442     out_type=dtypes.float32,
    443     name=None,
    444     input=None):
    445   string_tensor = deprecation.deprecated_argument_lookup(
    446       "input", input, "string_tensor", string_tensor)
    447   return gen_parsing_ops.string_to_number(string_tensor, out_type, name)
    448 
    449 string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__
    450 
    451 
    452 @tf_export("strings.to_hash_bucket", v1=[])
    453 @dispatch.add_dispatch_support
    454 def string_to_hash_bucket(input, num_buckets, name=None):
    455   # pylint: disable=line-too-long
    456   r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
    457 
    458   The hash function is deterministic on the content of the string within the
    459   process.
    460 
    461   Note that the hash function may change from time to time.
    462   This functionality will be deprecated and it's recommended to use
    463   `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`.
    464 
    465   Args:
    466     input: A `Tensor` of type `string`.
    467     num_buckets: An `int` that is `>= 1`. The number of buckets.
    468     name: A name for the operation (optional).
    469 
    470   Returns:
    471     A `Tensor` of type `int64`.
    472   """
    473   # pylint: enable=line-too-long
    474   return gen_string_ops.string_to_hash_bucket(input, num_buckets, name)
    475 
    476 
    477 @tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
    478 def string_to_hash_bucket_v1(
    479     string_tensor=None,
    480     num_buckets=None,
    481     name=None,
    482     input=None):
    483   string_tensor = deprecation.deprecated_argument_lookup(
    484       "input", input, "string_tensor", string_tensor)
    485   return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name)
    486 
    487 string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
    488