Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.framework import sparse_tensor
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import functional_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 class FilterDatasetTest(test.TestCase):
     33 
     34   def testFilterDataset(self):
     35     components = (
     36         np.arange(7, dtype=np.int64),
     37         np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
     38             7, dtype=np.int64)[:, np.newaxis],
     39         np.array(37.0, dtype=np.float64) * np.arange(7)
     40     )
     41     count = array_ops.placeholder(dtypes.int64, shape=[])
     42     modulus = array_ops.placeholder(dtypes.int64)
     43 
     44     def _map_fn(x, y, z):
     45       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
     46 
     47     iterator = (
     48         dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
     49         .repeat(count)
     50         .filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
     51         .make_initializable_iterator())
     52     init_op = iterator.initializer
     53     get_next = iterator.get_next()
     54 
     55     self.assertEqual([c.shape[1:] for c in components],
     56                      [t.shape for t in get_next])
     57 
     58     with self.test_session() as sess:
     59       # Test that we can dynamically feed a different modulus value for each
     60       # iterator.
     61       def do_test(count_val, modulus_val):
     62         sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
     63         for _ in range(count_val):
     64           for i in [x for x in range(7) if x**2 % modulus_val == 0]:
     65             result = sess.run(get_next)
     66             for component, result_component in zip(components, result):
     67               self.assertAllEqual(component[i]**2, result_component)
     68         with self.assertRaises(errors.OutOfRangeError):
     69           sess.run(get_next)
     70 
     71       do_test(14, 2)
     72       do_test(4, 18)
     73 
     74       # Test an empty dataset.
     75       do_test(0, 1)
     76 
     77   def testFilterRange(self):
     78     dataset = dataset_ops.Dataset.range(100).filter(
     79         lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
     80     iterator = dataset.make_one_shot_iterator()
     81     get_next = iterator.get_next()
     82 
     83     with self.test_session() as sess:
     84       self.assertEqual(0, sess.run(get_next))
     85       self.assertEqual(1, sess.run(get_next))
     86       self.assertEqual(3, sess.run(get_next))
     87 
     88   def testFilterDict(self):
     89     iterator = (dataset_ops.Dataset.range(10)
     90                 .map(lambda x: {"foo": x * 2, "bar": x ** 2})
     91                 .filter(lambda d: math_ops.equal(d["bar"] % 2, 0))
     92                 .map(lambda d: d["foo"] + d["bar"])
     93                 .make_initializable_iterator())
     94     init_op = iterator.initializer
     95     get_next = iterator.get_next()
     96 
     97     with self.test_session() as sess:
     98       sess.run(init_op)
     99       for i in range(10):
    100         if (i ** 2) % 2 == 0:
    101           self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
    102       with self.assertRaises(errors.OutOfRangeError):
    103         sess.run(get_next)
    104 
    105   def testUseStepContainerInFilter(self):
    106     input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
    107 
    108     # Define a predicate that returns true for the first element of
    109     # the sequence and not the second, and uses `tf.map_fn()`.
    110     def _predicate(xs):
    111       squared_xs = functional_ops.map_fn(lambda x: x * x, xs)
    112       summed = math_ops.reduce_sum(squared_xs)
    113       return math_ops.equal(summed, 1 + 4 + 9)
    114 
    115     iterator = (
    116         dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]])
    117         .filter(_predicate)
    118         .make_initializable_iterator())
    119     init_op = iterator.initializer
    120     get_next = iterator.get_next()
    121 
    122     with self.test_session() as sess:
    123       sess.run(init_op)
    124       self.assertAllEqual(input_data[0], sess.run(get_next))
    125       with self.assertRaises(errors.OutOfRangeError):
    126         sess.run(get_next)
    127 
    128   def assertSparseValuesEqual(self, a, b):
    129     self.assertAllEqual(a.indices, b.indices)
    130     self.assertAllEqual(a.values, b.values)
    131     self.assertAllEqual(a.dense_shape, b.dense_shape)
    132 
    133   def testSparse(self):
    134 
    135     def _map_fn(i):
    136       return sparse_tensor.SparseTensorValue(
    137           indices=np.array([[0, 0]]),
    138           values=(i * np.array([1])),
    139           dense_shape=np.array([1, 1])), i
    140 
    141     def _filter_fn(_, i):
    142       return math_ops.equal(i % 2, 0)
    143 
    144     iterator = (
    145         dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
    146             lambda x, i: x).make_initializable_iterator())
    147     init_op = iterator.initializer
    148     get_next = iterator.get_next()
    149 
    150     with self.test_session() as sess:
    151       sess.run(init_op)
    152       for i in range(5):
    153         actual = sess.run(get_next)
    154         self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
    155         self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
    156       with self.assertRaises(errors.OutOfRangeError):
    157         sess.run(get_next)
    158 
    159 
    160 if __name__ == "__main__":
    161   test.main()
    162