Home | History | Annotate | Download | only in rnn_colorbot
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import tensorflow as tf
     21 
     22 from tensorflow.contrib.eager.python import tfe
     23 from tensorflow.contrib.eager.python.examples.rnn_colorbot import rnn_colorbot
     24 
     25 
     26 LABEL_DIMENSION = 5
     27 
     28 
     29 def device():
     30   return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
     31 
     32 
     33 def random_dataset():
     34   batch_size = 64
     35   time_steps = 10
     36   alphabet = 50
     37   chars = tf.one_hot(
     38       tf.random_uniform(
     39           [batch_size, time_steps], minval=0, maxval=alphabet, dtype=tf.int32),
     40       alphabet)
     41   sequence_length = tf.constant(
     42       [time_steps for _ in range(batch_size)], dtype=tf.int64)
     43   labels = tf.random_normal([batch_size, LABEL_DIMENSION])
     44   return tf.data.Dataset.from_tensors((labels, chars, sequence_length))
     45 
     46 
     47 class RNNColorbotTest(tf.test.TestCase):
     48 
     49   def testTrainOneEpoch(self):
     50     model = rnn_colorbot.RNNColorbot(
     51         rnn_cell_sizes=[256, 128, 64],
     52         label_dimension=LABEL_DIMENSION,
     53         keep_prob=1.0)
     54     optimizer = tf.train.AdamOptimizer(learning_rate=.01)
     55     dataset = random_dataset()
     56     with tf.device(device()):
     57       rnn_colorbot.train_one_epoch(model, optimizer, dataset)
     58 
     59   def testTest(self):
     60     model = rnn_colorbot.RNNColorbot(
     61         rnn_cell_sizes=[256],
     62         label_dimension=LABEL_DIMENSION,
     63         keep_prob=1.0)
     64     dataset = random_dataset()
     65     with tf.device(device()):
     66       rnn_colorbot.test(model, dataset)
     67 
     68 
     69 if __name__ == "__main__":
     70   tf.enable_eager_execution()
     71   tf.test.main()
     72