Home | History | Annotate | Download | only in utils
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for data_utils."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from itertools import cycle
     22 import os
     23 import tarfile
     24 import threading
     25 import unittest
     26 import zipfile
     27 
     28 import numpy as np
     29 from six.moves.urllib.parse import urljoin
     30 from six.moves.urllib.request import pathname2url
     31 
     32 from tensorflow.python.keras._impl import keras
     33 from tensorflow.python.platform import test
     34 
     35 
     36 class TestGetFileAndValidateIt(test.TestCase):
     37 
     38   def test_get_file_and_validate_it(self):
     39     """Tests get_file from a url, plus extraction and validation.
     40     """
     41     dest_dir = self.get_temp_dir()
     42     orig_dir = self.get_temp_dir()
     43 
     44     text_file_path = os.path.join(orig_dir, 'test.txt')
     45     zip_file_path = os.path.join(orig_dir, 'test.zip')
     46     tar_file_path = os.path.join(orig_dir, 'test.tar.gz')
     47 
     48     with open(text_file_path, 'w') as text_file:
     49       text_file.write('Float like a butterfly, sting like a bee.')
     50 
     51     with tarfile.open(tar_file_path, 'w:gz') as tar_file:
     52       tar_file.add(text_file_path)
     53 
     54     with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
     55       zip_file.write(text_file_path)
     56 
     57     origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path)))
     58 
     59     path = keras.utils.data_utils.get_file('test.txt', origin,
     60                                            untar=True, cache_subdir=dest_dir)
     61     filepath = path + '.tar.gz'
     62     hashval_sha256 = keras.utils.data_utils._hash_file(filepath)
     63     hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5')
     64     path = keras.utils.data_utils.get_file(
     65         'test.txt', origin, md5_hash=hashval_md5,
     66         untar=True, cache_subdir=dest_dir)
     67     path = keras.utils.data_utils.get_file(
     68         filepath, origin, file_hash=hashval_sha256,
     69         extract=True, cache_subdir=dest_dir)
     70     self.assertTrue(os.path.exists(filepath))
     71     self.assertTrue(keras.utils.data_utils.validate_file(filepath,
     72                                                          hashval_sha256))
     73     self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5))
     74     os.remove(filepath)
     75 
     76     origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path)))
     77 
     78     hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path)
     79     hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path,
     80                                                     algorithm='md5')
     81     path = keras.utils.data_utils.get_file(
     82         'test', origin, md5_hash=hashval_md5,
     83         extract=True, cache_subdir=dest_dir)
     84     path = keras.utils.data_utils.get_file(
     85         'test', origin, file_hash=hashval_sha256,
     86         extract=True, cache_subdir=dest_dir)
     87     self.assertTrue(os.path.exists(path))
     88     self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256))
     89     self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5))
     90 
     91 
     92 class ThreadsafeIter(object):
     93 
     94   def __init__(self, it):
     95     self.it = it
     96     self.lock = threading.Lock()
     97 
     98   def __iter__(self):
     99     return self
    100 
    101   def __next__(self):
    102     return self.next()
    103 
    104   def next(self):
    105     with self.lock:
    106       return next(self.it)
    107 
    108 
    109 def threadsafe_generator(f):
    110 
    111   def g(*a, **kw):
    112     return ThreadsafeIter(f(*a, **kw))
    113 
    114   return g
    115 
    116 
    117 class TestSequence(keras.utils.data_utils.Sequence):
    118 
    119   def __init__(self, shape, value=1.):
    120     self.shape = shape
    121     self.inner = value
    122 
    123   def __getitem__(self, item):
    124     return np.ones(self.shape, dtype=np.uint32) * item * self.inner
    125 
    126   def __len__(self):
    127     return 100
    128 
    129   def on_epoch_end(self):
    130     self.inner *= 5.0
    131 
    132 
    133 class FaultSequence(keras.utils.data_utils.Sequence):
    134 
    135   def __getitem__(self, item):
    136     raise IndexError(item, 'item is not present')
    137 
    138   def __len__(self):
    139     return 100
    140 
    141 
    142 @threadsafe_generator
    143 def create_generator_from_sequence_threads(ds):
    144   for i in cycle(range(len(ds))):
    145     yield ds[i]
    146 
    147 
    148 def create_generator_from_sequence_pcs(ds):
    149   for i in cycle(range(len(ds))):
    150     yield ds[i]
    151 
    152 
    153 class TestEnqueuers(test.TestCase):
    154 
    155   def test_generator_enqueuer_threads(self):
    156     enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
    157         create_generator_from_sequence_threads(TestSequence([3, 200, 200, 3])),
    158         use_multiprocessing=False)
    159     enqueuer.start(3, 10)
    160     gen_output = enqueuer.get()
    161     acc = []
    162     for _ in range(100):
    163       acc.append(int(next(gen_output)[0, 0, 0, 0]))
    164 
    165     self.assertEqual(len(set(acc) - set(range(100))), 0)
    166     enqueuer.stop()
    167 
    168   @unittest.skipIf(
    169       os.name == 'nt',
    170       'use_multiprocessing=True does not work on windows properly.')
    171   def test_generator_enqueuer_processes(self):
    172     enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
    173         create_generator_from_sequence_pcs(TestSequence([3, 200, 200, 3])),
    174         use_multiprocessing=True)
    175     enqueuer.start(3, 10)
    176     gen_output = enqueuer.get()
    177     acc = []
    178     for _ in range(100):
    179       acc.append(int(next(gen_output)[0, 0, 0, 0]))
    180     self.assertNotEqual(acc, list(range(100)))
    181     enqueuer.stop()
    182 
    183   def test_generator_enqueuer_fail_threads(self):
    184     enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
    185         create_generator_from_sequence_threads(FaultSequence()),
    186         use_multiprocessing=False)
    187     enqueuer.start(3, 10)
    188     gen_output = enqueuer.get()
    189     with self.assertRaises(IndexError):
    190       next(gen_output)
    191 
    192   @unittest.skipIf(
    193       os.name == 'nt',
    194       'use_multiprocessing=True does not work on windows properly.')
    195   def test_generator_enqueuer_fail_processes(self):
    196     enqueuer = keras.utils.data_utils.GeneratorEnqueuer(
    197         create_generator_from_sequence_pcs(FaultSequence()),
    198         use_multiprocessing=True)
    199     enqueuer.start(3, 10)
    200     gen_output = enqueuer.get()
    201     with self.assertRaises(IndexError):
    202       next(gen_output)
    203 
    204   def test_ordered_enqueuer_threads(self):
    205     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    206         TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
    207     enqueuer.start(3, 10)
    208     gen_output = enqueuer.get()
    209     acc = []
    210     for _ in range(100):
    211       acc.append(next(gen_output)[0, 0, 0, 0])
    212     self.assertEqual(acc, list(range(100)))
    213     enqueuer.stop()
    214 
    215   def test_ordered_enqueuer_processes(self):
    216     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    217         TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
    218     enqueuer.start(3, 10)
    219     gen_output = enqueuer.get()
    220     acc = []
    221     for _ in range(100):
    222       acc.append(next(gen_output)[0, 0, 0, 0])
    223     self.assertEqual(acc, list(range(100)))
    224     enqueuer.stop()
    225 
    226   def test_ordered_enqueuer_fail_threads(self):
    227     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    228         FaultSequence(), use_multiprocessing=False)
    229     enqueuer.start(3, 10)
    230     gen_output = enqueuer.get()
    231     with self.assertRaises(StopIteration):
    232       next(gen_output)
    233 
    234   def test_ordered_enqueuer_fail_processes(self):
    235     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    236         FaultSequence(), use_multiprocessing=True)
    237     enqueuer.start(3, 10)
    238     gen_output = enqueuer.get()
    239     with self.assertRaises(StopIteration):
    240       next(gen_output)
    241 
    242   def test_on_epoch_end_processes(self):
    243     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    244         TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
    245     enqueuer.start(3, 10)
    246     gen_output = enqueuer.get()
    247     acc = []
    248     for _ in range(200):
    249       acc.append(next(gen_output)[0, 0, 0, 0])
    250     # Check that order was keep in GeneratorEnqueuer with processes
    251     self.assertEqual(acc[100:], list([k * 5 for k in range(100)]))
    252     enqueuer.stop()
    253 
    254   def test_context_switch(self):
    255     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    256         TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
    257     enqueuer2 = keras.utils.data_utils.OrderedEnqueuer(
    258         TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True)
    259     enqueuer.start(3, 10)
    260     enqueuer2.start(3, 10)
    261     gen_output = enqueuer.get()
    262     gen_output2 = enqueuer2.get()
    263     acc = []
    264     for _ in range(100):
    265       acc.append(next(gen_output)[0, 0, 0, 0])
    266     self.assertEqual(acc[-1], 99)
    267     # One epoch is completed so enqueuer will switch the Sequence
    268 
    269     acc = []
    270     for _ in range(100):
    271       acc.append(next(gen_output2)[0, 0, 0, 0])
    272     self.assertEqual(acc[-1], 99 * 15)
    273     # One epoch has been completed so enqueuer2 will switch
    274 
    275     # Be sure that both Sequence were updated
    276     self.assertEqual(next(gen_output)[0, 0, 0, 0], 0)
    277     self.assertEqual(next(gen_output)[0, 0, 0, 0], 5)
    278     self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0)
    279     self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5)
    280 
    281     # Tear down everything
    282     enqueuer.stop()
    283     enqueuer2.stop()
    284 
    285   def test_on_epoch_end_threads(self):
    286     enqueuer = keras.utils.data_utils.OrderedEnqueuer(
    287         TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
    288     enqueuer.start(3, 10)
    289     gen_output = enqueuer.get()
    290     acc = []
    291     for _ in range(100):
    292       acc.append(next(gen_output)[0, 0, 0, 0])
    293     acc = []
    294     for _ in range(100):
    295       acc.append(next(gen_output)[0, 0, 0, 0])
    296     # Check that order was keep in GeneratorEnqueuer with processes
    297     self.assertEqual(acc, list([k * 5 for k in range(100)]))
    298     enqueuer.stop()
    299 
    300 
    301 if __name__ == '__main__':
    302   # Bazel sets these environment variables to very long paths.
    303   # Tempfile uses them to create long paths, and in turn multiprocessing
    304   # library tries to create sockets named after paths. Delete whatever bazel
    305   # writes to these to avoid tests failing due to socket addresses being too
    306   # long.
    307   for var in ('TMPDIR', 'TMP', 'TEMP'):
    308     if var in os.environ:
    309       del os.environ[var]
    310 
    311   test.main()
    312