1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import errors 25 from tensorflow.python.framework import sparse_tensor 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import functional_ops 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.platform import test 30 31 32 class FilterDatasetTest(test.TestCase): 33 34 def testFilterDataset(self): 35 components = ( 36 np.arange(7, dtype=np.int64), 37 np.array([[1, 2, 3]], dtype=np.int64) * np.arange( 38 7, dtype=np.int64)[:, np.newaxis], 39 np.array(37.0, dtype=np.float64) * np.arange(7) 40 ) 41 count = array_ops.placeholder(dtypes.int64, shape=[]) 42 modulus = array_ops.placeholder(dtypes.int64) 43 44 def _map_fn(x, y, z): 45 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 46 47 iterator = ( 48 dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 49 .repeat(count) 50 .filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0)) 51 .make_initializable_iterator()) 52 init_op = iterator.initializer 53 get_next = iterator.get_next() 54 55 self.assertEqual([c.shape[1:] for c in components], 56 [t.shape for t in get_next]) 57 58 with self.test_session() as sess: 59 # Test that we can dynamically feed a different modulus value for each 60 # iterator. 61 def do_test(count_val, modulus_val): 62 sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val}) 63 for _ in range(count_val): 64 for i in [x for x in range(7) if x**2 % modulus_val == 0]: 65 result = sess.run(get_next) 66 for component, result_component in zip(components, result): 67 self.assertAllEqual(component[i]**2, result_component) 68 with self.assertRaises(errors.OutOfRangeError): 69 sess.run(get_next) 70 71 do_test(14, 2) 72 do_test(4, 18) 73 74 # Test an empty dataset. 75 do_test(0, 1) 76 77 def testFilterRange(self): 78 dataset = dataset_ops.Dataset.range(100).filter( 79 lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) 80 iterator = dataset.make_one_shot_iterator() 81 get_next = iterator.get_next() 82 83 with self.test_session() as sess: 84 self.assertEqual(0, sess.run(get_next)) 85 self.assertEqual(1, sess.run(get_next)) 86 self.assertEqual(3, sess.run(get_next)) 87 88 def testFilterDict(self): 89 iterator = (dataset_ops.Dataset.range(10) 90 .map(lambda x: {"foo": x * 2, "bar": x ** 2}) 91 .filter(lambda d: math_ops.equal(d["bar"] % 2, 0)) 92 .map(lambda d: d["foo"] + d["bar"]) 93 .make_initializable_iterator()) 94 init_op = iterator.initializer 95 get_next = iterator.get_next() 96 97 with self.test_session() as sess: 98 sess.run(init_op) 99 for i in range(10): 100 if (i ** 2) % 2 == 0: 101 self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) 102 with self.assertRaises(errors.OutOfRangeError): 103 sess.run(get_next) 104 105 def testUseStepContainerInFilter(self): 106 input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 107 108 # Define a predicate that returns true for the first element of 109 # the sequence and not the second, and uses `tf.map_fn()`. 110 def _predicate(xs): 111 squared_xs = functional_ops.map_fn(lambda x: x * x, xs) 112 summed = math_ops.reduce_sum(squared_xs) 113 return math_ops.equal(summed, 1 + 4 + 9) 114 115 iterator = ( 116 dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]]) 117 .filter(_predicate) 118 .make_initializable_iterator()) 119 init_op = iterator.initializer 120 get_next = iterator.get_next() 121 122 with self.test_session() as sess: 123 sess.run(init_op) 124 self.assertAllEqual(input_data[0], sess.run(get_next)) 125 with self.assertRaises(errors.OutOfRangeError): 126 sess.run(get_next) 127 128 def assertSparseValuesEqual(self, a, b): 129 self.assertAllEqual(a.indices, b.indices) 130 self.assertAllEqual(a.values, b.values) 131 self.assertAllEqual(a.dense_shape, b.dense_shape) 132 133 def testSparse(self): 134 135 def _map_fn(i): 136 return sparse_tensor.SparseTensorValue( 137 indices=np.array([[0, 0]]), 138 values=(i * np.array([1])), 139 dense_shape=np.array([1, 1])), i 140 141 def _filter_fn(_, i): 142 return math_ops.equal(i % 2, 0) 143 144 iterator = ( 145 dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( 146 lambda x, i: x).make_initializable_iterator()) 147 init_op = iterator.initializer 148 get_next = iterator.get_next() 149 150 with self.test_session() as sess: 151 sess.run(init_op) 152 for i in range(5): 153 actual = sess.run(get_next) 154 self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) 155 self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0]) 156 with self.assertRaises(errors.OutOfRangeError): 157 sess.run(get_next) 158 159 160 if __name__ == "__main__": 161 test.main() 162