1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import random 21 22 import numpy as np 23 24 from tensorflow.python.client import session 25 from tensorflow.python.data.ops import dataset_ops 26 from tensorflow.python.framework import errors 27 from tensorflow.python.framework import sparse_tensor 28 from tensorflow.python.ops import sparse_ops 29 from tensorflow.python.platform import test 30 from tensorflow.python.training import server_lib 31 32 33 class FlatMapDatasetTest(test.TestCase): 34 35 # pylint: disable=g-long-lambda 36 def testFlatMapDataset(self): 37 repeats = [1, 2, 3, 4, 5, 0, 1] 38 components = np.array(repeats, dtype=np.int64) 39 iterator = ( 40 dataset_ops.Dataset.from_tensor_slices(components) 41 .flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x)) 42 .make_initializable_iterator()) 43 init_op = iterator.initializer 44 get_next = iterator.get_next() 45 46 with self.test_session() as sess: 47 sess.run(init_op) 48 for i in repeats: 49 for _ in range(i): 50 self.assertEqual(i, sess.run(get_next)) 51 with self.assertRaises(errors.OutOfRangeError): 52 sess.run(get_next) 53 54 def testNestedFlatMapDataset(self): 55 repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] 56 components = np.array(repeats, dtype=np.int64) 57 iterator = ( 58 dataset_ops.Dataset.from_tensor_slices(components) 59 .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) 60 .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) 61 .repeat(y))).make_initializable_iterator()) 62 init_op = iterator.initializer 63 get_next = iterator.get_next() 64 65 with self.test_session() as sess: 66 sess.run(init_op) 67 for row in repeats: 68 for i in row: 69 for _ in range(i): 70 self.assertEqual(i, sess.run(get_next)) 71 72 with self.assertRaises(errors.OutOfRangeError): 73 sess.run(get_next) 74 75 def testSharedResourceNestedFlatMapDataset(self): 76 repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] 77 components = np.array(repeats, dtype=np.int64) 78 iterator = ( 79 dataset_ops.Dataset.from_tensor_slices(components) 80 .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) 81 .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) 82 .repeat(y))).make_initializable_iterator( 83 shared_name="shared_flat_map_iterator")) 84 init_op = iterator.initializer 85 get_next = iterator.get_next() 86 87 # Create two concurrent sessions that share the same iterator 88 # resource on the same server, and verify that a random 89 # interleaving of `Session.run(get_next)` calls on the two 90 # sessions yields the expected result. 91 server = server_lib.Server.create_local_server() 92 with session.Session(server.target) as sess1: 93 with session.Session(server.target) as sess2: 94 for _ in range(3): 95 sess = random.choice([sess1, sess2]) 96 sess.run(init_op) 97 for row in repeats: 98 for i in row: 99 for _ in range(i): 100 sess = random.choice([sess1, sess2]) 101 self.assertEqual(i, sess.run(get_next)) 102 103 with self.assertRaises(errors.OutOfRangeError): 104 sess = random.choice([sess1, sess2]) 105 sess.run(get_next) 106 107 def testMapDict(self): 108 iterator = (dataset_ops.Dataset.range(10) 109 .map(lambda x: {"foo": x * 2, "bar": x ** 2}) 110 .flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"]) 111 .repeat(d["bar"])) 112 .make_initializable_iterator()) 113 init_op = iterator.initializer 114 get_next = iterator.get_next() 115 116 with self.test_session() as sess: 117 sess.run(init_op) 118 for i in range(10): 119 for _ in range(i ** 2): 120 self.assertEqual(i * 2, sess.run(get_next)) 121 with self.assertRaises(errors.OutOfRangeError): 122 sess.run(get_next) 123 # pylint: enable=g-long-lambda 124 125 def testSparse(self): 126 def _map_fn(i): 127 return sparse_tensor.SparseTensorValue( 128 indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) 129 130 def _flat_map_fn(x): 131 return dataset_ops.Dataset.from_tensor_slices( 132 sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) 133 134 iterator = ( 135 dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) 136 .make_initializable_iterator()) 137 init_op = iterator.initializer 138 get_next = iterator.get_next() 139 140 with self.test_session() as sess: 141 sess.run(init_op) 142 for i in range(10): 143 for j in range(2): 144 expected = [i, 0] if j % 2 == 0 else [0, -i] 145 self.assertAllEqual(expected, sess.run(get_next)) 146 with self.assertRaises(errors.OutOfRangeError): 147 sess.run(get_next) 148 149 150 if __name__ == "__main__": 151 test.main() 152