1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Penn Treebank RNN model definition compatible with eager execution. 16 17 Model similar to 18 https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb 19 20 Usage: python ./rnn_ptb.py --data-path=<path_to_dataset> 21 22 Penn Treebank (PTB) dataset from: 23 http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 24 """ 25 26 from __future__ import absolute_import 27 from __future__ import division 28 from __future__ import print_function 29 30 import argparse 31 import os 32 import sys 33 import time 34 35 import numpy as np 36 import tensorflow as tf 37 38 from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn 39 from tensorflow.contrib.eager.python import tfe 40 41 42 class RNN(tfe.Network): 43 """A static RNN. 44 45 Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer. 46 """ 47 48 def __init__(self, hidden_dim, num_layers, keep_ratio): 49 super(RNN, self).__init__() 50 self.keep_ratio = keep_ratio 51 for _ in range(num_layers): 52 self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)) 53 54 def call(self, input_seq, training): 55 batch_size = int(input_seq.shape[1]) 56 for c in self.layers: 57 state = c.zero_state(batch_size, tf.float32) 58 outputs = [] 59 input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0) 60 for inp in input_seq: 61 output, state = c(inp, state) 62 outputs.append(output) 63 64 input_seq = tf.stack(outputs, axis=0) 65 if training: 66 input_seq = tf.nn.dropout(input_seq, self.keep_ratio) 67 return input_seq, None 68 69 70 class Embedding(tf.layers.Layer): 71 """An Embedding layer.""" 72 73 def __init__(self, vocab_size, embedding_dim, **kwargs): 74 super(Embedding, self).__init__(**kwargs) 75 self.vocab_size = vocab_size 76 self.embedding_dim = embedding_dim 77 78 def build(self, _): 79 self.embedding = self.add_variable( 80 "embedding_kernel", 81 shape=[self.vocab_size, self.embedding_dim], 82 dtype=tf.float32, 83 initializer=tf.random_uniform_initializer(-0.1, 0.1), 84 trainable=True) 85 86 def call(self, x): 87 return tf.nn.embedding_lookup(self.embedding, x) 88 89 90 class PTBModel(tfe.Network): 91 """LSTM for word language modeling. 92 93 Model described in: 94 (Zaremba, et. al.) Recurrent Neural Network Regularization 95 http://arxiv.org/abs/1409.2329 96 97 See also: 98 https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb 99 """ 100 101 def __init__(self, 102 vocab_size, 103 embedding_dim, 104 hidden_dim, 105 num_layers, 106 dropout_ratio, 107 use_cudnn_rnn=True): 108 super(PTBModel, self).__init__() 109 110 self.keep_ratio = 1 - dropout_ratio 111 self.use_cudnn_rnn = use_cudnn_rnn 112 self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim)) 113 114 if self.use_cudnn_rnn: 115 self.rnn = cudnn_rnn.CudnnLSTM( 116 num_layers, hidden_dim, dropout=dropout_ratio) 117 else: 118 self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) 119 self.track_layer(self.rnn) 120 121 self.linear = self.track_layer( 122 tf.layers.Dense( 123 vocab_size, 124 kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))) 125 self._output_shape = [-1, embedding_dim] 126 127 def call(self, input_seq, training): 128 """Run the forward pass of PTBModel. 129 130 Args: 131 input_seq: [length, batch] shape int64 tensor. 132 training: Is this a training call. 133 Returns: 134 outputs tensors of inference. 135 """ 136 y = self.embedding(input_seq) 137 if training: 138 y = tf.nn.dropout(y, self.keep_ratio) 139 y, _ = self.rnn(y, training=training) 140 return self.linear(tf.reshape(y, self._output_shape)) 141 142 143 def clip_gradients(grads_and_vars, clip_ratio): 144 gradients, variables = zip(*grads_and_vars) 145 clipped, _ = tf.clip_by_global_norm(gradients, clip_ratio) 146 return zip(clipped, variables) 147 148 149 def loss_fn(model, inputs, targets, training): 150 labels = tf.reshape(targets, [-1]) 151 outputs = model(inputs, training) 152 return tf.reduce_mean( 153 tf.nn.sparse_softmax_cross_entropy_with_logits( 154 labels=labels, logits=outputs)) 155 156 157 def _divide_into_batches(data, batch_size): 158 """Convert a sequence to a batch of sequences.""" 159 nbatch = data.shape[0] // batch_size 160 data = data[:nbatch * batch_size] 161 data = data.reshape(batch_size, -1).transpose() 162 return data 163 164 165 def _get_batch(data, i, seq_len): 166 slen = min(seq_len, data.shape[0] - 1 - i) 167 inputs = data[i:i + slen, :] 168 target = data[i + 1:i + 1 + slen, :] 169 return tf.constant(inputs), tf.constant(target) 170 171 172 def evaluate(model, data): 173 """evaluate an epoch.""" 174 total_loss = 0.0 175 total_batches = 0 176 start = time.time() 177 for _, i in enumerate(range(0, data.shape[0] - 1, FLAGS.seq_len)): 178 inp, target = _get_batch(data, i, FLAGS.seq_len) 179 loss = loss_fn(model, inp, target, training=False) 180 total_loss += loss.numpy() 181 total_batches += 1 182 time_in_ms = (time.time() - start) * 1000 183 sys.stderr.write("eval loss %.2f (eval took %d ms)\n" % 184 (total_loss / total_batches, time_in_ms)) 185 return total_loss 186 187 188 def train(model, optimizer, train_data, sequence_length, clip_ratio): 189 """training an epoch.""" 190 191 def model_loss(inputs, targets): 192 return loss_fn(model, inputs, targets, training=True) 193 194 grads = tfe.implicit_gradients(model_loss) 195 196 total_time = 0 197 for batch, i in enumerate(range(0, train_data.shape[0] - 1, sequence_length)): 198 train_seq, train_target = _get_batch(train_data, i, sequence_length) 199 start = time.time() 200 optimizer.apply_gradients( 201 clip_gradients(grads(train_seq, train_target), clip_ratio)) 202 total_time += (time.time() - start) 203 if batch % 10 == 0: 204 time_in_ms = (total_time * 1000) / (batch + 1) 205 sys.stderr.write("batch %d: training loss %.2f, avg step time %d ms\n" % 206 (batch, model_loss(train_seq, train_target).numpy(), 207 time_in_ms)) 208 209 210 class Datasets(object): 211 """Processed form of the Penn Treebank dataset.""" 212 213 def __init__(self, path): 214 """Load the Penn Treebank dataset. 215 216 Args: 217 path: Path to the data/ directory of the dataset from Tomas Mikolov's 218 webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 219 """ 220 221 self.word2idx = {} # string -> integer id 222 self.idx2word = [] # integer id -> word string 223 # Files represented as a list of integer ids (as opposed to list of string 224 # words). 225 self.train = self.tokenize(os.path.join(path, "ptb.train.txt")) 226 self.valid = self.tokenize(os.path.join(path, "ptb.valid.txt")) 227 228 def vocab_size(self): 229 return len(self.idx2word) 230 231 def add(self, word): 232 if word not in self.word2idx: 233 self.idx2word.append(word) 234 self.word2idx[word] = len(self.idx2word) - 1 235 236 def tokenize(self, path): 237 """Read text file in path and return a list of integer token ids.""" 238 tokens = 0 239 with tf.gfile.Open(path, "r") as f: 240 for line in f: 241 words = line.split() + ["<eos>"] 242 tokens += len(words) 243 for word in words: 244 self.add(word) 245 246 # Tokenize file content 247 with tf.gfile.Open(path, "r") as f: 248 ids = np.zeros(tokens).astype(np.int64) 249 token = 0 250 for line in f: 251 words = line.split() + ["<eos>"] 252 for word in words: 253 ids[token] = self.word2idx[word] 254 token += 1 255 256 return ids 257 258 259 def small_model(use_cudnn_rnn): 260 """Returns a PTBModel with a 'small' configuration.""" 261 return PTBModel( 262 vocab_size=10000, 263 embedding_dim=200, 264 hidden_dim=200, 265 num_layers=2, 266 dropout_ratio=0., 267 use_cudnn_rnn=use_cudnn_rnn) 268 269 270 def large_model(use_cudnn_rnn): 271 """Returns a PTBModel with a 'large' configuration.""" 272 return PTBModel( 273 vocab_size=10000, 274 embedding_dim=650, 275 hidden_dim=650, 276 num_layers=2, 277 dropout_ratio=0.5, 278 use_cudnn_rnn=use_cudnn_rnn) 279 280 281 def test_model(use_cudnn_rnn): 282 """Returns a tiny PTBModel for unit tests.""" 283 return PTBModel( 284 vocab_size=100, 285 embedding_dim=20, 286 hidden_dim=20, 287 num_layers=2, 288 dropout_ratio=0., 289 use_cudnn_rnn=use_cudnn_rnn) 290 291 292 def main(_): 293 tfe.enable_eager_execution() 294 295 if not FLAGS.data_path: 296 raise ValueError("Must specify --data-path") 297 corpus = Datasets(FLAGS.data_path) 298 train_data = _divide_into_batches(corpus.train, FLAGS.batch_size) 299 eval_data = _divide_into_batches(corpus.valid, 10) 300 301 have_gpu = tfe.num_gpus() > 0 302 use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu 303 304 with tfe.restore_variables_on_create( 305 tf.train.latest_checkpoint(FLAGS.logdir)): 306 with tf.device("/device:GPU:0" if have_gpu else None): 307 # Make learning_rate a Variable so it can be included in the checkpoint 308 # and we can resume training with the last saved learning_rate. 309 learning_rate = tfe.Variable(20.0, name="learning_rate") 310 sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) 311 model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, 312 FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, 313 use_cudnn_rnn) 314 optimizer = tf.train.GradientDescentOptimizer(learning_rate) 315 316 best_loss = None 317 for _ in range(FLAGS.epoch): 318 train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) 319 eval_loss = evaluate(model, eval_data) 320 if not best_loss or eval_loss < best_loss: 321 if FLAGS.logdir: 322 tfe.Saver(model.trainable_weights + [learning_rate]).save( 323 os.path.join(FLAGS.logdir, "ckpt")) 324 best_loss = eval_loss 325 else: 326 learning_rate.assign(learning_rate / 4.0) 327 sys.stderr.write("eval_loss did not reduce in this epoch, " 328 "changing learning rate to %f for the next epoch\n" % 329 learning_rate.numpy()) 330 331 332 if __name__ == "__main__": 333 parser = argparse.ArgumentParser() 334 parser.add_argument( 335 "--data-path", 336 type=str, 337 default="", 338 help="Data directory of the Penn Treebank dataset from " 339 "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz") 340 parser.add_argument( 341 "--logdir", type=str, default="", help="Directory for checkpoint.") 342 parser.add_argument("--epoch", type=int, default=20, help="Number of epochs.") 343 parser.add_argument("--batch-size", type=int, default=20, help="Batch size.") 344 parser.add_argument( 345 "--seq-len", type=int, default=35, help="Sequence length.") 346 parser.add_argument( 347 "--embedding-dim", type=int, default=200, help="Embedding dimension.") 348 parser.add_argument( 349 "--hidden-dim", type=int, default=200, help="Hidden layer dimension.") 350 parser.add_argument( 351 "--num-layers", type=int, default=2, help="Number of RNN layers.") 352 parser.add_argument( 353 "--dropout", type=float, default=0.2, help="Drop out ratio.") 354 parser.add_argument( 355 "--clip", type=float, default=0.25, help="Gradient clipping ratio.") 356 parser.add_argument( 357 "--no-use-cudnn-rnn", 358 action="store_true", 359 default=False, 360 help="Disable the fast CuDNN RNN (when no gpu)") 361 362 FLAGS, unparsed = parser.parse_known_args() 363 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 364