Home | History | Annotate | Download | only in examples
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""Train a ConvNet on MNIST using K-FAC.
     16 
     17 This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
     18 following structure,
     19 
     20 - Conv Layer: 5x5 kernel, 16 output channels.
     21 - Max Pool: 3x3 kernel, stride 2.
     22 - Conv Layer: 5x5 kernel, 16 output channels.
     23 - Max Pool: 3x3 kernel, stride 2.
     24 - Linear: 10 output dims.
     25 
     26 After 3k~6k steps, this should reach perfect accuracy on the training set.
     27 """
     28 
     29 from __future__ import absolute_import
     30 from __future__ import division
     31 from __future__ import print_function
     32 
     33 import os
     34 
     35 import numpy as np
     36 import tensorflow as tf
     37 
     38 from tensorflow.contrib.kfac.examples import mlp
     39 from tensorflow.contrib.kfac.examples import mnist
     40 
     41 lc = tf.contrib.kfac.layer_collection
     42 oq = tf.contrib.kfac.op_queue
     43 opt = tf.contrib.kfac.optimizer
     44 
     45 __all__ = [
     46     "conv_layer",
     47     "max_pool_layer",
     48     "linear_layer",
     49     "build_model",
     50     "minimize_loss_single_machine",
     51     "minimize_loss_distributed",
     52     "train_mnist_single_machine",
     53     "train_mnist_distributed",
     54 ]
     55 
     56 
     57 def conv_layer(layer_id, inputs, kernel_size, out_channels):
     58   """Builds a convolutional layer with ReLU non-linearity.
     59 
     60   Args:
     61     layer_id: int. Integer ID for this layer's variables.
     62     inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
     63       corresponds to a single example.
     64     kernel_size: int. Width and height of the convolution kernel. The kernel is
     65       assumed to be square.
     66     out_channels: int. Number of output features per pixel.
     67 
     68   Returns:
     69     preactivations: Tensor of shape [num_examples, width, height, out_channels].
     70       Values of the layer immediately before the activation function.
     71     activations: Tensor of shape [num_examples, width, height, out_channels].
     72       Values of the layer immediately after the activation function.
     73     params: Tuple of (kernel, bias), parameters for this layer.
     74   """
     75   # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
     76   layer = tf.layers.Conv2D(
     77       out_channels,
     78       kernel_size=[kernel_size, kernel_size],
     79       kernel_initializer=tf.random_normal_initializer(stddev=0.01),
     80       padding="SAME",
     81       name="conv_%d" % layer_id)
     82   preactivations = layer(inputs)
     83   activations = tf.nn.relu(preactivations)
     84 
     85   # layer.weights is a list. This converts it a (hashable) tuple.
     86   return preactivations, activations, (layer.kernel, layer.bias)
     87 
     88 
     89 def max_pool_layer(layer_id, inputs, kernel_size, stride):
     90   """Build a max-pooling layer.
     91 
     92   Args:
     93     layer_id: int. Integer ID for this layer's variables.
     94     inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
     95       corresponds to a single example.
     96     kernel_size: int. Width and height to pool over per input channel. The
     97       kernel is assumed to be square.
     98     stride: int. Step size between pooling operations.
     99 
    100   Returns:
    101     Tensor of shape [num_examples, width/stride, height/stride, out_channels].
    102     Result of applying max pooling to 'inputs'.
    103   """
    104   # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
    105   with tf.variable_scope("pool_%d" % layer_id):
    106     return tf.nn.max_pool(
    107         inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
    108         padding="SAME",
    109         name="pool")
    110 
    111 
    112 def linear_layer(layer_id, inputs, output_size):
    113   """Builds the final linear layer for an MNIST classification problem.
    114 
    115   Args:
    116     layer_id: int. Integer ID for this layer's variables.
    117     inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
    118       corresponds to a single example.
    119     output_size: int. Number of output dims per example.
    120 
    121   Returns:
    122     activations: Tensor of shape [num_examples, output_size]. Values of the
    123       layer immediately after the activation function.
    124     params: Tuple of (weights, bias), parameters for this layer.
    125   """
    126   # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
    127   pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
    128   return pre, params
    129 
    130 
    131 def build_model(examples, labels, num_labels, layer_collection):
    132   """Builds a ConvNet classification model.
    133 
    134   Args:
    135     examples: Tensor of shape [num_examples, num_features]. Represents inputs of
    136       model.
    137     labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
    138       by softmax for each example.
    139     num_labels: int. Number of distinct values 'labels' can take on.
    140     layer_collection: LayerCollection instance. Layers will be registered here.
    141 
    142   Returns:
    143     loss: 0-D Tensor representing loss to be minimized.
    144     accuracy: 0-D Tensor representing model's accuracy.
    145   """
    146   # Build a ConvNet. For each layer with parameters, we'll keep track of the
    147   # preactivations, activations, weights, and bias.
    148   tf.logging.info("Building model.")
    149   pre0, act0, params0 = conv_layer(
    150       layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
    151   act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
    152   pre2, act2, params2 = conv_layer(
    153       layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
    154   act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
    155   flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
    156   logits, params4 = linear_layer(
    157       layer_id=4, inputs=flat_act3, output_size=num_labels)
    158   loss = tf.reduce_mean(
    159       tf.nn.sparse_softmax_cross_entropy_with_logits(
    160           labels=labels, logits=logits))
    161   accuracy = tf.reduce_mean(
    162       tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
    163 
    164   tf.summary.scalar("loss", loss)
    165   tf.summary.scalar("accuracy", accuracy)
    166 
    167   # Register parameters. K-FAC needs to know about the inputs, outputs, and
    168   # parameters of each conv/fully connected layer and the logits powering the
    169   # posterior probability over classes.
    170   tf.logging.info("Building LayerCollection.")
    171   layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
    172                                    pre0)
    173   layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
    174   layer_collection.register_fully_connected(params4, flat_act3, logits)
    175   layer_collection.register_categorical_predictive_distribution(
    176       logits, name="logits")
    177 
    178   return loss, accuracy
    179 
    180 
    181 def minimize_loss_single_machine(loss,
    182                                  accuracy,
    183                                  layer_collection,
    184                                  session_config=None):
    185   """Minimize loss with K-FAC on a single machine.
    186 
    187   A single Session is responsible for running all of K-FAC's ops.
    188 
    189   Args:
    190     loss: 0-D Tensor. Loss to be minimized.
    191     accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
    192     layer_collection: LayerCollection instance describing model architecture.
    193       Used by K-FAC to construct preconditioner.
    194     session_config: None or tf.ConfigProto. Configuration for tf.Session().
    195 
    196   Returns:
    197     final value for 'accuracy'.
    198   """
    199   # Train with K-FAC.
    200   global_step = tf.train.get_or_create_global_step()
    201   optimizer = opt.KfacOptimizer(
    202       learning_rate=0.0001,
    203       cov_ema_decay=0.95,
    204       damping=0.001,
    205       layer_collection=layer_collection,
    206       momentum=0.9)
    207   train_op = optimizer.minimize(loss, global_step=global_step)
    208 
    209   tf.logging.info("Starting training.")
    210   with tf.train.MonitoredTrainingSession(config=session_config) as sess:
    211     while not sess.should_stop():
    212       global_step_, loss_, accuracy_, _, _ = sess.run(
    213           [global_step, loss, accuracy, train_op, optimizer.cov_update_op])
    214 
    215       if global_step_ % 100 == 0:
    216         sess.run(optimizer.inv_update_op)
    217 
    218       if global_step_ % 100 == 0:
    219         tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
    220                         global_step_, loss_, accuracy_)
    221 
    222   return accuracy_
    223 
    224 
    225 def _is_gradient_task(task_id, num_tasks):
    226   """Returns True if this task should update the weights."""
    227   if num_tasks < 3:
    228     return True
    229   return 0 <= task_id < 0.6 * num_tasks
    230 
    231 
    232 def _is_cov_update_task(task_id, num_tasks):
    233   """Returns True if this task should update K-FAC's covariance matrices."""
    234   if num_tasks < 3:
    235     return False
    236   return 0.6 * num_tasks <= task_id < num_tasks - 1
    237 
    238 
    239 def _is_inv_update_task(task_id, num_tasks):
    240   """Returns True if this task should update K-FAC's preconditioner."""
    241   if num_tasks < 3:
    242     return False
    243   return task_id == num_tasks - 1
    244 
    245 
    246 def _num_gradient_tasks(num_tasks):
    247   """Number of tasks that will update weights."""
    248   if num_tasks < 3:
    249     return num_tasks
    250   return int(np.ceil(0.6 * num_tasks))
    251 
    252 
    253 def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
    254                               checkpoint_dir, loss, accuracy, layer_collection):
    255   """Minimize loss with an synchronous implementation of K-FAC.
    256 
    257   Different tasks are responsible for different parts of K-FAC's Ops. The first
    258   60% of tasks update weights; the next 20% accumulate covariance statistics;
    259   the last 20% invert the matrices used to precondition gradients.
    260 
    261   Args:
    262     task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
    263     num_worker_tasks: int. Number of workers in this distributed training setup.
    264     num_ps_tasks: int. Number of parameter servers holding variables. If 0,
    265       parameter servers are not used.
    266     master: string. IP and port of TensorFlow runtime process. Set to empty
    267       string to run locally.
    268     checkpoint_dir: string or None. Path to store checkpoints under.
    269     loss: 0-D Tensor. Loss to be minimized.
    270     accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
    271       run with each step.
    272     layer_collection: LayerCollection instance describing model architecture.
    273       Used by K-FAC to construct preconditioner.
    274 
    275   Returns:
    276     final value for 'accuracy'.
    277 
    278   Raises:
    279     ValueError: if task_id >= num_worker_tasks.
    280   """
    281   with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
    282     global_step = tf.train.get_or_create_global_step()
    283     optimizer = opt.KfacOptimizer(
    284         learning_rate=0.0001,
    285         cov_ema_decay=0.95,
    286         damping=0.001,
    287         layer_collection=layer_collection,
    288         momentum=0.9)
    289     inv_update_queue = oq.OpQueue(optimizer.inv_update_ops)
    290     sync_optimizer = tf.train.SyncReplicasOptimizer(
    291         opt=optimizer,
    292         replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
    293     train_op = sync_optimizer.minimize(loss, global_step=global_step)
    294 
    295   tf.logging.info("Starting training.")
    296   is_chief = (task_id == 0)
    297   hooks = [sync_optimizer.make_session_run_hook(is_chief)]
    298   with tf.train.MonitoredTrainingSession(
    299       master=master,
    300       is_chief=is_chief,
    301       checkpoint_dir=checkpoint_dir,
    302       hooks=hooks,
    303       stop_grace_period_secs=0) as sess:
    304     while not sess.should_stop():
    305       # Choose which op this task is responsible for running.
    306       if _is_gradient_task(task_id, num_worker_tasks):
    307         learning_op = train_op
    308       elif _is_cov_update_task(task_id, num_worker_tasks):
    309         learning_op = optimizer.cov_update_op
    310       elif _is_inv_update_task(task_id, num_worker_tasks):
    311         # TODO(duckworthd): Running this op before cov_update_op has been run a
    312         # few times can result in "InvalidArgumentError: Cholesky decomposition
    313         # was not successful." Delay running this op until cov_update_op has
    314         # been run a few times.
    315         learning_op = inv_update_queue.next_op(sess)
    316       else:
    317         raise ValueError("Which op should task %d do?" % task_id)
    318 
    319       global_step_, loss_, accuracy_, _ = sess.run(
    320           [global_step, loss, accuracy, learning_op])
    321       tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
    322                       loss_, accuracy_)
    323 
    324   return accuracy_
    325 
    326 
    327 def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
    328   """Train a ConvNet on MNIST.
    329 
    330   Args:
    331     data_dir: string. Directory to read MNIST examples from.
    332     num_epochs: int. Number of passes to make over the training set.
    333     use_fake_data: bool. If True, generate a synthetic dataset.
    334 
    335   Returns:
    336     accuracy of model on the final minibatch of training data.
    337   """
    338   # Load a dataset.
    339   tf.logging.info("Loading MNIST into memory.")
    340   examples, labels = mnist.load_mnist(
    341       data_dir,
    342       num_epochs=num_epochs,
    343       batch_size=128,
    344       use_fake_data=use_fake_data,
    345       flatten_images=False)
    346 
    347   # Build a ConvNet.
    348   layer_collection = lc.LayerCollection()
    349   loss, accuracy = build_model(
    350       examples, labels, num_labels=10, layer_collection=layer_collection)
    351 
    352   # Fit model.
    353   return minimize_loss_single_machine(loss, accuracy, layer_collection)
    354 
    355 
    356 def train_mnist_multitower(data_dir, num_epochs, num_towers,
    357                            use_fake_data=True):
    358   """Train a ConvNet on MNIST.
    359 
    360   Args:
    361     data_dir: string. Directory to read MNIST examples from.
    362     num_epochs: int. Number of passes to make over the training set.
    363     num_towers: int. Number of CPUs to split inference across.
    364     use_fake_data: bool. If True, generate a synthetic dataset.
    365 
    366   Returns:
    367     accuracy of model on the final minibatch of training data.
    368   """
    369   # Load a dataset.
    370   tf.logging.info("Loading MNIST into memory.")
    371   tower_batch_size = 128
    372   batch_size = tower_batch_size * num_towers
    373   tf.logging.info(
    374       ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
    375        "tower batch size.") % (batch_size, num_towers, tower_batch_size))
    376   examples, labels = mnist.load_mnist(
    377       data_dir,
    378       num_epochs=num_epochs,
    379       batch_size=batch_size,
    380       use_fake_data=use_fake_data,
    381       flatten_images=False)
    382 
    383   # Split minibatch across towers.
    384   examples = tf.split(examples, num_towers)
    385   labels = tf.split(labels, num_towers)
    386 
    387   # Build an MLP. Each tower's layers will be added to the LayerCollection.
    388   layer_collection = lc.LayerCollection()
    389   tower_results = []
    390   for tower_id in range(num_towers):
    391     with tf.device("/cpu:%d" % tower_id):
    392       with tf.name_scope("tower%d" % tower_id):
    393         with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
    394           tf.logging.info("Building tower %d." % tower_id)
    395           tower_results.append(
    396               build_model(examples[tower_id], labels[tower_id], 10,
    397                           layer_collection))
    398   losses, accuracies = zip(*tower_results)
    399 
    400   # Average across towers.
    401   loss = tf.reduce_mean(losses)
    402   accuracy = tf.reduce_mean(accuracies)
    403 
    404   # Fit model.
    405   session_config = tf.ConfigProto(
    406       allow_soft_placement=False, device_count={
    407           "CPU": num_towers
    408       })
    409   return minimize_loss_single_machine(
    410       loss, accuracy, layer_collection, session_config=session_config)
    411 
    412 
    413 def train_mnist_distributed(task_id,
    414                             num_worker_tasks,
    415                             num_ps_tasks,
    416                             master,
    417                             data_dir,
    418                             num_epochs,
    419                             use_fake_data=False):
    420   """Train a ConvNet on MNIST.
    421 
    422   Args:
    423     task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
    424     num_worker_tasks: int. Number of workers in this distributed training setup.
    425     num_ps_tasks: int. Number of parameter servers holding variables.
    426     master: string. IP and port of TensorFlow runtime process.
    427     data_dir: string. Directory to read MNIST examples from.
    428     num_epochs: int. Number of passes to make over the training set.
    429     use_fake_data: bool. If True, generate a synthetic dataset.
    430 
    431   Returns:
    432     accuracy of model on the final minibatch of training data.
    433   """
    434   # Load a dataset.
    435   tf.logging.info("Loading MNIST into memory.")
    436   examples, labels = mnist.load_mnist(
    437       data_dir,
    438       num_epochs=num_epochs,
    439       batch_size=128,
    440       use_fake_data=use_fake_data,
    441       flatten_images=False)
    442 
    443   # Build a ConvNet.
    444   layer_collection = lc.LayerCollection()
    445   with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
    446     loss, accuracy = build_model(
    447         examples, labels, num_labels=10, layer_collection=layer_collection)
    448 
    449   # Fit model.
    450   checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
    451   return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
    452                                    master, checkpoint_dir, loss, accuracy,
    453                                    layer_collection)
    454 
    455 
    456 if __name__ == "__main__":
    457   tf.app.run()
    458