Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Implementation of tf.sets."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import sparse_tensor
     25 from tensorflow.python.ops import gen_set_ops
     26 from tensorflow.python.util.tf_export import tf_export
     27 
     28 
     29 _VALID_DTYPES = set([
     30     dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
     31     dtypes.uint8, dtypes.uint16, dtypes.string])
     32 
     33 
     34 @tf_export("sets.size", v1=["sets.size", "sets.set_size"])
     35 def set_size(a, validate_indices=True):
     36   """Compute number of unique elements along last dimension of `a`.
     37 
     38   Args:
     39     a: `SparseTensor`, with indices sorted in row-major order.
     40     validate_indices: Whether to validate the order and range of sparse indices
     41        in `a`.
     42 
     43   Returns:
     44     `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
     45     rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
     46     number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
     47 
     48   Raises:
     49     TypeError: If `a` is an invalid types.
     50   """
     51   a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
     52   if not isinstance(a, sparse_tensor.SparseTensor):
     53     raise TypeError("Expected `SparseTensor`, got %s." % a)
     54   if a.values.dtype.base_dtype not in _VALID_DTYPES:
     55     raise TypeError("Invalid dtype %s." % a.values.dtype)
     56   # pylint: disable=protected-access
     57   return gen_set_ops.set_size(
     58       a.indices, a.values, a.dense_shape, validate_indices)
     59 
     60 ops.NotDifferentiable("SetSize")
     61 
     62 
     63 ops.NotDifferentiable("DenseToDenseSetOperation")
     64 ops.NotDifferentiable("DenseToSparseSetOperation")
     65 ops.NotDifferentiable("SparseToSparseSetOperation")
     66 
     67 
     68 def _convert_to_tensors_or_sparse_tensors(a, b):
     69   """Convert to tensor types, and flip order if necessary.
     70 
     71   Args:
     72     a: `Tensor` or `SparseTensor` of the same type as `b`.
     73     b: `Tensor` or `SparseTensor` of the same type as `a`.
     74 
     75   Returns:
     76     Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to
     77     `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has
     78     been flipped to make it dense,sparse instead of sparse,dense (since the set
     79     ops do not support the latter).
     80   """
     81   a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
     82   if a.dtype.base_dtype not in _VALID_DTYPES:
     83     raise TypeError("'a' invalid dtype %s." % a.dtype)
     84   b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
     85   if b.dtype.base_dtype != a.dtype.base_dtype:
     86     raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
     87   if (isinstance(a, sparse_tensor.SparseTensor) and
     88       not isinstance(b, sparse_tensor.SparseTensor)):
     89     return b, a, True
     90   return a, b, False
     91 
     92 
     93 def _set_operation(a, b, set_operation, validate_indices=True):
     94   """Compute set operation of elements in last dimension of `a` and `b`.
     95 
     96   All but the last dimension of `a` and `b` must match.
     97 
     98   Args:
     99     a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
    100         must be sorted in row-major order.
    101     b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
    102         `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
    103         sorted in row-major order.
    104     set_operation: String indicating set operation. See
    105         SetOperationOp::SetOperationFromContext for valid values.
    106     validate_indices: Whether to validate the order and range of sparse indices
    107        in `a` and `b`.
    108 
    109   Returns:
    110     A `SparseTensor` with the same rank as `a` and `b`, and all but the last
    111     dimension the same. Elements along the last dimension contain the results
    112     of the set operation.
    113 
    114   Raises:
    115     TypeError: If inputs are invalid types.
    116     ValueError: If `a` is sparse and `b` is dense.
    117   """
    118   if isinstance(a, sparse_tensor.SparseTensor):
    119     if isinstance(b, sparse_tensor.SparseTensor):
    120       indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
    121           a.indices, a.values, a.dense_shape,
    122           b.indices, b.values, b.dense_shape,
    123           set_operation, validate_indices)
    124     else:
    125       raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
    126                        "Please flip the order of your inputs.")
    127   elif isinstance(b, sparse_tensor.SparseTensor):
    128     indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
    129         a, b.indices, b.values, b.dense_shape, set_operation, validate_indices)
    130   else:
    131     indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
    132         a, b, set_operation, validate_indices)
    133   return sparse_tensor.SparseTensor(indices, values, shape)
    134 
    135 
    136 @tf_export(
    137     "sets.intersection", v1=["sets.intersection", "sets.set_intersection"])
    138 def set_intersection(a, b, validate_indices=True):
    139   """Compute set intersection of elements in last dimension of `a` and `b`.
    140 
    141   All but the last dimension of `a` and `b` must match.
    142 
    143   Example:
    144 
    145   ```python
    146     import tensorflow as tf
    147     import collections
    148 
    149     # Represent the following array of sets as a sparse tensor:
    150     # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
    151     a = collections.OrderedDict([
    152         ((0, 0, 0), 1),
    153         ((0, 0, 1), 2),
    154         ((0, 1, 0), 3),
    155         ((1, 0, 0), 4),
    156         ((1, 1, 0), 5),
    157         ((1, 1, 1), 6),
    158     ])
    159     a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2,2,2])
    160 
    161     # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]])
    162     b = collections.OrderedDict([
    163         ((0, 0, 0), 1),
    164         ((1, 0, 0), 4),
    165         ((1, 1, 0), 5),
    166         ((1, 1, 1), 6),
    167         ((1, 1, 2), 7),
    168         ((1, 1, 3), 8),
    169     ])
    170     b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
    171 
    172     # `tf.sets.set_intersection` is applied to each aligned pair of sets.
    173     tf.sets.set_intersection(a, b)
    174 
    175     # The result will be equivalent to either of:
    176     #
    177     # np.array([[{1}, {}], [{4}, {5, 6}]])
    178     #
    179     # collections.OrderedDict([
    180     #     ((0, 0, 0), 1),
    181     #     ((1, 0, 0), 4),
    182     #     ((1, 1, 0), 5),
    183     #     ((1, 1, 1), 6),
    184     # ])
    185   ```
    186 
    187   Args:
    188     a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
    189         must be sorted in row-major order.
    190     b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
    191         must be sorted in row-major order.
    192     validate_indices: Whether to validate the order and range of sparse indices
    193        in `a` and `b`.
    194 
    195   Returns:
    196     A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
    197     the last dimension the same. Elements along the last dimension contain the
    198     intersections.
    199   """
    200   a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
    201   return _set_operation(a, b, "intersection", validate_indices)
    202 
    203 
    204 @tf_export(
    205 	   "sets.difference", v1=["sets.difference", "sets.set_difference"])
    206 def set_difference(a, b, aminusb=True, validate_indices=True):
    207   """Compute set difference of elements in last dimension of `a` and `b`.
    208 
    209   All but the last dimension of `a` and `b` must match.
    210 
    211   Example:
    212 
    213   ```python
    214     import tensorflow as tf
    215     import collections
    216 
    217     # Represent the following array of sets as a sparse tensor:
    218     # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
    219     a = collections.OrderedDict([
    220         ((0, 0, 0), 1),
    221         ((0, 0, 1), 2),
    222         ((0, 1, 0), 3),
    223         ((1, 0, 0), 4),
    224         ((1, 1, 0), 5),
    225         ((1, 1, 1), 6),
    226     ])
    227     a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2])
    228 
    229     # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]])
    230     b = collections.OrderedDict([
    231         ((0, 0, 0), 1),
    232         ((0, 0, 1), 3),
    233         ((0, 1, 0), 2),
    234         ((1, 0, 0), 4),
    235         ((1, 0, 1), 5),
    236         ((1, 1, 0), 5),
    237         ((1, 1, 1), 6),
    238         ((1, 1, 2), 7),
    239         ((1, 1, 3), 8),
    240     ])
    241     b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
    242 
    243     # `set_difference` is applied to each aligned pair of sets.
    244     tf.sets.set_difference(a, b)
    245 
    246     # The result will be equivalent to either of:
    247     #
    248     # np.array([[{2}, {3}], [{}, {}]])
    249     #
    250     # collections.OrderedDict([
    251     #     ((0, 0, 0), 2),
    252     #     ((0, 1, 0), 3),
    253     # ])
    254   ```
    255 
    256   Args:
    257     a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
    258         must be sorted in row-major order.
    259     b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
    260         must be sorted in row-major order.
    261     aminusb: Whether to subtract `b` from `a`, vs vice versa.
    262     validate_indices: Whether to validate the order and range of sparse indices
    263        in `a` and `b`.
    264 
    265   Returns:
    266     A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
    267     the last dimension the same. Elements along the last dimension contain the
    268     differences.
    269   """
    270   a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b)
    271   if flipped:
    272     aminusb = not aminusb
    273   return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
    274 
    275 
    276 @tf_export(
    277 	   "sets.union", v1=["sets.union", "sets.set_union"])
    278 def set_union(a, b, validate_indices=True):
    279   """Compute set union of elements in last dimension of `a` and `b`.
    280 
    281   All but the last dimension of `a` and `b` must match.
    282 
    283   Example:
    284 
    285   ```python
    286     import tensorflow as tf
    287     import collections
    288 
    289     # [[{1, 2}, {3}], [{4}, {5, 6}]]
    290     a = collections.OrderedDict([
    291         ((0, 0, 0), 1),
    292         ((0, 0, 1), 2),
    293         ((0, 1, 0), 3),
    294         ((1, 0, 0), 4),
    295         ((1, 1, 0), 5),
    296         ((1, 1, 1), 6),
    297     ])
    298     a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2])
    299 
    300     # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]
    301     b = collections.OrderedDict([
    302         ((0, 0, 0), 1),
    303         ((0, 0, 1), 3),
    304         ((0, 1, 0), 2),
    305         ((1, 0, 0), 4),
    306         ((1, 0, 1), 5),
    307         ((1, 1, 0), 5),
    308         ((1, 1, 1), 6),
    309         ((1, 1, 2), 7),
    310         ((1, 1, 3), 8),
    311     ])
    312     b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
    313 
    314     # `set_union` is applied to each aligned pair of sets.
    315     tf.sets.set_union(a, b)
    316 
    317     # The result will be a equivalent to either of:
    318     #
    319     # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]])
    320     #
    321     # collections.OrderedDict([
    322     #     ((0, 0, 0), 1),
    323     #     ((0, 0, 1), 2),
    324     #     ((0, 0, 2), 3),
    325     #     ((0, 1, 0), 2),
    326     #     ((0, 1, 1), 3),
    327     #     ((1, 0, 0), 4),
    328     #     ((1, 0, 1), 5),
    329     #     ((1, 1, 0), 5),
    330     #     ((1, 1, 1), 6),
    331     #     ((1, 1, 2), 7),
    332     #     ((1, 1, 3), 8),
    333     # ])
    334   ```
    335 
    336   Args:
    337     a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
    338         must be sorted in row-major order.
    339     b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
    340         must be sorted in row-major order.
    341     validate_indices: Whether to validate the order and range of sparse indices
    342        in `a` and `b`.
    343 
    344   Returns:
    345     A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
    346     the last dimension the same. Elements along the last dimension contain the
    347     unions.
    348   """
    349   a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
    350   return _set_operation(a, b, "union", validate_indices)
    351