Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for SparseTensorsMap."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import sparse_ops
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.platform import test
     31 
     32 # pylint: disable=protected-access
     33 add_sparse_to_tensors_map = sparse_ops._add_sparse_to_tensors_map
     34 add_many_sparse_to_tensors_map = sparse_ops._add_many_sparse_to_tensors_map
     35 take_many_sparse_from_tensors_map = (
     36     sparse_ops._take_many_sparse_from_tensors_map)
     37 
     38 # pylint: enable=protected-access
     39 
     40 
     41 class SparseTensorsMapTest(test.TestCase):
     42 
     43   def _SparseTensorPlaceholder(self, dtype=None):
     44     if dtype is None:
     45       dtype = dtypes.int32
     46     return sparse_tensor_lib.SparseTensor(
     47         array_ops.placeholder(dtypes.int64),
     48         array_ops.placeholder(dtype), array_ops.placeholder(dtypes.int64))
     49 
     50   def _SparseTensorValue_5x6(self, permutation):
     51     ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
     52                     [3, 3]]).astype(np.int64)
     53     val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
     54 
     55     ind = ind[permutation]
     56     val = val[permutation]
     57 
     58     shape = np.array([5, 6]).astype(np.int64)
     59     return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
     60 
     61   def _SparseTensorValue_3x4(self, permutation):
     62     ind = np.array([[0, 0], [1, 0], [1, 2], [1, 3], [2, 2],
     63                     [2, 3]]).astype(np.int64)
     64     val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
     65 
     66     ind = ind[permutation]
     67     val = val[permutation]
     68 
     69     shape = np.array([3, 4]).astype(np.int64)
     70     return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
     71 
     72   def _SparseTensorValue_1x1x1(self):
     73     ind = np.array([[0, 0, 0]]).astype(np.int64)
     74     val = np.array([0]).astype(np.int32)
     75     shape = np.array([3, 4, 5]).astype(np.int64)
     76     return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
     77 
     78   def testAddTakeMany(self):
     79     with self.test_session(graph=ops.Graph(), use_gpu=False) as sess:
     80       sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
     81       sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
     82       handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a")
     83       handle1 = add_sparse_to_tensors_map(sp_input1, shared_name="a")
     84       self.assertEqual(handle0.get_shape(), ())
     85       handles_concat = array_ops.stack([handle0, handle1])
     86 
     87       sp_out = take_many_sparse_from_tensors_map(
     88           sparse_map_op=handle0.op, sparse_handles=handles_concat)
     89 
     90       combined_indices, combined_values, combined_shape = sess.run(sp_out)
     91 
     92       self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
     93       self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
     94       self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
     95       self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
     96       self.assertAllEqual(combined_values[:6], sp_input0[1])
     97       self.assertAllEqual(combined_values[6:], sp_input1[1])
     98       self.assertAllEqual(combined_shape, [2, 5, 6])
     99 
    100   def testFeedAddTakeMany(self):
    101     with self.test_session(use_gpu=False) as sess:
    102       sp_input = self._SparseTensorPlaceholder()
    103       input0_val = self._SparseTensorValue_5x6(np.arange(6))
    104       input1_val = self._SparseTensorValue_3x4(np.arange(6))
    105       handle = add_sparse_to_tensors_map(sp_input)
    106 
    107       handle0_value = sess.run(handle, feed_dict={sp_input: input0_val})
    108       handle1_value = sess.run(handle, feed_dict={sp_input: input1_val})
    109 
    110       sparse_handles = ops.convert_to_tensor(
    111           [handle0_value, handle1_value], dtype=dtypes.int64)
    112 
    113       sp_roundtrip = take_many_sparse_from_tensors_map(
    114           sparse_map_op=handle.op, sparse_handles=sparse_handles)
    115 
    116       combined_indices, combined_values, combined_shape = sess.run(sp_roundtrip)
    117 
    118       self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
    119       self.assertAllEqual(combined_indices[:6, 1:], input0_val[0])
    120       self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
    121       self.assertAllEqual(combined_indices[6:, 1:], input1_val[0])
    122       self.assertAllEqual(combined_values[:6], input0_val[1])
    123       self.assertAllEqual(combined_values[6:], input1_val[1])
    124       self.assertAllEqual(combined_shape, [2, 5, 6])
    125 
    126   def testAddManyTakeManyRoundTrip(self):
    127     with self.test_session(use_gpu=False) as sess:
    128       # N == 4 because shape_value == [4, 5]
    129       indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
    130       values_value = np.array([b"a", b"b", b"c"])
    131       shape_value = np.array([4, 5], dtype=np.int64)
    132       sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
    133       handles = add_many_sparse_to_tensors_map(sparse_tensor)
    134       roundtrip = take_many_sparse_from_tensors_map(
    135           sparse_map_op=handles.op, sparse_handles=handles)
    136       handles_value, roundtrip_value = sess.run(
    137           [handles, roundtrip],
    138           feed_dict={
    139               sparse_tensor.indices: indices_value,
    140               sparse_tensor.values: values_value,
    141               sparse_tensor.dense_shape: shape_value
    142           })
    143       self.assertEqual(handles_value.shape, (4,))
    144       self.assertAllEqual(roundtrip_value.indices, indices_value)
    145       self.assertAllEqual(roundtrip_value.values, values_value)
    146       self.assertAllEqual(roundtrip_value.dense_shape, shape_value)
    147 
    148   def testDeserializeFailsInconsistentRank(self):
    149     with self.test_session(use_gpu=False) as sess:
    150       sp_input = self._SparseTensorPlaceholder()
    151       input0_val = self._SparseTensorValue_5x6(np.arange(6))
    152       input1_val = self._SparseTensorValue_1x1x1()
    153       handle = add_sparse_to_tensors_map(sp_input)
    154 
    155       handle0_value = sess.run(handle, feed_dict={sp_input: input0_val})
    156       handle1_value = sess.run(handle, feed_dict={sp_input: input1_val})
    157 
    158       handle_concat = ops.convert_to_tensor(
    159           [handle0_value, handle1_value], dtype=dtypes.int64)
    160 
    161       sp_roundtrip = take_many_sparse_from_tensors_map(
    162           sparse_map_op=handle.op, sparse_handles=handle_concat)
    163 
    164       with self.assertRaisesOpError(
    165           r"Inconsistent rank across SparseTensors: rank prior to "
    166           r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
    167         sess.run(sp_roundtrip)
    168 
    169   def testTakeManyFailsWrongInputOp(self):
    170     with self.test_session(use_gpu=False) as sess:
    171       input_val = self._SparseTensorValue_5x6(np.arange(6))
    172       handle = add_sparse_to_tensors_map(input_val)
    173       handle_value = sess.run(handle)
    174       bad_handle = handle_value + 10
    175       sp_roundtrip = take_many_sparse_from_tensors_map(
    176           sparse_map_op=handle.op, sparse_handles=[handle_value, bad_handle])
    177 
    178       with self.assertRaisesOpError(r"Unable to find SparseTensor: 10"):
    179         sess.run(sp_roundtrip)
    180 
    181 
    182 class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
    183 
    184   def benchmarkVeryLarge2DFloatSparseTensor(self):
    185     np.random.seed(127)
    186     num_elements = 10000
    187     batch_size = 64
    188     indices_batch = np.random.randint(
    189         batch_size, size=num_elements, dtype=np.int64)
    190     indices_value = np.arange(num_elements, dtype=np.int64)
    191     indices = np.asarray(
    192         sorted(zip(indices_batch, indices_value)), dtype=np.int64)
    193     values = ["feature_value_for_embedding_lookup"] * num_elements
    194     shape = np.asarray([batch_size, num_elements], dtype=np.int64)
    195     with session.Session() as sess:
    196       with ops.device("/cpu:0"):
    197         indices = variables.Variable(indices)
    198         values = variables.Variable(values)
    199         shape = variables.Variable(shape)
    200         st = sparse_tensor_lib.SparseTensor(indices, values, shape)
    201 
    202         st_handles = add_many_sparse_to_tensors_map(st)
    203         st_roundtrip = take_many_sparse_from_tensors_map(
    204             sparse_map_op=st_handles.op, sparse_handles=st_handles)
    205         st_roundtrip_op = st_roundtrip.values.op
    206 
    207         st_serialized = sparse_ops.serialize_many_sparse(st)
    208         st_deserialized = sparse_ops.deserialize_many_sparse(
    209             st_serialized, dtype=values.dtype)
    210         st_deserialized_op = st_deserialized.values.op
    211 
    212         variables.global_variables_initializer().run()
    213 
    214         st_roundtrip_values = sess.run(st_roundtrip)
    215         st_deserialized_values = sess.run(st_deserialized)
    216         np.testing.assert_equal(st_roundtrip_values.values,
    217                                 st_deserialized_values.values)
    218         np.testing.assert_equal(st_roundtrip_values.indices,
    219                                 st_deserialized_values.indices)
    220         np.testing.assert_equal(st_roundtrip_values.dense_shape,
    221                                 st_deserialized_values.dense_shape)
    222 
    223         self.run_op_benchmark(
    224             sess,
    225             st_roundtrip_op,
    226             min_iters=2000,
    227             name="benchmark_very_large_2d_float_st_tensor_maps")
    228         self.run_op_benchmark(
    229             sess,
    230             st_deserialized_op,
    231             min_iters=2000,
    232             name="benchmark_very_large_2d_float_st_serialization")
    233 
    234 
    235 if __name__ == "__main__":
    236   test.main()
    237