Home | History | Annotate | Download | only in python
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Smoke test for reading records from GCS to TensorFlow."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import random
     21 import sys
     22 import time
     23 
     24 import numpy as np
     25 import tensorflow as tf
     26 from tensorflow.core.example import example_pb2
     27 from tensorflow.python.lib.io import file_io
     28 
     29 flags = tf.app.flags
     30 flags.DEFINE_string("gcs_bucket_url", "",
     31                     "The URL to the GCS bucket in which the temporary "
     32                     "tfrecord file is to be written and read, e.g., "
     33                     "gs://my-gcs-bucket/test-directory")
     34 flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
     35 
     36 FLAGS = flags.FLAGS
     37 
     38 
     39 def create_examples(num_examples, input_mean):
     40   """Create ExampleProto's containing data."""
     41   ids = np.arange(num_examples).reshape([num_examples, 1])
     42   inputs = np.random.randn(num_examples, 1) + input_mean
     43   target = inputs - input_mean
     44   examples = []
     45   for row in range(num_examples):
     46     ex = example_pb2.Example()
     47     ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0]))
     48     ex.features.feature["target"].float_list.value.append(target[row, 0])
     49     ex.features.feature["inputs"].float_list.value.append(inputs[row, 0])
     50     examples.append(ex)
     51   return examples
     52 
     53 
     54 def create_dir_test():
     55   """Verifies file_io directory handling methods."""
     56 
     57   # Test directory creation.
     58   starttime_ms = int(round(time.time() * 1000))
     59   dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
     60   print("Creating dir %s" % dir_name)
     61   file_io.create_dir(dir_name)
     62   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
     63   print("Created directory in: %d milliseconds" % elapsed_ms)
     64 
     65   # Check that the directory exists.
     66   dir_exists = file_io.is_directory(dir_name)
     67   assert dir_exists
     68   print("%s directory exists: %s" % (dir_name, dir_exists))
     69 
     70   # Test recursive directory creation.
     71   starttime_ms = int(round(time.time() * 1000))
     72   recursive_dir_name = "%s/%s/%s" % (dir_name,
     73                                      "nested_dir1",
     74                                      "nested_dir2")
     75   print("Creating recursive dir %s" % recursive_dir_name)
     76   file_io.recursive_create_dir(recursive_dir_name)
     77   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
     78   print("Created directory recursively in: %d milliseconds" % elapsed_ms)
     79 
     80   # Check that the directory exists.
     81   recursive_dir_exists = file_io.is_directory(recursive_dir_name)
     82   assert recursive_dir_exists
     83   print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists))
     84 
     85   # Create some contents in the just created directory and list the contents.
     86   num_files = 10
     87   files_to_create = ["file_%d.txt" % n for n in range(num_files)]
     88   for file_num in files_to_create:
     89     file_name = "%s/%s" % (dir_name, file_num)
     90     print("Creating file %s." % file_name)
     91     file_io.write_string_to_file(file_name, "test file.")
     92 
     93   print("Listing directory %s." % dir_name)
     94   starttime_ms = int(round(time.time() * 1000))
     95   directory_contents = file_io.list_directory(dir_name)
     96   print(directory_contents)
     97   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
     98   print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms))
     99   assert set(directory_contents) == set(files_to_create + ["nested_dir1/"])
    100 
    101   # Test directory renaming.
    102   dir_to_rename = "%s/old_dir" % dir_name
    103   new_dir_name = "%s/new_dir" % dir_name
    104   file_io.create_dir(dir_to_rename)
    105   assert file_io.is_directory(dir_to_rename)
    106   assert not file_io.is_directory(new_dir_name)
    107 
    108   starttime_ms = int(round(time.time() * 1000))
    109   print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name))
    110   file_io.rename(dir_to_rename, new_dir_name)
    111   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
    112   print("Renamed directory %s to %s in %s milliseconds" % (
    113       dir_to_rename, new_dir_name, elapsed_ms))
    114   assert not file_io.is_directory(dir_to_rename)
    115   assert file_io.is_directory(new_dir_name)
    116 
    117   # Test Delete directory recursively.
    118   print("Deleting directory recursively %s." % dir_name)
    119   starttime_ms = int(round(time.time() * 1000))
    120   file_io.delete_recursively(dir_name)
    121   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
    122   dir_exists = file_io.is_directory(dir_name)
    123   assert not dir_exists
    124   print("Deleted directory recursively %s in %s milliseconds" % (
    125       dir_name, elapsed_ms))
    126 
    127 
    128 def create_object_test():
    129   """Verifies file_io's object manipulation methods ."""
    130   starttime_ms = int(round(time.time() * 1000))
    131   dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
    132   print("Creating dir %s." % dir_name)
    133   file_io.create_dir(dir_name)
    134 
    135   num_files = 5
    136   # Create files of 2 different patterns in this directory.
    137   files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n)
    138                      for n in range(num_files)]
    139   files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n)
    140                      for n in range(num_files)]
    141 
    142   starttime_ms = int(round(time.time() * 1000))
    143   files_to_create = files_pattern_1 + files_pattern_2
    144   for file_name in files_to_create:
    145     print("Creating file %s." % file_name)
    146     file_io.write_string_to_file(file_name, "test file creation.")
    147   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
    148   print("Created %d files in %s milliseconds" % (
    149       len(files_to_create), elapsed_ms))
    150 
    151   # Listing files of pattern1.
    152   list_files_pattern = "%s/test_file*.txt" % dir_name
    153   print("Getting files matching pattern %s." % list_files_pattern)
    154   starttime_ms = int(round(time.time() * 1000))
    155   files_list = file_io.get_matching_files(list_files_pattern)
    156   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
    157   print("Listed files in %s milliseconds" % elapsed_ms)
    158   print(files_list)
    159   assert set(files_list) == set(files_pattern_1)
    160 
    161   # Listing files of pattern2.
    162   list_files_pattern = "%s/testfile*.txt" % dir_name
    163   print("Getting files matching pattern %s." % list_files_pattern)
    164   starttime_ms = int(round(time.time() * 1000))
    165   files_list = file_io.get_matching_files(list_files_pattern)
    166   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
    167   print("Listed files in %s milliseconds" % elapsed_ms)
    168   print(files_list)
    169   assert set(files_list) == set(files_pattern_2)
    170 
    171   # Test renaming file.
    172   file_to_rename = "%s/oldname.txt" % dir_name
    173   file_new_name = "%s/newname.txt" % dir_name
    174   file_io.write_string_to_file(file_to_rename, "test file.")
    175   assert file_io.file_exists(file_to_rename)
    176   assert not file_io.file_exists(file_new_name)
    177 
    178   print("Will try renaming file %s to %s" % (file_to_rename, file_new_name))
    179   starttime_ms = int(round(time.time() * 1000))
    180   file_io.rename(file_to_rename, file_new_name)
    181   elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
    182   print("File %s renamed to %s in %s milliseconds" % (
    183       file_to_rename, file_new_name, elapsed_ms))
    184   assert not file_io.file_exists(file_to_rename)
    185   assert file_io.file_exists(file_new_name)
    186 
    187   # Delete directory.
    188   print("Deleting directory %s." % dir_name)
    189   file_io.delete_recursively(dir_name)
    190 
    191 
    192 def main(argv):
    193   del argv  # Unused.
    194 
    195   # Sanity check on the GCS bucket URL.
    196   if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
    197     print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
    198     sys.exit(1)
    199 
    200   # Generate random tfrecord path name.
    201   input_path = FLAGS.gcs_bucket_url + "/"
    202   input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8))
    203   input_path += ".tfrecord"
    204   print("Using input path: %s" % input_path)
    205 
    206   # Verify that writing to the records file in GCS works.
    207   print("\n=== Testing writing and reading of GCS record file... ===")
    208   example_data = create_examples(FLAGS.num_examples, 5)
    209   with tf.python_io.TFRecordWriter(input_path) as hf:
    210     for e in example_data:
    211       hf.write(e.SerializeToString())
    212 
    213     print("Data written to: %s" % input_path)
    214 
    215   # Verify that reading from the tfrecord file works and that
    216   # tf_record_iterator works.
    217   record_iter = tf.python_io.tf_record_iterator(input_path)
    218   read_count = 0
    219   for _ in record_iter:
    220     read_count += 1
    221   print("Read %d records using tf_record_iterator" % read_count)
    222 
    223   if read_count != FLAGS.num_examples:
    224     print("FAIL: The number of records read from tf_record_iterator (%d) "
    225           "differs from the expected number (%d)" % (read_count,
    226                                                      FLAGS.num_examples))
    227     sys.exit(1)
    228 
    229   # Verify that running the read op in a session works.
    230   print("\n=== Testing TFRecordReader.read op in a session... ===")
    231   with tf.Graph().as_default():
    232     filename_queue = tf.train.string_input_producer([input_path], num_epochs=1)
    233     reader = tf.TFRecordReader()
    234     _, serialized_example = reader.read(filename_queue)
    235 
    236     with tf.Session() as sess:
    237       sess.run(tf.global_variables_initializer())
    238       sess.run(tf.local_variables_initializer())
    239       tf.train.start_queue_runners()
    240       index = 0
    241       for _ in range(FLAGS.num_examples):
    242         print("Read record: %d" % index)
    243         sess.run(serialized_example)
    244         index += 1
    245 
    246       # Reading one more record should trigger an exception.
    247       try:
    248         sess.run(serialized_example)
    249         print("FAIL: Failed to catch the expected OutOfRangeError while "
    250               "reading one more record than is available")
    251         sys.exit(1)
    252       except tf.errors.OutOfRangeError:
    253         print("Successfully caught the expected OutOfRangeError while "
    254               "reading one more record than is available")
    255 
    256   create_dir_test()
    257   create_object_test()
    258 
    259 
    260 if __name__ == "__main__":
    261   tf.app.run(main)
    262