Home | History | Annotate | Download | only in reading_data
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Converts MNIST data to TFRecords file format with Example protos."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import argparse
     22 import os
     23 import sys
     24 
     25 import tensorflow as tf
     26 
     27 from tensorflow.contrib.learn.python.learn.datasets import mnist
     28 
     29 FLAGS = None
     30 
     31 
     32 def _int64_feature(value):
     33   return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
     34 
     35 
     36 def _bytes_feature(value):
     37   return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
     38 
     39 
     40 def convert_to(data_set, name):
     41   """Converts a dataset to tfrecords."""
     42   images = data_set.images
     43   labels = data_set.labels
     44   num_examples = data_set.num_examples
     45 
     46   if images.shape[0] != num_examples:
     47     raise ValueError('Images size %d does not match label size %d.' %
     48                      (images.shape[0], num_examples))
     49   rows = images.shape[1]
     50   cols = images.shape[2]
     51   depth = images.shape[3]
     52 
     53   filename = os.path.join(FLAGS.directory, name + '.tfrecords')
     54   print('Writing', filename)
     55   with tf.python_io.TFRecordWriter(filename) as writer:
     56     for index in range(num_examples):
     57       image_raw = images[index].tostring()
     58       example = tf.train.Example(
     59           features=tf.train.Features(
     60               feature={
     61                   'height': _int64_feature(rows),
     62                   'width': _int64_feature(cols),
     63                   'depth': _int64_feature(depth),
     64                   'label': _int64_feature(int(labels[index])),
     65                   'image_raw': _bytes_feature(image_raw)
     66               }))
     67       writer.write(example.SerializeToString())
     68 
     69 
     70 def main(unused_argv):
     71   # Get the data.
     72   data_sets = mnist.read_data_sets(FLAGS.directory,
     73                                    dtype=tf.uint8,
     74                                    reshape=False,
     75                                    validation_size=FLAGS.validation_size)
     76 
     77   # Convert to Examples and write the result to TFRecords.
     78   convert_to(data_sets.train, 'train')
     79   convert_to(data_sets.validation, 'validation')
     80   convert_to(data_sets.test, 'test')
     81 
     82 
     83 if __name__ == '__main__':
     84   parser = argparse.ArgumentParser()
     85   parser.add_argument(
     86       '--directory',
     87       type=str,
     88       default='/tmp/data',
     89       help='Directory to download data files and write the converted result'
     90   )
     91   parser.add_argument(
     92       '--validation_size',
     93       type=int,
     94       default=5000,
     95       help="""\
     96       Number of examples to separate from the training data for the validation
     97       set.\
     98       """
     99   )
    100   FLAGS, unparsed = parser.parse_known_args()
    101   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    102