Home | History | Annotate | Download | only in mnist
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 import tensorflow as tf
     22 from tensorflow.contrib.eager.python.examples.mnist import mnist
     23 
     24 
     25 def data_format():
     26   return "channels_first" if tf.test.is_gpu_available() else "channels_last"
     27 
     28 
     29 class MNISTGraphTest(tf.test.TestCase):
     30 
     31   def testTrainGraph(self):
     32     # The MNISTModel class can be executed eagerly (as in mnist.py and
     33     # mnist_test.py) and also be used to construct a TensorFlow graph, which is
     34     # then trained in a session.
     35     with tf.Graph().as_default():
     36       # Generate some random data.
     37       batch_size = 64
     38       images = np.random.randn(batch_size, 784).astype(np.float32)
     39       digits = np.random.randint(low=0, high=10, size=batch_size)
     40       labels = np.zeros((batch_size, 10))
     41       labels[np.arange(batch_size), digits] = 1.
     42 
     43       # Create a model, optimizer, and dataset as would be done
     44       # for eager execution as well.
     45       model = mnist.MNISTModel(data_format())
     46       optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
     47       dataset = tf.data.Dataset.from_tensors((images, labels))
     48 
     49       # Define the loss tensor (as opposed to a loss function when
     50       # using eager execution).
     51       (images, labels) = dataset.make_one_shot_iterator().get_next()
     52       predictions = model(images, training=True)
     53       loss = mnist.loss(predictions, labels)
     54 
     55       train_op = optimizer.minimize(loss)
     56       init = tf.global_variables_initializer()
     57       with tf.Session() as sess:
     58         # Variables have to be initialized in the session.
     59         sess.run(init)
     60         # Train using the optimizer.
     61         sess.run(train_op)
     62 
     63 
     64 if __name__ == "__main__":
     65   tf.test.main()
     66