Home | History | Annotate | Download | only in utils
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=g-import-not-at-top
     16 """Utilities for file download and caching."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from abc import abstractmethod
     22 from contextlib import closing
     23 import hashlib
     24 import multiprocessing
     25 from multiprocessing.pool import ThreadPool
     26 import os
     27 import random
     28 import shutil
     29 import sys
     30 import tarfile
     31 import threading
     32 import time
     33 import zipfile
     34 
     35 import numpy as np
     36 import six
     37 from six.moves.urllib.error import HTTPError
     38 from six.moves.urllib.error import URLError
     39 from six.moves.urllib.request import urlopen
     40 
     41 from tensorflow.python.keras.utils.generic_utils import Progbar
     42 from tensorflow.python.util import tf_inspect
     43 from tensorflow.python.util.tf_export import keras_export
     44 
     45 
     46 try:
     47   import queue
     48 except ImportError:
     49   import Queue as queue
     50 
     51 
     52 if sys.version_info[0] == 2:
     53 
     54   def urlretrieve(url, filename, reporthook=None, data=None):
     55     """Replacement for `urlretrive` for Python 2.
     56 
     57     Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
     58     `urllib` module, known to have issues with proxy management.
     59 
     60     Arguments:
     61         url: url to retrieve.
     62         filename: where to store the retrieved data locally.
     63         reporthook: a hook function that will be called once
     64             on establishment of the network connection and once
     65             after each block read thereafter.
     66             The hook will be passed three arguments;
     67             a count of blocks transferred so far,
     68             a block size in bytes, and the total size of the file.
     69         data: `data` argument passed to `urlopen`.
     70     """
     71 
     72     def chunk_read(response, chunk_size=8192, reporthook=None):
     73       content_type = response.info().get('Content-Length')
     74       total_size = -1
     75       if content_type is not None:
     76         total_size = int(content_type.strip())
     77       count = 0
     78       while True:
     79         chunk = response.read(chunk_size)
     80         count += 1
     81         if reporthook is not None:
     82           reporthook(count, chunk_size, total_size)
     83         if chunk:
     84           yield chunk
     85         else:
     86           break
     87 
     88     response = urlopen(url, data)
     89     with open(filename, 'wb') as fd:
     90       for chunk in chunk_read(response, reporthook=reporthook):
     91         fd.write(chunk)
     92 else:
     93   from six.moves.urllib.request import urlretrieve
     94 
     95 
     96 def is_generator_or_sequence(x):
     97   """Check if `x` is a Keras generator type."""
     98   return tf_inspect.isgenerator(x) or isinstance(x, Sequence)
     99 
    100 
    101 def _extract_archive(file_path, path='.', archive_format='auto'):
    102   """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
    103 
    104   Arguments:
    105       file_path: path to the archive file
    106       path: path to extract the archive file
    107       archive_format: Archive format to try for extracting the file.
    108           Options are 'auto', 'tar', 'zip', and None.
    109           'tar' includes tar, tar.gz, and tar.bz files.
    110           The default 'auto' is ['tar', 'zip'].
    111           None or an empty list will return no matches found.
    112 
    113   Returns:
    114       True if a match was found and an archive extraction was completed,
    115       False otherwise.
    116   """
    117   if archive_format is None:
    118     return False
    119   if archive_format == 'auto':
    120     archive_format = ['tar', 'zip']
    121   if isinstance(archive_format, six.string_types):
    122     archive_format = [archive_format]
    123 
    124   for archive_type in archive_format:
    125     if archive_type == 'tar':
    126       open_fn = tarfile.open
    127       is_match_fn = tarfile.is_tarfile
    128     if archive_type == 'zip':
    129       open_fn = zipfile.ZipFile
    130       is_match_fn = zipfile.is_zipfile
    131 
    132     if is_match_fn(file_path):
    133       with open_fn(file_path) as archive:
    134         try:
    135           archive.extractall(path)
    136         except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
    137           if os.path.exists(path):
    138             if os.path.isfile(path):
    139               os.remove(path)
    140             else:
    141               shutil.rmtree(path)
    142           raise
    143       return True
    144   return False
    145 
    146 
    147 @keras_export('keras.utils.get_file')
    148 def get_file(fname,
    149              origin,
    150              untar=False,
    151              md5_hash=None,
    152              file_hash=None,
    153              cache_subdir='datasets',
    154              hash_algorithm='auto',
    155              extract=False,
    156              archive_format='auto',
    157              cache_dir=None):
    158   """Downloads a file from a URL if it not already in the cache.
    159 
    160   By default the file at the url `origin` is downloaded to the
    161   cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
    162   and given the filename `fname`. The final location of a file
    163   `example.txt` would therefore be `~/.keras/datasets/example.txt`.
    164 
    165   Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
    166   Passing a hash will verify the file after download. The command line
    167   programs `shasum` and `sha256sum` can compute the hash.
    168 
    169   Arguments:
    170       fname: Name of the file. If an absolute path `/path/to/file.txt` is
    171           specified the file will be saved at that location.
    172       origin: Original URL of the file.
    173       untar: Deprecated in favor of 'extract'.
    174           boolean, whether the file should be decompressed
    175       md5_hash: Deprecated in favor of 'file_hash'.
    176           md5 hash of the file for verification
    177       file_hash: The expected hash string of the file after download.
    178           The sha256 and md5 hash algorithms are both supported.
    179       cache_subdir: Subdirectory under the Keras cache dir where the file is
    180           saved. If an absolute path `/path/to/folder` is
    181           specified the file will be saved at that location.
    182       hash_algorithm: Select the hash algorithm to verify the file.
    183           options are 'md5', 'sha256', and 'auto'.
    184           The default 'auto' detects the hash algorithm in use.
    185       extract: True tries extracting the file as an Archive, like tar or zip.
    186       archive_format: Archive format to try for extracting the file.
    187           Options are 'auto', 'tar', 'zip', and None.
    188           'tar' includes tar, tar.gz, and tar.bz files.
    189           The default 'auto' is ['tar', 'zip'].
    190           None or an empty list will return no matches found.
    191       cache_dir: Location to store cached files, when None it
    192           defaults to the [Keras
    193             Directory](/faq/#where-is-the-keras-configuration-filed-stored).
    194 
    195   Returns:
    196       Path to the downloaded file
    197   """
    198   if cache_dir is None:
    199     cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
    200   if md5_hash is not None and file_hash is None:
    201     file_hash = md5_hash
    202     hash_algorithm = 'md5'
    203   datadir_base = os.path.expanduser(cache_dir)
    204   if not os.access(datadir_base, os.W_OK):
    205     datadir_base = os.path.join('/tmp', '.keras')
    206   datadir = os.path.join(datadir_base, cache_subdir)
    207   if not os.path.exists(datadir):
    208     os.makedirs(datadir)
    209 
    210   if untar:
    211     untar_fpath = os.path.join(datadir, fname)
    212     fpath = untar_fpath + '.tar.gz'
    213   else:
    214     fpath = os.path.join(datadir, fname)
    215 
    216   download = False
    217   if os.path.exists(fpath):
    218     # File found; verify integrity if a hash was provided.
    219     if file_hash is not None:
    220       if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
    221         print('A local file was found, but it seems to be '
    222               'incomplete or outdated because the ' + hash_algorithm +
    223               ' file hash does not match the original value of ' + file_hash +
    224               ' so we will re-download the data.')
    225         download = True
    226   else:
    227     download = True
    228 
    229   if download:
    230     print('Downloading data from', origin)
    231 
    232     class ProgressTracker(object):
    233       # Maintain progbar for the lifetime of download.
    234       # This design was chosen for Python 2.7 compatibility.
    235       progbar = None
    236 
    237     def dl_progress(count, block_size, total_size):
    238       if ProgressTracker.progbar is None:
    239         if total_size == -1:
    240           total_size = None
    241         ProgressTracker.progbar = Progbar(total_size)
    242       else:
    243         ProgressTracker.progbar.update(count * block_size)
    244 
    245     error_msg = 'URL fetch failure on {}: {} -- {}'
    246     try:
    247       try:
    248         urlretrieve(origin, fpath, dl_progress)
    249       except HTTPError as e:
    250         raise Exception(error_msg.format(origin, e.code, e.msg))
    251       except URLError as e:
    252         raise Exception(error_msg.format(origin, e.errno, e.reason))
    253     except (Exception, KeyboardInterrupt) as e:
    254       if os.path.exists(fpath):
    255         os.remove(fpath)
    256       raise
    257     ProgressTracker.progbar = None
    258 
    259   if untar:
    260     if not os.path.exists(untar_fpath):
    261       _extract_archive(fpath, datadir, archive_format='tar')
    262     return untar_fpath
    263 
    264   if extract:
    265     _extract_archive(fpath, datadir, archive_format)
    266 
    267   return fpath
    268 
    269 
    270 def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
    271   """Calculates a file sha256 or md5 hash.
    272 
    273   Example:
    274 
    275   ```python
    276       >>> from keras.data_utils import _hash_file
    277       >>> _hash_file('/path/to/file.zip')
    278       'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
    279   ```
    280 
    281   Arguments:
    282       fpath: path to the file being validated
    283       algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
    284           The default 'auto' detects the hash algorithm in use.
    285       chunk_size: Bytes to read at a time, important for large files.
    286 
    287   Returns:
    288       The file hash
    289   """
    290   if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
    291     hasher = hashlib.sha256()
    292   else:
    293     hasher = hashlib.md5()
    294 
    295   with open(fpath, 'rb') as fpath_file:
    296     for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
    297       hasher.update(chunk)
    298 
    299   return hasher.hexdigest()
    300 
    301 
    302 def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
    303   """Validates a file against a sha256 or md5 hash.
    304 
    305   Arguments:
    306       fpath: path to the file being validated
    307       file_hash:  The expected hash string of the file.
    308           The sha256 and md5 hash algorithms are both supported.
    309       algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
    310           The default 'auto' detects the hash algorithm in use.
    311       chunk_size: Bytes to read at a time, important for large files.
    312 
    313   Returns:
    314       Whether the file is valid
    315   """
    316   if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
    317     hasher = 'sha256'
    318   else:
    319     hasher = 'md5'
    320 
    321   if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
    322     return True
    323   else:
    324     return False
    325 
    326 
    327 @keras_export('keras.utils.Sequence')
    328 class Sequence(object):
    329   """Base object for fitting to a sequence of data, such as a dataset.
    330 
    331   Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
    332   If you want to modify your dataset between epochs you may implement
    333   `on_epoch_end`.
    334   The method `__getitem__` should return a complete batch.
    335 
    336   Notes:
    337 
    338   `Sequence` are a safer way to do multiprocessing. This structure guarantees
    339   that the network will only train once
    340    on each sample per epoch which is not the case with generators.
    341 
    342   Examples:
    343 
    344   ```python
    345       from skimage.io import imread
    346       from skimage.transform import resize
    347       import numpy as np
    348       import math
    349 
    350       # Here, `x_set` is list of path to the images
    351       # and `y_set` are the associated classes.
    352 
    353       class CIFAR10Sequence(Sequence):
    354 
    355           def __init__(self, x_set, y_set, batch_size):
    356               self.x, self.y = x_set, y_set
    357               self.batch_size = batch_size
    358 
    359           def __len__(self):
    360               return math.ceil(len(self.x) / self.batch_size)
    361 
    362           def __getitem__(self, idx):
    363               batch_x = self.x[idx * self.batch_size:(idx + 1) *
    364               self.batch_size]
    365               batch_y = self.y[idx * self.batch_size:(idx + 1) *
    366               self.batch_size]
    367 
    368               return np.array([
    369                   resize(imread(file_name), (200, 200))
    370                      for file_name in batch_x]), np.array(batch_y)
    371   ```
    372   """
    373 
    374   @abstractmethod
    375   def __getitem__(self, index):
    376     """Gets batch at position `index`.
    377 
    378     Arguments:
    379         index: position of the batch in the Sequence.
    380 
    381     Returns:
    382         A batch
    383     """
    384     raise NotImplementedError
    385 
    386   @abstractmethod
    387   def __len__(self):
    388     """Number of batch in the Sequence.
    389 
    390     Returns:
    391         The number of batches in the Sequence.
    392     """
    393     raise NotImplementedError
    394 
    395   def on_epoch_end(self):
    396     """Method called at the end of every epoch.
    397     """
    398     pass
    399 
    400   def __iter__(self):
    401     """Create a generator that iterate over the Sequence."""
    402     for item in (self[i] for i in range(len(self))):
    403       yield item
    404 
    405 
    406 def iter_sequence_infinite(seq):
    407   """Iterates indefinitely over a Sequence.
    408 
    409   Arguments:
    410     seq: Sequence instance.
    411 
    412   Yields:
    413     Batches of data from the Sequence.
    414   """
    415   while True:
    416     for item in seq:
    417       yield item
    418 
    419 
    420 # Global variables to be shared across processes
    421 _SHARED_SEQUENCES = {}
    422 # We use a Value to provide unique id to different processes.
    423 _SEQUENCE_COUNTER = None
    424 
    425 
    426 def init_pool(seqs):
    427   global _SHARED_SEQUENCES
    428   _SHARED_SEQUENCES = seqs
    429 
    430 
    431 def get_index(uid, i):
    432   """Get the value from the Sequence `uid` at index `i`.
    433 
    434   To allow multiple Sequences to be used at the same time, we use `uid` to
    435   get a specific one. A single Sequence would cause the validation to
    436   overwrite the training Sequence.
    437 
    438   Arguments:
    439       uid: int, Sequence identifier
    440       i: index
    441 
    442   Returns:
    443       The value at index `i`.
    444   """
    445   return _SHARED_SEQUENCES[uid][i]
    446 
    447 
    448 @keras_export('keras.utils.SequenceEnqueuer')
    449 class SequenceEnqueuer(object):
    450   """Base class to enqueue inputs.
    451 
    452   The task of an Enqueuer is to use parallelism to speed up preprocessing.
    453   This is done with processes or threads.
    454 
    455   Example:
    456 
    457   ```python
    458       enqueuer = SequenceEnqueuer(...)
    459       enqueuer.start()
    460       datas = enqueuer.get()
    461       for data in datas:
    462           # Use the inputs; training, evaluating, predicting.
    463           # ... stop sometime.
    464       enqueuer.close()
    465   ```
    466 
    467   The `enqueuer.get()` should be an infinite stream of datas.
    468   """
    469 
    470   def __init__(self, sequence,
    471                use_multiprocessing=False):
    472     self.sequence = sequence
    473     self.use_multiprocessing = use_multiprocessing
    474 
    475     global _SEQUENCE_COUNTER
    476     if _SEQUENCE_COUNTER is None:
    477       try:
    478         _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
    479       except OSError:
    480         # In this case the OS does not allow us to use
    481         # multiprocessing. We resort to an int
    482         # for enqueuer indexing.
    483         _SEQUENCE_COUNTER = 0
    484 
    485     if isinstance(_SEQUENCE_COUNTER, int):
    486       self.uid = _SEQUENCE_COUNTER
    487       _SEQUENCE_COUNTER += 1
    488     else:
    489       # Doing Multiprocessing.Value += x is not process-safe.
    490       with _SEQUENCE_COUNTER.get_lock():
    491         self.uid = _SEQUENCE_COUNTER.value
    492         _SEQUENCE_COUNTER.value += 1
    493 
    494     self.workers = 0
    495     self.executor_fn = None
    496     self.queue = None
    497     self.run_thread = None
    498     self.stop_signal = None
    499 
    500   def is_running(self):
    501     return self.stop_signal is not None and not self.stop_signal.is_set()
    502 
    503   def start(self, workers=1, max_queue_size=10):
    504     """Starts the handler's workers.
    505 
    506     Arguments:
    507         workers: Number of workers.
    508         max_queue_size: queue size
    509             (when full, workers could block on `put()`)
    510     """
    511     if self.use_multiprocessing:
    512       self.executor_fn = self._get_executor_init(workers)
    513     else:
    514       # We do not need the init since it's threads.
    515       self.executor_fn = lambda _: ThreadPool(workers)
    516     self.workers = workers
    517     self.queue = queue.Queue(max_queue_size)
    518     self.stop_signal = threading.Event()
    519     self.run_thread = threading.Thread(target=self._run)
    520     self.run_thread.daemon = True
    521     self.run_thread.start()
    522 
    523   def _send_sequence(self):
    524     """Sends current Iterable to all workers."""
    525     # For new processes that may spawn
    526     _SHARED_SEQUENCES[self.uid] = self.sequence
    527 
    528   def stop(self, timeout=None):
    529     """Stops running threads and wait for them to exit, if necessary.
    530 
    531     Should be called by the same thread which called `start()`.
    532 
    533     Arguments:
    534         timeout: maximum time to wait on `thread.join()`
    535     """
    536     self.stop_signal.set()
    537     with self.queue.mutex:
    538       self.queue.queue.clear()
    539       self.queue.unfinished_tasks = 0
    540       self.queue.not_full.notify()
    541     self.run_thread.join(timeout)
    542     _SHARED_SEQUENCES[self.uid] = None
    543 
    544   @abstractmethod
    545   def _run(self):
    546     """Submits request to the executor and queue the `Future` objects."""
    547     raise NotImplementedError
    548 
    549   @abstractmethod
    550   def _get_executor_init(self, workers):
    551     """Gets the Pool initializer for multiprocessing.
    552 
    553     Arguments:
    554         workers: Number of workers.
    555 
    556     Returns:
    557         Function, a Function to initialize the pool
    558     """
    559     raise NotImplementedError
    560 
    561   @abstractmethod
    562   def get(self):
    563     """Creates a generator to extract data from the queue.
    564 
    565     Skip the data if it is `None`.
    566     # Returns
    567         Generator yielding tuples `(inputs, targets)`
    568             or `(inputs, targets, sample_weights)`.
    569     """
    570     raise NotImplementedError
    571 
    572 
    573 @keras_export('keras.utils.OrderedEnqueuer')
    574 class OrderedEnqueuer(SequenceEnqueuer):
    575   """Builds a Enqueuer from a Sequence.
    576 
    577   Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    578 
    579   Arguments:
    580       sequence: A `tf.keras.utils.data_utils.Sequence` object.
    581       use_multiprocessing: use multiprocessing if True, otherwise threading
    582       shuffle: whether to shuffle the data at the beginning of each epoch
    583   """
    584 
    585   def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
    586     super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
    587     self.shuffle = shuffle
    588 
    589   def _get_executor_init(self, workers):
    590     """Gets the Pool initializer for multiprocessing.
    591 
    592     Arguments:
    593         workers: Number of workers.
    594 
    595     Returns:
    596         Function, a Function to initialize the pool
    597     """
    598     def pool_fn(seqs):
    599       return multiprocessing.Pool(
    600           workers, initializer=init_pool_generator, initargs=(seqs, None))
    601 
    602     return pool_fn
    603 
    604   def _wait_queue(self):
    605     """Wait for the queue to be empty."""
    606     while True:
    607       time.sleep(0.1)
    608       if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
    609         return
    610 
    611   def _run(self):
    612     """Submits request to the executor and queue the `Future` objects."""
    613     sequence = list(range(len(self.sequence)))
    614     self._send_sequence()  # Share the initial sequence
    615     while True:
    616       if self.shuffle:
    617         random.shuffle(sequence)
    618 
    619       with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
    620         for i in sequence:
    621           if self.stop_signal.is_set():
    622             return
    623           self.queue.put(
    624               executor.apply_async(get_index, (self.uid, i)), block=True)
    625 
    626         # Done with the current epoch, waiting for the final batches
    627         self._wait_queue()
    628 
    629         if self.stop_signal.is_set():
    630           # We're done
    631           return
    632 
    633       # Call the internal on epoch end.
    634       self.sequence.on_epoch_end()
    635       self._send_sequence()  # Update the pool
    636 
    637   def get(self):
    638     """Creates a generator to extract data from the queue.
    639 
    640     Skip the data if it is `None`.
    641 
    642     Yields:
    643         The next element in the queue, i.e. a tuple
    644         `(inputs, targets)` or
    645         `(inputs, targets, sample_weights)`.
    646     """
    647     try:
    648       while self.is_running():
    649         inputs = self.queue.get(block=True).get()
    650         self.queue.task_done()
    651         if inputs is not None:
    652           yield inputs
    653     except Exception:  # pylint: disable=broad-except
    654       self.stop()
    655       six.reraise(*sys.exc_info())
    656 
    657 
    658 def init_pool_generator(gens, random_seed=None):
    659   global _SHARED_SEQUENCES
    660   _SHARED_SEQUENCES = gens
    661 
    662   if random_seed is not None:
    663     ident = multiprocessing.current_process().ident
    664     np.random.seed(random_seed + ident)
    665 
    666 
    667 def next_sample(uid):
    668   """Gets the next value from the generator `uid`.
    669 
    670   To allow multiple generators to be used at the same time, we use `uid` to
    671   get a specific one. A single generator would cause the validation to
    672   overwrite the training generator.
    673 
    674   Arguments:
    675       uid: int, generator identifier
    676 
    677   Returns:
    678       The next value of generator `uid`.
    679   """
    680   return six.next(_SHARED_SEQUENCES[uid])
    681 
    682 
    683 @keras_export('keras.utils.GeneratorEnqueuer')
    684 class GeneratorEnqueuer(SequenceEnqueuer):
    685   """Builds a queue out of a data generator.
    686 
    687   The provided generator can be finite in which case the class will throw
    688   a `StopIteration` exception.
    689 
    690   Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    691 
    692   Arguments:
    693       generator: a generator function which yields data
    694       use_multiprocessing: use multiprocessing if True, otherwise threading
    695       wait_time: time to sleep in-between calls to `put()`
    696       random_seed: Initial seed for workers,
    697           will be incremented by one for each worker.
    698   """
    699 
    700   def __init__(self, sequence,
    701                use_multiprocessing=False,
    702                random_seed=None):
    703     super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing)
    704     self.random_seed = random_seed
    705 
    706   def _get_executor_init(self, workers):
    707     """Gets the Pool initializer for multiprocessing.
    708 
    709     Arguments:
    710       workers: Number of works.
    711 
    712     Returns:
    713         A Function to initialize the pool
    714     """
    715     def pool_fn(seqs):
    716       return multiprocessing.Pool(workers,
    717                                   initializer=init_pool_generator,
    718                                   initargs=(seqs, self.random_seed))
    719     return pool_fn
    720 
    721   def _run(self):
    722     """Submits request to the executor and queue the `Future` objects."""
    723     self._send_sequence()  # Share the initial generator
    724     with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
    725       while True:
    726         if self.stop_signal.is_set():
    727           return
    728         self.queue.put(
    729             executor.apply_async(next_sample, (self.uid,)), block=True)
    730 
    731   def get(self):
    732     """Creates a generator to extract data from the queue.
    733 
    734     Skip the data if it is `None`.
    735 
    736     Yields:
    737         The next element in the queue, i.e. a tuple
    738         `(inputs, targets)` or
    739         `(inputs, targets, sample_weights)`.
    740     """
    741     try:
    742       while self.is_running():
    743         inputs = self.queue.get(block=True).get()
    744         self.queue.task_done()
    745         if inputs is not None:
    746           yield inputs
    747     except StopIteration:
    748       # Special case for finite generators
    749       last_ones = []
    750       while self.queue.qsize() > 0:
    751         last_ones.append(self.queue.get(block=True))
    752       # Wait for them to complete
    753       for f in last_ones:
    754         f.wait()
    755       # Keep the good ones
    756       last_ones = [future.get() for future in last_ones if future.successful()]
    757       for inputs in last_ones:
    758         if inputs is not None:
    759           yield inputs
    760     except Exception as e:  # pylint: disable=broad-except
    761       self.stop()
    762       if 'generator already executing' in str(e):
    763         raise RuntimeError(
    764             'Your generator is NOT thread-safe. '
    765             'Keras requires a thread-safe generator when '
    766             '`use_multiprocessing=False, workers > 1`. ')
    767       six.reraise(*sys.exc_info())
    768