1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 import tensorflow as tf 22 from tensorflow.contrib.eager.python.examples.mnist import mnist 23 24 25 def data_format(): 26 return "channels_first" if tf.test.is_gpu_available() else "channels_last" 27 28 29 class MNISTGraphTest(tf.test.TestCase): 30 31 def testTrainGraph(self): 32 # The MNISTModel class can be executed eagerly (as in mnist.py and 33 # mnist_test.py) and also be used to construct a TensorFlow graph, which is 34 # then trained in a session. 35 with tf.Graph().as_default(): 36 # Generate some random data. 37 batch_size = 64 38 images = np.random.randn(batch_size, 784).astype(np.float32) 39 digits = np.random.randint(low=0, high=10, size=batch_size) 40 labels = np.zeros((batch_size, 10)) 41 labels[np.arange(batch_size), digits] = 1. 42 43 # Create a model, optimizer, and dataset as would be done 44 # for eager execution as well. 45 model = mnist.MNISTModel(data_format()) 46 optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 47 dataset = tf.data.Dataset.from_tensors((images, labels)) 48 49 # Define the loss tensor (as opposed to a loss function when 50 # using eager execution). 51 (images, labels) = dataset.make_one_shot_iterator().get_next() 52 predictions = model(images, training=True) 53 loss = mnist.loss(predictions, labels) 54 55 train_op = optimizer.minimize(loss) 56 init = tf.global_variables_initializer() 57 with tf.Session() as sess: 58 # Variables have to be initialized in the session. 59 sess.run(init) 60 # Train using the optimizer. 61 sess.run(train_op) 62 63 64 if __name__ == "__main__": 65 tf.test.main() 66