Home | History | Annotate | Download | only in rnn_colorbot
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""TensorFlow Eager Execution Example: RNN Colorbot.
     16 
     17 This example builds, trains, and evaluates a multi-layer RNN that can be
     18 run with eager execution enabled. The RNN is trained to map color names to
     19 their RGB values: it takes as input a one-hot encoded character sequence and
     20 outputs a three-tuple (R, G, B) (scaled by 1/255).
     21 
     22 For example, say we'd like the RNN Colorbot to generate the RGB values for the
     23 color white. To represent our query in a form that the Colorbot could
     24 understand, we would create a sequence of five 256-long vectors encoding the
     25 ASCII values of the characters in "white". The first vector in our sequence
     26 would be 0 everywhere except for the ord("w")-th position, where it would be
     27 1, the second vector would be 0 everywhere except for the
     28 ord("h")-th position, where it would be 1, and similarly for the remaining three
     29 vectors. We refer to such indicator vectors as "one-hot encodings" of
     30 characters. After consuming these vectors, a well-trained Colorbot would output
     31 the three tuple (1, 1, 1), since the RGB values for white are (255, 255, 255).
     32 We are of course free to ask the colorbot to generate colors for any string we'd
     33 like, such as "steel gray," "tensorflow orange," or "green apple," though
     34 your mileage may vary as your queries increase in creativity.
     35 
     36 This example shows how to:
     37   1. read, process, (one-hot) encode, and pad text data via the
     38      Datasets API;
     39   2. build a trainable model;
     40   3. implement a multi-layer RNN using Python control flow
     41      constructs (e.g., a for loop);
     42   4. train a model using an iterative gradient-based method; and
     43 
     44 The data used in this example is licensed under the Creative Commons
     45 Attribution-ShareAlike License and is available at
     46   https://en.wikipedia.org/wiki/List_of_colors:_A-F
     47   https://en.wikipedia.org/wiki/List_of_colors:_G-M
     48   https://en.wikipedia.org/wiki/List_of_colors:_N-Z
     49 
     50 This example was adapted from
     51   https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
     52 """
     53 
     54 from __future__ import absolute_import
     55 from __future__ import division
     56 from __future__ import print_function
     57 
     58 import argparse
     59 import functools
     60 import os
     61 import sys
     62 import time
     63 
     64 import six
     65 import tensorflow as tf
     66 
     67 from tensorflow.contrib.eager.python import tfe
     68 
     69 try:
     70   import matplotlib.pyplot as plt  # pylint: disable=g-import-not-at-top
     71   HAS_MATPLOTLIB = True
     72 except ImportError:
     73   HAS_MATPLOTLIB = False
     74 
     75 
     76 def parse(line):
     77   """Parse a line from the colors dataset."""
     78 
     79   # Each line of the dataset is comma-separated and formatted as
     80   #    color_name, r, g, b
     81   # so `items` is a list [color_name, r, g, b].
     82   items = tf.string_split([line], ",").values
     83   rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.
     84   # Represent the color name as a one-hot encoded character sequence.
     85   color_name = items[0]
     86   chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)
     87   # The sequence length is needed by our RNN.
     88   length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)
     89   return rgb, chars, length
     90 
     91 
     92 def load_dataset(data_dir, url, batch_size):
     93   """Loads the colors data at path into a PaddedDataset."""
     94 
     95   # Downloads data at url into data_dir/basename(url). The dataset has a header
     96   # row (color_name, r, g, b) followed by comma-separated lines.
     97   path = tf.contrib.learn.datasets.base.maybe_download(
     98       os.path.basename(url), data_dir, url)
     99 
    100   # This chain of commands loads our data by:
    101   #   1. skipping the header; (.skip(1))
    102   #   2. parsing the subsequent lines; (.map(parse))
    103   #   3. shuffling the data; (.shuffle(...))
    104   #   3. grouping the data into padded batches (.padded_batch(...)).
    105   dataset = tf.data.TextLineDataset(path).skip(1).map(parse).shuffle(
    106       buffer_size=10000).padded_batch(
    107           batch_size, padded_shapes=([None], [None, None], []))
    108   return dataset
    109 
    110 
    111 # pylint: disable=not-callable
    112 class RNNColorbot(tfe.Network):
    113   """Multi-layer (LSTM) RNN that regresses on real-valued vector labels.
    114   """
    115 
    116   def __init__(self, rnn_cell_sizes, label_dimension, keep_prob):
    117     """Constructs an RNNColorbot.
    118 
    119     Args:
    120       rnn_cell_sizes: list of integers denoting the size of each LSTM cell in
    121         the RNN; rnn_cell_sizes[i] is the size of the i-th layer cell
    122       label_dimension: the length of the labels on which to regress
    123       keep_prob: (1 - dropout probability); dropout is applied to the outputs of
    124         each LSTM layer
    125     """
    126     super(RNNColorbot, self).__init__(name="")
    127     self.label_dimension = label_dimension
    128     self.keep_prob = keep_prob
    129 
    130     # Note the calls to `track_layer` below; these calls register the layers as
    131     # network components that house trainable variables.
    132     self.cells = [
    133         self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size))
    134         for size in rnn_cell_sizes
    135     ]
    136     self.relu = self.track_layer(
    137         tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu"))
    138 
    139   def call(self, chars, sequence_length, training=False):
    140     """Implements the RNN logic and prediction generation.
    141 
    142     Args:
    143       chars: a Tensor of dimension [batch_size, time_steps, 256] holding a
    144         batch of one-hot encoded color names
    145       sequence_length: a Tensor of dimension [batch_size] holding the length
    146         of each character sequence (i.e., color name)
    147       training: whether the invocation is happening during training
    148 
    149     Returns:
    150       A tensor of dimension [batch_size, label_dimension] that is produced by
    151       passing chars through a multi-layer RNN and applying a ReLU to the final
    152       hidden state.
    153     """
    154     # Transpose the first and second dimensions so that chars is of shape
    155     # [time_steps, batch_size, dimension].
    156     chars = tf.transpose(chars, [1, 0, 2])
    157     # The outer loop cycles through the layers of the RNN; the inner loop
    158     # executes the time steps for a particular layer.
    159     batch_size = int(chars.shape[1])
    160     for l in range(len(self.cells)):
    161       cell = self.cells[l]
    162       outputs = []
    163       state = cell.zero_state(batch_size, tf.float32)
    164       # Unstack the inputs to obtain a list of batches, one for each time step.
    165       chars = tf.unstack(chars, axis=0)
    166       for ch in chars:
    167         output, state = cell(ch, state)
    168         outputs.append(output)
    169       # The outputs of this layer are the inputs of the subsequent layer.
    170       chars = tf.stack(outputs, axis=0)
    171       if training:
    172         chars = tf.nn.dropout(chars, self.keep_prob)
    173     # Extract the correct output (i.e., hidden state) for each example. All the
    174     # character sequences in this batch were padded to the same fixed length so
    175     # that they could be easily fed through the above RNN loop. The
    176     # `sequence_length` vector tells us the true lengths of the character
    177     # sequences, letting us obtain for each sequence the hidden state that was
    178     # generated by its non-padding characters.
    179     batch_range = [i for i in range(batch_size)]
    180     indices = tf.stack([sequence_length - 1, batch_range], axis=1)
    181     hidden_states = tf.gather_nd(chars, indices)
    182     return self.relu(hidden_states)
    183 
    184 
    185 def loss(labels, predictions):
    186   """Computes mean squared loss."""
    187   return tf.reduce_mean(tf.square(predictions - labels))
    188 
    189 
    190 def test(model, eval_data):
    191   """Computes the average loss on eval_data, which should be a Dataset."""
    192   avg_loss = tfe.metrics.Mean("loss")
    193   for (labels, chars, sequence_length) in tfe.Iterator(eval_data):
    194     predictions = model(chars, sequence_length, training=False)
    195     avg_loss(loss(labels, predictions))
    196   print("eval/loss: %.6f\n" % avg_loss.result())
    197   with tf.contrib.summary.always_record_summaries():
    198     tf.contrib.summary.scalar("loss", avg_loss.result())
    199 
    200 
    201 def train_one_epoch(model, optimizer, train_data, log_interval=10):
    202   """Trains model on train_data using optimizer."""
    203 
    204   tf.train.get_or_create_global_step()
    205 
    206   def model_loss(labels, chars, sequence_length):
    207     predictions = model(chars, sequence_length, training=True)
    208     loss_value = loss(labels, predictions)
    209     tf.contrib.summary.scalar("loss", loss_value)
    210     return loss_value
    211 
    212   for (batch, (labels, chars, sequence_length)) in enumerate(
    213       tfe.Iterator(train_data)):
    214     with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
    215       batch_model_loss = functools.partial(model_loss, labels, chars,
    216                                            sequence_length)
    217       optimizer.minimize(
    218           batch_model_loss, global_step=tf.train.get_global_step())
    219       if log_interval and batch % log_interval == 0:
    220         print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss()))
    221 
    222 
    223 SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv"
    224 SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv"
    225 
    226 
    227 def main(_):
    228   data_dir = os.path.join(FLAGS.dir, "data")
    229   train_data = load_dataset(
    230       data_dir=data_dir, url=SOURCE_TRAIN_URL, batch_size=FLAGS.batch_size)
    231   eval_data = load_dataset(
    232       data_dir=data_dir, url=SOURCE_TEST_URL, batch_size=FLAGS.batch_size)
    233 
    234   model = RNNColorbot(
    235       rnn_cell_sizes=FLAGS.rnn_cell_sizes,
    236       label_dimension=3,
    237       keep_prob=FLAGS.keep_probability)
    238   optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    239 
    240   if FLAGS.no_gpu or tfe.num_gpus() <= 0:
    241     print(tfe.num_gpus())
    242     device = "/cpu:0"
    243   else:
    244     device = "/gpu:0"
    245   print("Using device %s." % device)
    246 
    247   log_dir = os.path.join(FLAGS.dir, "summaries")
    248   tf.gfile.MakeDirs(log_dir)
    249   train_summary_writer = tf.contrib.summary.create_file_writer(
    250       os.path.join(log_dir, "train"), flush_millis=10000)
    251   test_summary_writer = tf.contrib.summary.create_file_writer(
    252       os.path.join(log_dir, "eval"), flush_millis=10000, name="eval")
    253 
    254   with tf.device(device):
    255     for epoch in range(FLAGS.num_epochs):
    256       start = time.time()
    257       with train_summary_writer.as_default():
    258         train_one_epoch(model, optimizer, train_data, FLAGS.log_interval)
    259       end = time.time()
    260       print("train/time for epoch #%d: %.2f" % (epoch, end - start))
    261       with test_summary_writer.as_default():
    262         test(model, eval_data)
    263 
    264   print("Colorbot is ready to generate colors!")
    265   while True:
    266     try:
    267       color_name = six.moves.input(
    268           "Give me a color name (or press enter to exit): ")
    269     except EOFError:
    270       return
    271 
    272     if not color_name:
    273       return
    274 
    275     _, chars, length = parse(color_name)
    276     with tf.device(device):
    277       (chars, length) = (tf.identity(chars), tf.identity(length))
    278       chars = tf.expand_dims(chars, 0)
    279       length = tf.expand_dims(length, 0)
    280       preds = tf.unstack(model(chars, length, training=False)[0])
    281 
    282     # Predictions cannot be negative, as they are generated by a ReLU layer;
    283     # they may, however, be greater than 1.
    284     clipped_preds = tuple(min(float(p), 1.0) for p in preds)
    285     rgb = tuple(int(p * 255) for p in clipped_preds)
    286     print("rgb:", rgb)
    287     data = [[clipped_preds]]
    288     if HAS_MATPLOTLIB:
    289       plt.imshow(data)
    290       plt.title(color_name)
    291       plt.show()
    292 
    293 
    294 if __name__ == "__main__":
    295   parser = argparse.ArgumentParser()
    296   parser.add_argument(
    297       "--dir",
    298       type=str,
    299       default="/tmp/rnn_colorbot/",
    300       help="Directory to download data files and save logs.")
    301   parser.add_argument(
    302       "--log_interval",
    303       type=int,
    304       default=10,
    305       metavar="N",
    306       help="Log training loss every log_interval batches.")
    307   parser.add_argument(
    308       "--num_epochs", type=int, default=20, help="Number of epochs to train.")
    309   parser.add_argument(
    310       "--rnn_cell_sizes",
    311       type=int,
    312       nargs="+",
    313       default=[256, 128],
    314       help="List of sizes for each layer of the RNN.")
    315   parser.add_argument(
    316       "--batch_size",
    317       type=int,
    318       default=64,
    319       help="Batch size for training and eval.")
    320   parser.add_argument(
    321       "--keep_probability",
    322       type=float,
    323       default=0.5,
    324       help="Keep probability for dropout between layers.")
    325   parser.add_argument(
    326       "--learning_rate",
    327       type=float,
    328       default=0.01,
    329       help="Learning rate to be used during training.")
    330   parser.add_argument(
    331       "--no_gpu",
    332       action="store_true",
    333       default=False,
    334       help="Disables GPU usage even if a GPU is available.")
    335 
    336   FLAGS, unparsed = parser.parse_known_args()
    337   tfe.run(main=main, argv=[sys.argv[0]] + unparsed)
    338