1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import tensorflow as tf 21 22 from tensorflow.contrib.eager.python import tfe 23 from tensorflow.contrib.eager.python.examples.rnn_colorbot import rnn_colorbot 24 25 26 LABEL_DIMENSION = 5 27 28 29 def device(): 30 return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" 31 32 33 def random_dataset(): 34 batch_size = 64 35 time_steps = 10 36 alphabet = 50 37 chars = tf.one_hot( 38 tf.random_uniform( 39 [batch_size, time_steps], minval=0, maxval=alphabet, dtype=tf.int32), 40 alphabet) 41 sequence_length = tf.constant( 42 [time_steps for _ in range(batch_size)], dtype=tf.int64) 43 labels = tf.random_normal([batch_size, LABEL_DIMENSION]) 44 return tf.data.Dataset.from_tensors((labels, chars, sequence_length)) 45 46 47 class RNNColorbotTest(tf.test.TestCase): 48 49 def testTrainOneEpoch(self): 50 model = rnn_colorbot.RNNColorbot( 51 rnn_cell_sizes=[256, 128, 64], 52 label_dimension=LABEL_DIMENSION, 53 keep_prob=1.0) 54 optimizer = tf.train.AdamOptimizer(learning_rate=.01) 55 dataset = random_dataset() 56 with tf.device(device()): 57 rnn_colorbot.train_one_epoch(model, optimizer, dataset) 58 59 def testTest(self): 60 model = rnn_colorbot.RNNColorbot( 61 rnn_cell_sizes=[256], 62 label_dimension=LABEL_DIMENSION, 63 keep_prob=1.0) 64 dataset = random_dataset() 65 with tf.device(device()): 66 rnn_colorbot.test(model, dataset) 67 68 69 if __name__ == "__main__": 70 tf.enable_eager_execution() 71 tf.test.main() 72