Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from os import path
     22 import shutil
     23 import tempfile
     24 
     25 from tensorflow.python.data.ops import dataset_ops
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import errors
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.util import compat
     31 
     32 
     33 class ListFilesDatasetOpTest(test.TestCase):
     34 
     35   def setUp(self):
     36     self.tmp_dir = tempfile.mkdtemp()
     37 
     38   def tearDown(self):
     39     shutil.rmtree(self.tmp_dir, ignore_errors=True)
     40 
     41   def _touchTempFiles(self, filenames):
     42     for filename in filenames:
     43       open(path.join(self.tmp_dir, filename), 'a').close()
     44 
     45   def testEmptyDirectory(self):
     46     dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
     47     with self.test_session() as sess:
     48       itr = dataset.make_one_shot_iterator()
     49       with self.assertRaises(errors.OutOfRangeError):
     50         sess.run(itr.get_next())
     51 
     52   def testSimpleDirectory(self):
     53     filenames = ['a', 'b', 'c']
     54     self._touchTempFiles(filenames)
     55 
     56     dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
     57     with self.test_session() as sess:
     58       itr = dataset.make_one_shot_iterator()
     59 
     60       full_filenames = []
     61       produced_filenames = []
     62       for filename in filenames:
     63         full_filenames.append(
     64             compat.as_bytes(path.join(self.tmp_dir, filename)))
     65         produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
     66       self.assertItemsEqual(full_filenames, produced_filenames)
     67       with self.assertRaises(errors.OutOfRangeError):
     68         sess.run(itr.get_next())
     69 
     70   def testEmptyDirectoryInitializer(self):
     71     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     72     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
     73 
     74     with self.test_session() as sess:
     75       itr = dataset.make_initializable_iterator()
     76       sess.run(
     77           itr.initializer,
     78           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
     79 
     80       with self.assertRaises(errors.OutOfRangeError):
     81         sess.run(itr.get_next())
     82 
     83   def testSimpleDirectoryInitializer(self):
     84     filenames = ['a', 'b', 'c']
     85     self._touchTempFiles(filenames)
     86 
     87     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     88     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
     89 
     90     with self.test_session() as sess:
     91       itr = dataset.make_initializable_iterator()
     92       sess.run(
     93           itr.initializer,
     94           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
     95 
     96       full_filenames = []
     97       produced_filenames = []
     98       for filename in filenames:
     99         full_filenames.append(
    100             compat.as_bytes(path.join(self.tmp_dir, filename)))
    101         produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
    102 
    103       self.assertItemsEqual(full_filenames, produced_filenames)
    104 
    105       with self.assertRaises(errors.OutOfRangeError):
    106         sess.run(itr.get_next())
    107 
    108   def testFileSuffixes(self):
    109     filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
    110     self._touchTempFiles(filenames)
    111 
    112     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    113     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
    114 
    115     with self.test_session() as sess:
    116       itr = dataset.make_initializable_iterator()
    117       sess.run(
    118           itr.initializer,
    119           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')})
    120 
    121       full_filenames = []
    122       produced_filenames = []
    123       for filename in filenames[1:-1]:
    124         full_filenames.append(
    125             compat.as_bytes(path.join(self.tmp_dir, filename)))
    126         produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
    127       self.assertItemsEqual(full_filenames, produced_filenames)
    128 
    129       with self.assertRaises(errors.OutOfRangeError):
    130         sess.run(itr.get_next())
    131 
    132   def testFileMiddles(self):
    133     filenames = ['a.txt', 'b.py', 'c.pyc']
    134     self._touchTempFiles(filenames)
    135 
    136     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    137     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
    138 
    139     with self.test_session() as sess:
    140       itr = dataset.make_initializable_iterator()
    141       sess.run(
    142           itr.initializer,
    143           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})
    144 
    145       full_filenames = []
    146       produced_filenames = []
    147       for filename in filenames[1:]:
    148         full_filenames.append(
    149             compat.as_bytes(path.join(self.tmp_dir, filename)))
    150         produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
    151 
    152       self.assertItemsEqual(full_filenames, produced_filenames)
    153 
    154       with self.assertRaises(errors.OutOfRangeError):
    155         sess.run(itr.get_next())
    156 
    157 
    158 if __name__ == '__main__':
    159   test.main()
    160