1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Fashion-MNIST dataset. 16 """ 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gzip 22 import os 23 24 import numpy as np 25 26 from tensorflow.python.keras._impl.keras.utils.data_utils import get_file 27 28 29 def load_data(): 30 """Loads the Fashion-MNIST dataset. 31 32 Returns: 33 Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 34 """ 35 dirname = os.path.join('datasets', 'fashion-mnist') 36 base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' 37 files = [ 38 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 39 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz' 40 ] 41 42 paths = [] 43 for fname in files: 44 paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname)) 45 46 with gzip.open(paths[0], 'rb') as lbpath: 47 y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) 48 49 with gzip.open(paths[1], 'rb') as imgpath: 50 x_train = np.frombuffer( 51 imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) 52 53 with gzip.open(paths[2], 'rb') as lbpath: 54 y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) 55 56 with gzip.open(paths[3], 'rb') as imgpath: 57 x_test = np.frombuffer( 58 imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28) 59 60 return (x_train, y_train), (x_test, y_test) 61