Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the private `MatchingFilesDataset`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import os
     21 import shutil
     22 import tempfile
     23 
     24 from tensorflow.python.data.experimental.ops import matching_files
     25 from tensorflow.python.data.kernel_tests import test_base
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.platform import test
     29 from tensorflow.python.util import compat
     30 
     31 
     32 @test_util.run_all_in_graph_and_eager_modes
     33 class MatchingFilesDatasetTest(test_base.DatasetTestBase):
     34 
     35   def setUp(self):
     36     self.tmp_dir = tempfile.mkdtemp()
     37 
     38   def tearDown(self):
     39     shutil.rmtree(self.tmp_dir, ignore_errors=True)
     40 
     41   def _touchTempFiles(self, filenames):
     42     for filename in filenames:
     43       open(os.path.join(self.tmp_dir, filename), 'a').close()
     44 
     45   def testNonExistingDirectory(self):
     46     """Test the MatchingFiles dataset with a non-existing directory."""
     47 
     48     self.tmp_dir = os.path.join(self.tmp_dir, 'nonexistingdir')
     49     dataset = matching_files.MatchingFilesDataset(
     50         os.path.join(self.tmp_dir, '*'))
     51     self.assertDatasetProduces(
     52         dataset, expected_error=(errors.NotFoundError, ''))
     53 
     54   def testEmptyDirectory(self):
     55     """Test the MatchingFiles dataset with an empty directory."""
     56 
     57     dataset = matching_files.MatchingFilesDataset(
     58         os.path.join(self.tmp_dir, '*'))
     59     self.assertDatasetProduces(
     60         dataset, expected_error=(errors.NotFoundError, ''))
     61 
     62   def testSimpleDirectory(self):
     63     """Test the MatchingFiles dataset with a simple directory."""
     64 
     65     filenames = ['a', 'b', 'c']
     66     self._touchTempFiles(filenames)
     67 
     68     dataset = matching_files.MatchingFilesDataset(
     69         os.path.join(self.tmp_dir, '*'))
     70     self.assertDatasetProduces(
     71         dataset,
     72         expected_output=[
     73             compat.as_bytes(os.path.join(self.tmp_dir, filename))
     74             for filename in filenames
     75         ],
     76         assert_items_equal=True)
     77 
     78   def testFileSuffixes(self):
     79     """Test the MatchingFiles dataset using the suffixes of filename."""
     80 
     81     filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
     82     self._touchTempFiles(filenames)
     83 
     84     dataset = matching_files.MatchingFilesDataset(
     85         os.path.join(self.tmp_dir, '*.py'))
     86     self.assertDatasetProduces(
     87         dataset,
     88         expected_output=[
     89             compat.as_bytes(os.path.join(self.tmp_dir, filename))
     90             for filename in filenames[1:-1]
     91         ],
     92         assert_items_equal=True)
     93 
     94   def testFileMiddles(self):
     95     """Test the MatchingFiles dataset using the middles of filename."""
     96 
     97     filenames = ['aa.txt', 'bb.py', 'bbc.pyc', 'cc.pyc']
     98     self._touchTempFiles(filenames)
     99 
    100     dataset = matching_files.MatchingFilesDataset(
    101         os.path.join(self.tmp_dir, 'b*.py*'))
    102     self.assertDatasetProduces(
    103         dataset,
    104         expected_output=[
    105             compat.as_bytes(os.path.join(self.tmp_dir, filename))
    106             for filename in filenames[1:3]
    107         ],
    108         assert_items_equal=True)
    109 
    110   def testNestedDirectories(self):
    111     """Test the MatchingFiles dataset with nested directories."""
    112 
    113     filenames = []
    114     width = 8
    115     depth = 4
    116     for i in range(width):
    117       for j in range(depth):
    118         new_base = os.path.join(self.tmp_dir, str(i),
    119                                 *[str(dir_name) for dir_name in range(j)])
    120         os.makedirs(new_base)
    121         child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log']
    122         for f in child_files:
    123           filename = os.path.join(new_base, f)
    124           filenames.append(filename)
    125           open(filename, 'w').close()
    126 
    127     patterns = [
    128         os.path.join(self.tmp_dir, os.path.join(*['**' for _ in range(depth)]),
    129                      suffix) for suffix in ['*.txt', '*.log']
    130     ]
    131 
    132     dataset = matching_files.MatchingFilesDataset(patterns)
    133     next_element = self.getNext(dataset)
    134     expected_filenames = [
    135         compat.as_bytes(filename)
    136         for filename in filenames
    137         if filename.endswith('.txt') or filename.endswith('.log')
    138     ]
    139     actual_filenames = []
    140     while True:
    141       try:
    142         actual_filenames.append(compat.as_bytes(self.evaluate(next_element())))
    143       except errors.OutOfRangeError:
    144         break
    145 
    146     self.assertItemsEqual(expected_filenames, actual_filenames)
    147 
    148 
    149 if __name__ == '__main__':
    150   test.main()
    151