1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from os import path 22 import shutil 23 import tempfile 24 25 from tensorflow.python.data.ops import dataset_ops 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import errors 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.platform import test 30 from tensorflow.python.util import compat 31 32 33 class ListFilesDatasetOpTest(test.TestCase): 34 35 def setUp(self): 36 self.tmp_dir = tempfile.mkdtemp() 37 38 def tearDown(self): 39 shutil.rmtree(self.tmp_dir, ignore_errors=True) 40 41 def _touchTempFiles(self, filenames): 42 for filename in filenames: 43 open(path.join(self.tmp_dir, filename), 'a').close() 44 45 def testEmptyDirectory(self): 46 dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) 47 with self.test_session() as sess: 48 itr = dataset.make_one_shot_iterator() 49 with self.assertRaises(errors.OutOfRangeError): 50 sess.run(itr.get_next()) 51 52 def testSimpleDirectory(self): 53 filenames = ['a', 'b', 'c'] 54 self._touchTempFiles(filenames) 55 56 dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) 57 with self.test_session() as sess: 58 itr = dataset.make_one_shot_iterator() 59 60 full_filenames = [] 61 produced_filenames = [] 62 for filename in filenames: 63 full_filenames.append( 64 compat.as_bytes(path.join(self.tmp_dir, filename))) 65 produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) 66 self.assertItemsEqual(full_filenames, produced_filenames) 67 with self.assertRaises(errors.OutOfRangeError): 68 sess.run(itr.get_next()) 69 70 def testEmptyDirectoryInitializer(self): 71 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 72 dataset = dataset_ops.Dataset.list_files(filename_placeholder) 73 74 with self.test_session() as sess: 75 itr = dataset.make_initializable_iterator() 76 sess.run( 77 itr.initializer, 78 feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) 79 80 with self.assertRaises(errors.OutOfRangeError): 81 sess.run(itr.get_next()) 82 83 def testSimpleDirectoryInitializer(self): 84 filenames = ['a', 'b', 'c'] 85 self._touchTempFiles(filenames) 86 87 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 88 dataset = dataset_ops.Dataset.list_files(filename_placeholder) 89 90 with self.test_session() as sess: 91 itr = dataset.make_initializable_iterator() 92 sess.run( 93 itr.initializer, 94 feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) 95 96 full_filenames = [] 97 produced_filenames = [] 98 for filename in filenames: 99 full_filenames.append( 100 compat.as_bytes(path.join(self.tmp_dir, filename))) 101 produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) 102 103 self.assertItemsEqual(full_filenames, produced_filenames) 104 105 with self.assertRaises(errors.OutOfRangeError): 106 sess.run(itr.get_next()) 107 108 def testFileSuffixes(self): 109 filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] 110 self._touchTempFiles(filenames) 111 112 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 113 dataset = dataset_ops.Dataset.list_files(filename_placeholder) 114 115 with self.test_session() as sess: 116 itr = dataset.make_initializable_iterator() 117 sess.run( 118 itr.initializer, 119 feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')}) 120 121 full_filenames = [] 122 produced_filenames = [] 123 for filename in filenames[1:-1]: 124 full_filenames.append( 125 compat.as_bytes(path.join(self.tmp_dir, filename))) 126 produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) 127 self.assertItemsEqual(full_filenames, produced_filenames) 128 129 with self.assertRaises(errors.OutOfRangeError): 130 sess.run(itr.get_next()) 131 132 def testFileMiddles(self): 133 filenames = ['a.txt', 'b.py', 'c.pyc'] 134 self._touchTempFiles(filenames) 135 136 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 137 dataset = dataset_ops.Dataset.list_files(filename_placeholder) 138 139 with self.test_session() as sess: 140 itr = dataset.make_initializable_iterator() 141 sess.run( 142 itr.initializer, 143 feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')}) 144 145 full_filenames = [] 146 produced_filenames = [] 147 for filename in filenames[1:]: 148 full_filenames.append( 149 compat.as_bytes(path.join(self.tmp_dir, filename))) 150 produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) 151 152 self.assertItemsEqual(full_filenames, produced_filenames) 153 154 with self.assertRaises(errors.OutOfRangeError): 155 sess.run(itr.get_next()) 156 157 158 if __name__ == '__main__': 159 test.main() 160