1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""Train a ConvNet on MNIST using K-FAC. 16 17 This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the 18 following structure, 19 20 - Conv Layer: 5x5 kernel, 16 output channels. 21 - Max Pool: 3x3 kernel, stride 2. 22 - Conv Layer: 5x5 kernel, 16 output channels. 23 - Max Pool: 3x3 kernel, stride 2. 24 - Linear: 10 output dims. 25 26 After 3k~6k steps, this should reach perfect accuracy on the training set. 27 """ 28 29 from __future__ import absolute_import 30 from __future__ import division 31 from __future__ import print_function 32 33 import os 34 35 import numpy as np 36 import tensorflow as tf 37 38 from tensorflow.contrib.kfac.examples import mlp 39 from tensorflow.contrib.kfac.examples import mnist 40 41 lc = tf.contrib.kfac.layer_collection 42 oq = tf.contrib.kfac.op_queue 43 opt = tf.contrib.kfac.optimizer 44 45 __all__ = [ 46 "conv_layer", 47 "max_pool_layer", 48 "linear_layer", 49 "build_model", 50 "minimize_loss_single_machine", 51 "minimize_loss_distributed", 52 "train_mnist_single_machine", 53 "train_mnist_distributed", 54 ] 55 56 57 def conv_layer(layer_id, inputs, kernel_size, out_channels): 58 """Builds a convolutional layer with ReLU non-linearity. 59 60 Args: 61 layer_id: int. Integer ID for this layer's variables. 62 inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row 63 corresponds to a single example. 64 kernel_size: int. Width and height of the convolution kernel. The kernel is 65 assumed to be square. 66 out_channels: int. Number of output features per pixel. 67 68 Returns: 69 preactivations: Tensor of shape [num_examples, width, height, out_channels]. 70 Values of the layer immediately before the activation function. 71 activations: Tensor of shape [num_examples, width, height, out_channels]. 72 Values of the layer immediately after the activation function. 73 params: Tuple of (kernel, bias), parameters for this layer. 74 """ 75 # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. 76 layer = tf.layers.Conv2D( 77 out_channels, 78 kernel_size=[kernel_size, kernel_size], 79 kernel_initializer=tf.random_normal_initializer(stddev=0.01), 80 padding="SAME", 81 name="conv_%d" % layer_id) 82 preactivations = layer(inputs) 83 activations = tf.nn.relu(preactivations) 84 85 # layer.weights is a list. This converts it a (hashable) tuple. 86 return preactivations, activations, (layer.kernel, layer.bias) 87 88 89 def max_pool_layer(layer_id, inputs, kernel_size, stride): 90 """Build a max-pooling layer. 91 92 Args: 93 layer_id: int. Integer ID for this layer's variables. 94 inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row 95 corresponds to a single example. 96 kernel_size: int. Width and height to pool over per input channel. The 97 kernel is assumed to be square. 98 stride: int. Step size between pooling operations. 99 100 Returns: 101 Tensor of shape [num_examples, width/stride, height/stride, out_channels]. 102 Result of applying max pooling to 'inputs'. 103 """ 104 # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. 105 with tf.variable_scope("pool_%d" % layer_id): 106 return tf.nn.max_pool( 107 inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1], 108 padding="SAME", 109 name="pool") 110 111 112 def linear_layer(layer_id, inputs, output_size): 113 """Builds the final linear layer for an MNIST classification problem. 114 115 Args: 116 layer_id: int. Integer ID for this layer's variables. 117 inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row 118 corresponds to a single example. 119 output_size: int. Number of output dims per example. 120 121 Returns: 122 activations: Tensor of shape [num_examples, output_size]. Values of the 123 layer immediately after the activation function. 124 params: Tuple of (weights, bias), parameters for this layer. 125 """ 126 # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. 127 pre, _, params = mlp.fc_layer(layer_id, inputs, output_size) 128 return pre, params 129 130 131 def build_model(examples, labels, num_labels, layer_collection): 132 """Builds a ConvNet classification model. 133 134 Args: 135 examples: Tensor of shape [num_examples, num_features]. Represents inputs of 136 model. 137 labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted 138 by softmax for each example. 139 num_labels: int. Number of distinct values 'labels' can take on. 140 layer_collection: LayerCollection instance. Layers will be registered here. 141 142 Returns: 143 loss: 0-D Tensor representing loss to be minimized. 144 accuracy: 0-D Tensor representing model's accuracy. 145 """ 146 # Build a ConvNet. For each layer with parameters, we'll keep track of the 147 # preactivations, activations, weights, and bias. 148 tf.logging.info("Building model.") 149 pre0, act0, params0 = conv_layer( 150 layer_id=0, inputs=examples, kernel_size=5, out_channels=16) 151 act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) 152 pre2, act2, params2 = conv_layer( 153 layer_id=2, inputs=act1, kernel_size=5, out_channels=16) 154 act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) 155 flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) 156 logits, params4 = linear_layer( 157 layer_id=4, inputs=flat_act3, output_size=num_labels) 158 loss = tf.reduce_mean( 159 tf.nn.sparse_softmax_cross_entropy_with_logits( 160 labels=labels, logits=logits)) 161 accuracy = tf.reduce_mean( 162 tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) 163 164 tf.summary.scalar("loss", loss) 165 tf.summary.scalar("accuracy", accuracy) 166 167 # Register parameters. K-FAC needs to know about the inputs, outputs, and 168 # parameters of each conv/fully connected layer and the logits powering the 169 # posterior probability over classes. 170 tf.logging.info("Building LayerCollection.") 171 layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, 172 pre0) 173 layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) 174 layer_collection.register_fully_connected(params4, flat_act3, logits) 175 layer_collection.register_categorical_predictive_distribution( 176 logits, name="logits") 177 178 return loss, accuracy 179 180 181 def minimize_loss_single_machine(loss, 182 accuracy, 183 layer_collection, 184 session_config=None): 185 """Minimize loss with K-FAC on a single machine. 186 187 A single Session is responsible for running all of K-FAC's ops. 188 189 Args: 190 loss: 0-D Tensor. Loss to be minimized. 191 accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. 192 layer_collection: LayerCollection instance describing model architecture. 193 Used by K-FAC to construct preconditioner. 194 session_config: None or tf.ConfigProto. Configuration for tf.Session(). 195 196 Returns: 197 final value for 'accuracy'. 198 """ 199 # Train with K-FAC. 200 global_step = tf.train.get_or_create_global_step() 201 optimizer = opt.KfacOptimizer( 202 learning_rate=0.0001, 203 cov_ema_decay=0.95, 204 damping=0.001, 205 layer_collection=layer_collection, 206 momentum=0.9) 207 train_op = optimizer.minimize(loss, global_step=global_step) 208 209 tf.logging.info("Starting training.") 210 with tf.train.MonitoredTrainingSession(config=session_config) as sess: 211 while not sess.should_stop(): 212 global_step_, loss_, accuracy_, _, _ = sess.run( 213 [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) 214 215 if global_step_ % 100 == 0: 216 sess.run(optimizer.inv_update_op) 217 218 if global_step_ % 100 == 0: 219 tf.logging.info("global_step: %d | loss: %f | accuracy: %s", 220 global_step_, loss_, accuracy_) 221 222 return accuracy_ 223 224 225 def _is_gradient_task(task_id, num_tasks): 226 """Returns True if this task should update the weights.""" 227 if num_tasks < 3: 228 return True 229 return 0 <= task_id < 0.6 * num_tasks 230 231 232 def _is_cov_update_task(task_id, num_tasks): 233 """Returns True if this task should update K-FAC's covariance matrices.""" 234 if num_tasks < 3: 235 return False 236 return 0.6 * num_tasks <= task_id < num_tasks - 1 237 238 239 def _is_inv_update_task(task_id, num_tasks): 240 """Returns True if this task should update K-FAC's preconditioner.""" 241 if num_tasks < 3: 242 return False 243 return task_id == num_tasks - 1 244 245 246 def _num_gradient_tasks(num_tasks): 247 """Number of tasks that will update weights.""" 248 if num_tasks < 3: 249 return num_tasks 250 return int(np.ceil(0.6 * num_tasks)) 251 252 253 def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, 254 checkpoint_dir, loss, accuracy, layer_collection): 255 """Minimize loss with an synchronous implementation of K-FAC. 256 257 Different tasks are responsible for different parts of K-FAC's Ops. The first 258 60% of tasks update weights; the next 20% accumulate covariance statistics; 259 the last 20% invert the matrices used to precondition gradients. 260 261 Args: 262 task_id: int. Integer in [0, num_worker_tasks). ID for this worker. 263 num_worker_tasks: int. Number of workers in this distributed training setup. 264 num_ps_tasks: int. Number of parameter servers holding variables. If 0, 265 parameter servers are not used. 266 master: string. IP and port of TensorFlow runtime process. Set to empty 267 string to run locally. 268 checkpoint_dir: string or None. Path to store checkpoints under. 269 loss: 0-D Tensor. Loss to be minimized. 270 accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to 271 run with each step. 272 layer_collection: LayerCollection instance describing model architecture. 273 Used by K-FAC to construct preconditioner. 274 275 Returns: 276 final value for 'accuracy'. 277 278 Raises: 279 ValueError: if task_id >= num_worker_tasks. 280 """ 281 with tf.device(tf.train.replica_device_setter(num_ps_tasks)): 282 global_step = tf.train.get_or_create_global_step() 283 optimizer = opt.KfacOptimizer( 284 learning_rate=0.0001, 285 cov_ema_decay=0.95, 286 damping=0.001, 287 layer_collection=layer_collection, 288 momentum=0.9) 289 inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) 290 sync_optimizer = tf.train.SyncReplicasOptimizer( 291 opt=optimizer, 292 replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) 293 train_op = sync_optimizer.minimize(loss, global_step=global_step) 294 295 tf.logging.info("Starting training.") 296 is_chief = (task_id == 0) 297 hooks = [sync_optimizer.make_session_run_hook(is_chief)] 298 with tf.train.MonitoredTrainingSession( 299 master=master, 300 is_chief=is_chief, 301 checkpoint_dir=checkpoint_dir, 302 hooks=hooks, 303 stop_grace_period_secs=0) as sess: 304 while not sess.should_stop(): 305 # Choose which op this task is responsible for running. 306 if _is_gradient_task(task_id, num_worker_tasks): 307 learning_op = train_op 308 elif _is_cov_update_task(task_id, num_worker_tasks): 309 learning_op = optimizer.cov_update_op 310 elif _is_inv_update_task(task_id, num_worker_tasks): 311 # TODO(duckworthd): Running this op before cov_update_op has been run a 312 # few times can result in "InvalidArgumentError: Cholesky decomposition 313 # was not successful." Delay running this op until cov_update_op has 314 # been run a few times. 315 learning_op = inv_update_queue.next_op(sess) 316 else: 317 raise ValueError("Which op should task %d do?" % task_id) 318 319 global_step_, loss_, accuracy_, _ = sess.run( 320 [global_step, loss, accuracy, learning_op]) 321 tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, 322 loss_, accuracy_) 323 324 return accuracy_ 325 326 327 def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): 328 """Train a ConvNet on MNIST. 329 330 Args: 331 data_dir: string. Directory to read MNIST examples from. 332 num_epochs: int. Number of passes to make over the training set. 333 use_fake_data: bool. If True, generate a synthetic dataset. 334 335 Returns: 336 accuracy of model on the final minibatch of training data. 337 """ 338 # Load a dataset. 339 tf.logging.info("Loading MNIST into memory.") 340 examples, labels = mnist.load_mnist( 341 data_dir, 342 num_epochs=num_epochs, 343 batch_size=128, 344 use_fake_data=use_fake_data, 345 flatten_images=False) 346 347 # Build a ConvNet. 348 layer_collection = lc.LayerCollection() 349 loss, accuracy = build_model( 350 examples, labels, num_labels=10, layer_collection=layer_collection) 351 352 # Fit model. 353 return minimize_loss_single_machine(loss, accuracy, layer_collection) 354 355 356 def train_mnist_multitower(data_dir, num_epochs, num_towers, 357 use_fake_data=True): 358 """Train a ConvNet on MNIST. 359 360 Args: 361 data_dir: string. Directory to read MNIST examples from. 362 num_epochs: int. Number of passes to make over the training set. 363 num_towers: int. Number of CPUs to split inference across. 364 use_fake_data: bool. If True, generate a synthetic dataset. 365 366 Returns: 367 accuracy of model on the final minibatch of training data. 368 """ 369 # Load a dataset. 370 tf.logging.info("Loading MNIST into memory.") 371 tower_batch_size = 128 372 batch_size = tower_batch_size * num_towers 373 tf.logging.info( 374 ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " 375 "tower batch size.") % (batch_size, num_towers, tower_batch_size)) 376 examples, labels = mnist.load_mnist( 377 data_dir, 378 num_epochs=num_epochs, 379 batch_size=batch_size, 380 use_fake_data=use_fake_data, 381 flatten_images=False) 382 383 # Split minibatch across towers. 384 examples = tf.split(examples, num_towers) 385 labels = tf.split(labels, num_towers) 386 387 # Build an MLP. Each tower's layers will be added to the LayerCollection. 388 layer_collection = lc.LayerCollection() 389 tower_results = [] 390 for tower_id in range(num_towers): 391 with tf.device("/cpu:%d" % tower_id): 392 with tf.name_scope("tower%d" % tower_id): 393 with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): 394 tf.logging.info("Building tower %d." % tower_id) 395 tower_results.append( 396 build_model(examples[tower_id], labels[tower_id], 10, 397 layer_collection)) 398 losses, accuracies = zip(*tower_results) 399 400 # Average across towers. 401 loss = tf.reduce_mean(losses) 402 accuracy = tf.reduce_mean(accuracies) 403 404 # Fit model. 405 session_config = tf.ConfigProto( 406 allow_soft_placement=False, device_count={ 407 "CPU": num_towers 408 }) 409 return minimize_loss_single_machine( 410 loss, accuracy, layer_collection, session_config=session_config) 411 412 413 def train_mnist_distributed(task_id, 414 num_worker_tasks, 415 num_ps_tasks, 416 master, 417 data_dir, 418 num_epochs, 419 use_fake_data=False): 420 """Train a ConvNet on MNIST. 421 422 Args: 423 task_id: int. Integer in [0, num_worker_tasks). ID for this worker. 424 num_worker_tasks: int. Number of workers in this distributed training setup. 425 num_ps_tasks: int. Number of parameter servers holding variables. 426 master: string. IP and port of TensorFlow runtime process. 427 data_dir: string. Directory to read MNIST examples from. 428 num_epochs: int. Number of passes to make over the training set. 429 use_fake_data: bool. If True, generate a synthetic dataset. 430 431 Returns: 432 accuracy of model on the final minibatch of training data. 433 """ 434 # Load a dataset. 435 tf.logging.info("Loading MNIST into memory.") 436 examples, labels = mnist.load_mnist( 437 data_dir, 438 num_epochs=num_epochs, 439 batch_size=128, 440 use_fake_data=use_fake_data, 441 flatten_images=False) 442 443 # Build a ConvNet. 444 layer_collection = lc.LayerCollection() 445 with tf.device(tf.train.replica_device_setter(num_ps_tasks)): 446 loss, accuracy = build_model( 447 examples, labels, num_labels=10, layer_collection=layer_collection) 448 449 # Fit model. 450 checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") 451 return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, 452 master, checkpoint_dir, loss, accuracy, 453 layer_collection) 454 455 456 if __name__ == "__main__": 457 tf.app.run() 458