1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""TensorFlow Eager Execution Example: RNN Colorbot. 16 17 This example builds, trains, and evaluates a multi-layer RNN that can be 18 run with eager execution enabled. The RNN is trained to map color names to 19 their RGB values: it takes as input a one-hot encoded character sequence and 20 outputs a three-tuple (R, G, B) (scaled by 1/255). 21 22 For example, say we'd like the RNN Colorbot to generate the RGB values for the 23 color white. To represent our query in a form that the Colorbot could 24 understand, we would create a sequence of five 256-long vectors encoding the 25 ASCII values of the characters in "white". The first vector in our sequence 26 would be 0 everywhere except for the ord("w")-th position, where it would be 27 1, the second vector would be 0 everywhere except for the 28 ord("h")-th position, where it would be 1, and similarly for the remaining three 29 vectors. We refer to such indicator vectors as "one-hot encodings" of 30 characters. After consuming these vectors, a well-trained Colorbot would output 31 the three tuple (1, 1, 1), since the RGB values for white are (255, 255, 255). 32 We are of course free to ask the colorbot to generate colors for any string we'd 33 like, such as "steel gray," "tensorflow orange," or "green apple," though 34 your mileage may vary as your queries increase in creativity. 35 36 This example shows how to: 37 1. read, process, (one-hot) encode, and pad text data via the 38 Datasets API; 39 2. build a trainable model; 40 3. implement a multi-layer RNN using Python control flow 41 constructs (e.g., a for loop); 42 4. train a model using an iterative gradient-based method; and 43 44 The data used in this example is licensed under the Creative Commons 45 Attribution-ShareAlike License and is available at 46 https://en.wikipedia.org/wiki/List_of_colors:_A-F 47 https://en.wikipedia.org/wiki/List_of_colors:_G-M 48 https://en.wikipedia.org/wiki/List_of_colors:_N-Z 49 50 This example was adapted from 51 https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot 52 """ 53 54 from __future__ import absolute_import 55 from __future__ import division 56 from __future__ import print_function 57 58 import argparse 59 import functools 60 import os 61 import sys 62 import time 63 64 import six 65 import tensorflow as tf 66 67 from tensorflow.contrib.eager.python import tfe 68 69 try: 70 import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top 71 HAS_MATPLOTLIB = True 72 except ImportError: 73 HAS_MATPLOTLIB = False 74 75 76 def parse(line): 77 """Parse a line from the colors dataset.""" 78 79 # Each line of the dataset is comma-separated and formatted as 80 # color_name, r, g, b 81 # so `items` is a list [color_name, r, g, b]. 82 items = tf.string_split([line], ",").values 83 rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255. 84 # Represent the color name as a one-hot encoded character sequence. 85 color_name = items[0] 86 chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256) 87 # The sequence length is needed by our RNN. 88 length = tf.cast(tf.shape(chars)[0], dtype=tf.int64) 89 return rgb, chars, length 90 91 92 def load_dataset(data_dir, url, batch_size): 93 """Loads the colors data at path into a PaddedDataset.""" 94 95 # Downloads data at url into data_dir/basename(url). The dataset has a header 96 # row (color_name, r, g, b) followed by comma-separated lines. 97 path = tf.contrib.learn.datasets.base.maybe_download( 98 os.path.basename(url), data_dir, url) 99 100 # This chain of commands loads our data by: 101 # 1. skipping the header; (.skip(1)) 102 # 2. parsing the subsequent lines; (.map(parse)) 103 # 3. shuffling the data; (.shuffle(...)) 104 # 3. grouping the data into padded batches (.padded_batch(...)). 105 dataset = tf.data.TextLineDataset(path).skip(1).map(parse).shuffle( 106 buffer_size=10000).padded_batch( 107 batch_size, padded_shapes=([None], [None, None], [])) 108 return dataset 109 110 111 # pylint: disable=not-callable 112 class RNNColorbot(tfe.Network): 113 """Multi-layer (LSTM) RNN that regresses on real-valued vector labels. 114 """ 115 116 def __init__(self, rnn_cell_sizes, label_dimension, keep_prob): 117 """Constructs an RNNColorbot. 118 119 Args: 120 rnn_cell_sizes: list of integers denoting the size of each LSTM cell in 121 the RNN; rnn_cell_sizes[i] is the size of the i-th layer cell 122 label_dimension: the length of the labels on which to regress 123 keep_prob: (1 - dropout probability); dropout is applied to the outputs of 124 each LSTM layer 125 """ 126 super(RNNColorbot, self).__init__(name="") 127 self.label_dimension = label_dimension 128 self.keep_prob = keep_prob 129 130 # Note the calls to `track_layer` below; these calls register the layers as 131 # network components that house trainable variables. 132 self.cells = [ 133 self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size)) 134 for size in rnn_cell_sizes 135 ] 136 self.relu = self.track_layer( 137 tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu")) 138 139 def call(self, chars, sequence_length, training=False): 140 """Implements the RNN logic and prediction generation. 141 142 Args: 143 chars: a Tensor of dimension [batch_size, time_steps, 256] holding a 144 batch of one-hot encoded color names 145 sequence_length: a Tensor of dimension [batch_size] holding the length 146 of each character sequence (i.e., color name) 147 training: whether the invocation is happening during training 148 149 Returns: 150 A tensor of dimension [batch_size, label_dimension] that is produced by 151 passing chars through a multi-layer RNN and applying a ReLU to the final 152 hidden state. 153 """ 154 # Transpose the first and second dimensions so that chars is of shape 155 # [time_steps, batch_size, dimension]. 156 chars = tf.transpose(chars, [1, 0, 2]) 157 # The outer loop cycles through the layers of the RNN; the inner loop 158 # executes the time steps for a particular layer. 159 batch_size = int(chars.shape[1]) 160 for l in range(len(self.cells)): 161 cell = self.cells[l] 162 outputs = [] 163 state = cell.zero_state(batch_size, tf.float32) 164 # Unstack the inputs to obtain a list of batches, one for each time step. 165 chars = tf.unstack(chars, axis=0) 166 for ch in chars: 167 output, state = cell(ch, state) 168 outputs.append(output) 169 # The outputs of this layer are the inputs of the subsequent layer. 170 chars = tf.stack(outputs, axis=0) 171 if training: 172 chars = tf.nn.dropout(chars, self.keep_prob) 173 # Extract the correct output (i.e., hidden state) for each example. All the 174 # character sequences in this batch were padded to the same fixed length so 175 # that they could be easily fed through the above RNN loop. The 176 # `sequence_length` vector tells us the true lengths of the character 177 # sequences, letting us obtain for each sequence the hidden state that was 178 # generated by its non-padding characters. 179 batch_range = [i for i in range(batch_size)] 180 indices = tf.stack([sequence_length - 1, batch_range], axis=1) 181 hidden_states = tf.gather_nd(chars, indices) 182 return self.relu(hidden_states) 183 184 185 def loss(labels, predictions): 186 """Computes mean squared loss.""" 187 return tf.reduce_mean(tf.square(predictions - labels)) 188 189 190 def test(model, eval_data): 191 """Computes the average loss on eval_data, which should be a Dataset.""" 192 avg_loss = tfe.metrics.Mean("loss") 193 for (labels, chars, sequence_length) in tfe.Iterator(eval_data): 194 predictions = model(chars, sequence_length, training=False) 195 avg_loss(loss(labels, predictions)) 196 print("eval/loss: %.6f\n" % avg_loss.result()) 197 with tf.contrib.summary.always_record_summaries(): 198 tf.contrib.summary.scalar("loss", avg_loss.result()) 199 200 201 def train_one_epoch(model, optimizer, train_data, log_interval=10): 202 """Trains model on train_data using optimizer.""" 203 204 tf.train.get_or_create_global_step() 205 206 def model_loss(labels, chars, sequence_length): 207 predictions = model(chars, sequence_length, training=True) 208 loss_value = loss(labels, predictions) 209 tf.contrib.summary.scalar("loss", loss_value) 210 return loss_value 211 212 for (batch, (labels, chars, sequence_length)) in enumerate( 213 tfe.Iterator(train_data)): 214 with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): 215 batch_model_loss = functools.partial(model_loss, labels, chars, 216 sequence_length) 217 optimizer.minimize( 218 batch_model_loss, global_step=tf.train.get_global_step()) 219 if log_interval and batch % log_interval == 0: 220 print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss())) 221 222 223 SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv" 224 SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv" 225 226 227 def main(_): 228 data_dir = os.path.join(FLAGS.dir, "data") 229 train_data = load_dataset( 230 data_dir=data_dir, url=SOURCE_TRAIN_URL, batch_size=FLAGS.batch_size) 231 eval_data = load_dataset( 232 data_dir=data_dir, url=SOURCE_TEST_URL, batch_size=FLAGS.batch_size) 233 234 model = RNNColorbot( 235 rnn_cell_sizes=FLAGS.rnn_cell_sizes, 236 label_dimension=3, 237 keep_prob=FLAGS.keep_probability) 238 optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) 239 240 if FLAGS.no_gpu or tfe.num_gpus() <= 0: 241 print(tfe.num_gpus()) 242 device = "/cpu:0" 243 else: 244 device = "/gpu:0" 245 print("Using device %s." % device) 246 247 log_dir = os.path.join(FLAGS.dir, "summaries") 248 tf.gfile.MakeDirs(log_dir) 249 train_summary_writer = tf.contrib.summary.create_file_writer( 250 os.path.join(log_dir, "train"), flush_millis=10000) 251 test_summary_writer = tf.contrib.summary.create_file_writer( 252 os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") 253 254 with tf.device(device): 255 for epoch in range(FLAGS.num_epochs): 256 start = time.time() 257 with train_summary_writer.as_default(): 258 train_one_epoch(model, optimizer, train_data, FLAGS.log_interval) 259 end = time.time() 260 print("train/time for epoch #%d: %.2f" % (epoch, end - start)) 261 with test_summary_writer.as_default(): 262 test(model, eval_data) 263 264 print("Colorbot is ready to generate colors!") 265 while True: 266 try: 267 color_name = six.moves.input( 268 "Give me a color name (or press enter to exit): ") 269 except EOFError: 270 return 271 272 if not color_name: 273 return 274 275 _, chars, length = parse(color_name) 276 with tf.device(device): 277 (chars, length) = (tf.identity(chars), tf.identity(length)) 278 chars = tf.expand_dims(chars, 0) 279 length = tf.expand_dims(length, 0) 280 preds = tf.unstack(model(chars, length, training=False)[0]) 281 282 # Predictions cannot be negative, as they are generated by a ReLU layer; 283 # they may, however, be greater than 1. 284 clipped_preds = tuple(min(float(p), 1.0) for p in preds) 285 rgb = tuple(int(p * 255) for p in clipped_preds) 286 print("rgb:", rgb) 287 data = [[clipped_preds]] 288 if HAS_MATPLOTLIB: 289 plt.imshow(data) 290 plt.title(color_name) 291 plt.show() 292 293 294 if __name__ == "__main__": 295 parser = argparse.ArgumentParser() 296 parser.add_argument( 297 "--dir", 298 type=str, 299 default="/tmp/rnn_colorbot/", 300 help="Directory to download data files and save logs.") 301 parser.add_argument( 302 "--log_interval", 303 type=int, 304 default=10, 305 metavar="N", 306 help="Log training loss every log_interval batches.") 307 parser.add_argument( 308 "--num_epochs", type=int, default=20, help="Number of epochs to train.") 309 parser.add_argument( 310 "--rnn_cell_sizes", 311 type=int, 312 nargs="+", 313 default=[256, 128], 314 help="List of sizes for each layer of the RNN.") 315 parser.add_argument( 316 "--batch_size", 317 type=int, 318 default=64, 319 help="Batch size for training and eval.") 320 parser.add_argument( 321 "--keep_probability", 322 type=float, 323 default=0.5, 324 help="Keep probability for dropout between layers.") 325 parser.add_argument( 326 "--learning_rate", 327 type=float, 328 default=0.01, 329 help="Learning rate to be used during training.") 330 parser.add_argument( 331 "--no_gpu", 332 action="store_true", 333 default=False, 334 help="Disables GPU usage even if a GPU is available.") 335 336 FLAGS, unparsed = parser.parse_known_args() 337 tfe.run(main=main, argv=[sys.argv[0]] + unparsed) 338