1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Converts MNIST data to TFRecords file format with Example protos.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import argparse 22 import os 23 import sys 24 25 import tensorflow as tf 26 27 from tensorflow.contrib.learn.python.learn.datasets import mnist 28 29 FLAGS = None 30 31 32 def _int64_feature(value): 33 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 34 35 36 def _bytes_feature(value): 37 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 38 39 40 def convert_to(data_set, name): 41 """Converts a dataset to tfrecords.""" 42 images = data_set.images 43 labels = data_set.labels 44 num_examples = data_set.num_examples 45 46 if images.shape[0] != num_examples: 47 raise ValueError('Images size %d does not match label size %d.' % 48 (images.shape[0], num_examples)) 49 rows = images.shape[1] 50 cols = images.shape[2] 51 depth = images.shape[3] 52 53 filename = os.path.join(FLAGS.directory, name + '.tfrecords') 54 print('Writing', filename) 55 with tf.python_io.TFRecordWriter(filename) as writer: 56 for index in range(num_examples): 57 image_raw = images[index].tostring() 58 example = tf.train.Example( 59 features=tf.train.Features( 60 feature={ 61 'height': _int64_feature(rows), 62 'width': _int64_feature(cols), 63 'depth': _int64_feature(depth), 64 'label': _int64_feature(int(labels[index])), 65 'image_raw': _bytes_feature(image_raw) 66 })) 67 writer.write(example.SerializeToString()) 68 69 70 def main(unused_argv): 71 # Get the data. 72 data_sets = mnist.read_data_sets(FLAGS.directory, 73 dtype=tf.uint8, 74 reshape=False, 75 validation_size=FLAGS.validation_size) 76 77 # Convert to Examples and write the result to TFRecords. 78 convert_to(data_sets.train, 'train') 79 convert_to(data_sets.validation, 'validation') 80 convert_to(data_sets.test, 'test') 81 82 83 if __name__ == '__main__': 84 parser = argparse.ArgumentParser() 85 parser.add_argument( 86 '--directory', 87 type=str, 88 default='/tmp/data', 89 help='Directory to download data files and write the converted result' 90 ) 91 parser.add_argument( 92 '--validation_size', 93 type=int, 94 default=5000, 95 help="""\ 96 Number of examples to separate from the training data for the validation 97 set.\ 98 """ 99 ) 100 FLAGS, unparsed = parser.parse_known_args() 101 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 102