1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import tensorflow as tf 21 22 import tensorflow.contrib.eager as tfe 23 from tensorflow.contrib.eager.python.examples.mnist import mnist 24 25 26 def device(): 27 return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" 28 29 30 def data_format(): 31 return "channels_first" if tfe.num_gpus() else "channels_last" 32 33 34 def random_dataset(): 35 batch_size = 64 36 images = tf.random_normal([batch_size, 784]) 37 digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32) 38 labels = tf.one_hot(digits, 10) 39 return tf.data.Dataset.from_tensors((images, labels)) 40 41 42 def train_one_epoch(defun=False): 43 model = mnist.MNISTModel(data_format()) 44 if defun: 45 model.call = tfe.defun(model.call) 46 optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) 47 dataset = random_dataset() 48 with tf.device(device()): 49 tf.train.get_or_create_global_step() 50 mnist.train_one_epoch(model, optimizer, dataset) 51 52 53 def evaluate(defun=False): 54 model = mnist.MNISTModel(data_format()) 55 dataset = random_dataset() 56 if defun: 57 model.call = tfe.defun(model.call) 58 with tf.device(device()): 59 tf.train.get_or_create_global_step() 60 mnist.test(model, dataset) 61 62 63 class MNISTTest(tf.test.TestCase): 64 65 def testTrainOneEpoch(self): 66 train_one_epoch(defun=False) 67 68 def testTest(self): 69 evaluate(defun=False) 70 71 def testTrainOneEpochWithDefunCall(self): 72 train_one_epoch(defun=True) 73 74 def testTestWithDefunCall(self): 75 evaluate(defun=True) 76 77 78 if __name__ == "__main__": 79 tfe.enable_eager_execution() 80 tf.test.main() 81