1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the private `MatchingFilesDataset`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import os 21 import shutil 22 import tempfile 23 24 from tensorflow.python.data.experimental.ops import matching_files 25 from tensorflow.python.data.kernel_tests import test_base 26 from tensorflow.python.framework import errors 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.platform import test 29 from tensorflow.python.util import compat 30 31 32 @test_util.run_all_in_graph_and_eager_modes 33 class MatchingFilesDatasetTest(test_base.DatasetTestBase): 34 35 def setUp(self): 36 self.tmp_dir = tempfile.mkdtemp() 37 38 def tearDown(self): 39 shutil.rmtree(self.tmp_dir, ignore_errors=True) 40 41 def _touchTempFiles(self, filenames): 42 for filename in filenames: 43 open(os.path.join(self.tmp_dir, filename), 'a').close() 44 45 def testNonExistingDirectory(self): 46 """Test the MatchingFiles dataset with a non-existing directory.""" 47 48 self.tmp_dir = os.path.join(self.tmp_dir, 'nonexistingdir') 49 dataset = matching_files.MatchingFilesDataset( 50 os.path.join(self.tmp_dir, '*')) 51 self.assertDatasetProduces( 52 dataset, expected_error=(errors.NotFoundError, '')) 53 54 def testEmptyDirectory(self): 55 """Test the MatchingFiles dataset with an empty directory.""" 56 57 dataset = matching_files.MatchingFilesDataset( 58 os.path.join(self.tmp_dir, '*')) 59 self.assertDatasetProduces( 60 dataset, expected_error=(errors.NotFoundError, '')) 61 62 def testSimpleDirectory(self): 63 """Test the MatchingFiles dataset with a simple directory.""" 64 65 filenames = ['a', 'b', 'c'] 66 self._touchTempFiles(filenames) 67 68 dataset = matching_files.MatchingFilesDataset( 69 os.path.join(self.tmp_dir, '*')) 70 self.assertDatasetProduces( 71 dataset, 72 expected_output=[ 73 compat.as_bytes(os.path.join(self.tmp_dir, filename)) 74 for filename in filenames 75 ], 76 assert_items_equal=True) 77 78 def testFileSuffixes(self): 79 """Test the MatchingFiles dataset using the suffixes of filename.""" 80 81 filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] 82 self._touchTempFiles(filenames) 83 84 dataset = matching_files.MatchingFilesDataset( 85 os.path.join(self.tmp_dir, '*.py')) 86 self.assertDatasetProduces( 87 dataset, 88 expected_output=[ 89 compat.as_bytes(os.path.join(self.tmp_dir, filename)) 90 for filename in filenames[1:-1] 91 ], 92 assert_items_equal=True) 93 94 def testFileMiddles(self): 95 """Test the MatchingFiles dataset using the middles of filename.""" 96 97 filenames = ['aa.txt', 'bb.py', 'bbc.pyc', 'cc.pyc'] 98 self._touchTempFiles(filenames) 99 100 dataset = matching_files.MatchingFilesDataset( 101 os.path.join(self.tmp_dir, 'b*.py*')) 102 self.assertDatasetProduces( 103 dataset, 104 expected_output=[ 105 compat.as_bytes(os.path.join(self.tmp_dir, filename)) 106 for filename in filenames[1:3] 107 ], 108 assert_items_equal=True) 109 110 def testNestedDirectories(self): 111 """Test the MatchingFiles dataset with nested directories.""" 112 113 filenames = [] 114 width = 8 115 depth = 4 116 for i in range(width): 117 for j in range(depth): 118 new_base = os.path.join(self.tmp_dir, str(i), 119 *[str(dir_name) for dir_name in range(j)]) 120 os.makedirs(new_base) 121 child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log'] 122 for f in child_files: 123 filename = os.path.join(new_base, f) 124 filenames.append(filename) 125 open(filename, 'w').close() 126 127 patterns = [ 128 os.path.join(self.tmp_dir, os.path.join(*['**' for _ in range(depth)]), 129 suffix) for suffix in ['*.txt', '*.log'] 130 ] 131 132 dataset = matching_files.MatchingFilesDataset(patterns) 133 next_element = self.getNext(dataset) 134 expected_filenames = [ 135 compat.as_bytes(filename) 136 for filename in filenames 137 if filename.endswith('.txt') or filename.endswith('.log') 138 ] 139 actual_filenames = [] 140 while True: 141 try: 142 actual_filenames.append(compat.as_bytes(self.evaluate(next_element()))) 143 except errors.OutOfRangeError: 144 break 145 146 self.assertItemsEqual(expected_filenames, actual_filenames) 147 148 149 if __name__ == '__main__': 150 test.main() 151