Home | History | Annotate | Download | only in datasets
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Fashion-MNIST dataset.
     16 """
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gzip
     22 import os
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
     27 
     28 
     29 def load_data():
     30   """Loads the Fashion-MNIST dataset.
     31 
     32   Returns:
     33       Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
     34   """
     35   dirname = os.path.join('datasets', 'fashion-mnist')
     36   base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
     37   files = [
     38       'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
     39       't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
     40   ]
     41 
     42   paths = []
     43   for fname in files:
     44     paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
     45 
     46   with gzip.open(paths[0], 'rb') as lbpath:
     47     y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
     48 
     49   with gzip.open(paths[1], 'rb') as imgpath:
     50     x_train = np.frombuffer(
     51         imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
     52 
     53   with gzip.open(paths[2], 'rb') as lbpath:
     54     y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
     55 
     56   with gzip.open(paths[3], 'rb') as imgpath:
     57     x_test = np.frombuffer(
     58         imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
     59 
     60   return (x_train, y_train), (x_test, y_test)
     61