Home | History | Annotate | Download | only in g3doc
      1 # TensorFlow Eager Execution
      2 
      3 ## What is this?
      4 
      5 Eager execution is a feature that makes TensorFlow execute operations
      6 immediately: concrete values are returned, instead of a computational graph to
      7 be executed later.
      8 
      9 As a result, enabling eager execution provides:
     10 
     11 -   A [NumPy](http://www.numpy.org/)-like library for numerical computation with
     12     support for GPU acceleration and automatic differentiation.
     13 -   A flexible platform for machine learning research and experimentation.
     14 
     15 Eager execution is under active development. This guide walks through an
     16 alpha/preview release. In particular, not all TensorFlow APIs currently work
     17 with eager execution enabled, and some models may be slow to execute, compared
     18 to models defined without using eager execution.
     19 
     20 ## Installation
     21 
     22 Eager execution is included in TensorFlow versions 1.5 and above.
     23 Installation instructions at https://www.tensorflow.org/install/
     24 
     25 The contents of this guide are compatible with TensorFlow 1.5.
     26 However, if you run into bugs that are fixed in source but not the
     27 release, you may want to either [build from
     28 source](https://www.tensorflow.org/install/install_sources)
     29 or try a nightly build. The nightly builds are available as:
     30 
     31 - [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and
     32 
     33 - [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images.
     34 
     35 For example, to run the latest nightly docker image:
     36 
     37 ```sh
     38 # If you have a GPU, use https://github.com/NVIDIA/nvidia-docker
     39 docker pull tensorflow/tensorflow:nightly-gpu
     40 docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu
     41 
     42 # If you do not have a GPU, use the CPU-only image
     43 docker pull tensorflow/tensorflow:nightly
     44 docker run -it -p 8888:8888 tensorflow/tensorflow:nightly
     45 ```
     46 
     47 And then visit http://localhost:8888 in your browser for a Jupyter notebook
     48 environment.
     49 
     50 ## Getting Started
     51 
     52 With TensorFlow installed, eager execution is enabled via a single call:
     53 
     54 ```python
     55 import tensorflow as tf
     56 
     57 import tensorflow.contrib.eager as tfe
     58 
     59 tfe.enable_eager_execution()
     60 ```
     61 
     62 Enabling eager execution changes how TensorFlow functions behave (in particular,
     63 `Tensor` objects will reference concrete values instead of being symbolic
     64 handles to nodes in a computational graph). As a result, eager execution should
     65 be enabled at the beginning of a program and cannot be disabled afterwards in
     66 the same program.
     67 
     68 Code examples in the rest of this guide assume that eager execution has been
     69 enabled.
     70 
     71 ## A library for numerical computation
     72 
     73 A significant fraction of the [TensorFlow
     74 API](https://www.tensorflow.org/api_docs/python/) consists of numerical
     75 operations:
     76 [arithmetic operations](https://www.tensorflow.org/api_guides/python/math_ops#Arithmetic_Operators),
     77 [matrix operations](https://www.tensorflow.org/api_guides/python/math_ops#Matrix_Math_Functions),
     78 [linear algebra operations](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg),
     79 etc.
     80 
     81 With eager execution enabled, these operations consume and return
     82 multi-dimensional arrays as `Tensor` objects, similar to NumPy
     83 [`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html).
     84 For example:
     85 
     86 ```python
     87 # Multiply two 2x2 matrices
     88 x = tf.matmul([[1, 2],
     89                [3, 4]],
     90               [[4, 5],
     91                [6, 7]])
     92 # Add one to each element
     93 # (tf.add supports broadcasting)
     94 y = tf.add(x, 1)
     95 
     96 # Create a random random 5x3 matrix
     97 z = tf.random_uniform([5, 3])
     98 
     99 print(x)
    100 print(y)
    101 print(z)
    102 ```
    103 
    104 Output:
    105 
    106 ```
    107 tf.Tensor(
    108 [[16 19]
    109  [36 43]], shape=(2, 2), dtype=int32)
    110 tf.Tensor(
    111 [[17 20]
    112  [37 44]], shape=(2, 2), dtype=int32)
    113 tf.Tensor(
    114 [[ 0.25058532  0.0929395   0.54113817]
    115  [ 0.3108716   0.93350542  0.84909797]
    116  [ 0.53081679  0.12788558  0.01767385]
    117  [ 0.29725885  0.33540785  0.83588314]
    118  [ 0.38877153  0.39720535  0.78914213]], shape=(5, 3), dtype=float32)
    119 ```
    120 
    121 For convenience, these operations can also be triggered via operator overloading
    122 of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`,
    123 `-` to `tf.subtract`, `*` to `tf.multiply`, etc.:
    124 
    125 ```python
    126 x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1
    127 print(x)
    128 ```
    129 
    130 Output:
    131 
    132 ```
    133 tf.Tensor([ 3.], shape=(1,), dtype=float32)
    134 ```
    135 
    136 ### Converting to and from NumPy
    137 
    138 The operations above automatically convert Python objects (like lists of
    139 numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used
    140 as NumPy arrays by numpy operations.
    141 
    142 ```python
    143 import numpy as np
    144 
    145 x = tf.add(1, 1)                     # tf.Tensor with a value of 2
    146 y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2
    147 z = np.multiply(x, y)                # numpy.int64 with a value of 4
    148 ```
    149 
    150 Alternatively, they can be explicitly converted using
    151 [`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as
    152 shown in the next example.
    153 
    154 Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain
    155 its NumPy `ndarray` value. For example:
    156 
    157 ```python
    158 import numpy as np
    159 
    160 np_x = np.array(2., dtype=np.float32)
    161 x = tf.constant(np_x)
    162 
    163 py_y = 3.
    164 y = tf.constant(py_y)
    165 
    166 z = x + y + 1
    167 
    168 print(z)
    169 print(z.numpy())
    170 ```
    171 
    172 Output:
    173 
    174 ```
    175 tf.Tensor(6.0, shape=(), dtype=float32)
    176 6.0
    177 ```
    178 
    179 ### GPU acceleration
    180 
    181 Many TensorFlow operations support GPU acceleration. With eager execution
    182 enabled, [computation is *not* automatically
    183 offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you
    184 must explicitly specify when GPUs should be used.
    185 
    186 The simplest way to do this is to enclose your computation in a `with
    187 tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function,
    188 which returns the number of available GPUs.
    189 
    190 For example, consider this snippet to measure the time to multiply two 1000x1000
    191 matrices on CPU:
    192 
    193 ```python
    194 import time
    195 
    196 def measure(x):
    197   # The very first time a GPU is used by TensorFlow, it is initialized.
    198   # So exclude the first run from timing.
    199   tf.matmul(x, x)
    200 
    201   start = time.time()
    202   for i in range(10):
    203     tf.matmul(x, x)
    204   end = time.time()
    205 
    206   return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape)
    207 
    208 # Run on CPU:
    209 with tf.device("/cpu:0"):
    210   print("CPU: %s" % measure(tf.random_normal([1000, 1000])))
    211 
    212 # If a GPU is available, run on GPU:
    213 if tfe.num_gpus() > 0:
    214   with tf.device("/gpu:0"):
    215     print("GPU: %s" % measure(tf.random_normal([1000, 1000])))
    216 ```
    217 
    218 Output (exact numbers will depend on the characteristics of the hardware):
    219 
    220 ```python
    221 CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times
    222 GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times
    223 ```
    224 
    225 Alternatively, methods on the `Tensor` object can be used to explicitly copy the
    226 `Tensor` to a different device. Operations are typically executed on the device
    227 on which the inputs are placed. For example:
    228 
    229 ```python
    230 x = tf.random_normal([10, 10])
    231 
    232 x_gpu0 = x.gpu()
    233 x_cpu = x.cpu()
    234 
    235 _ = tf.matmul(x_cpu, x_cpu)  # Runs on CPU
    236 _ = tf.matmul(x_gpu0, x_gpu0)  # Runs on GPU:0
    237 
    238 if tfe.num_gpus() > 1:
    239   x_gpu1 = x.gpu(1)
    240   _ = tf.matmul(x_gpu1, x_gpu1)  # Runs on GPU:1
    241 ```
    242 
    243 ### Automatic Differentiation
    244 
    245 [Automatic
    246 differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is
    247 very useful when implementing many machine learning algorithms (e.g.,
    248 [backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training
    249 neural networks). For this purpose, TensorFlow eager execution provides an
    250 [autograd](https://github.com/HIPS/autograd)-style API for automatic
    251 differentiation. Specifically, the functions:
    252 
    253 -   `tfe.gradients_function(f)`: Returns a Python function that computes the
    254     derivatives of the Python function `f` with respect to its arguments. `f`
    255     must return a scalar value. When the returned function is invoked, it
    256     returns a list of `Tensor` objects (one element for each argument of `f`).
    257 -   `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`,
    258     except that when the returned function is invoked, it returns the value of
    259     `f` in addition to the list of derivatives of `f` with respect to its
    260     arguments.
    261 
    262 These functions naturally apply to higher order differentiation as well. For
    263 example:
    264 
    265 ```python
    266 def f(x):
    267   return tf.multiply(x, x)  # Or x * x
    268 assert 9 == f(3.).numpy()
    269 
    270 df = tfe.gradients_function(f)
    271 assert 6 == df(3.)[0].numpy()
    272 
    273 # Second order deriviative.
    274 d2f = tfe.gradients_function(lambda x: df(x)[0])
    275 assert 2 == d2f(3.)[0].numpy()
    276 
    277 # Third order derivative.
    278 d3f = tfe.gradients_function(lambda x : d2f(x)[0])
    279 assert 0 == d3f(3.)[0].numpy()
    280 ```
    281 
    282 These functions can be used to train models. For example, consider the following
    283 simple linear regression model:
    284 
    285 ```python
    286 def prediction(input, weight, bias):
    287   return input * weight + bias
    288 
    289 # A toy dataset of points around 3 * x + 2
    290 NUM_EXAMPLES = 1000
    291 training_inputs = tf.random_normal([NUM_EXAMPLES])
    292 noise = tf.random_normal([NUM_EXAMPLES])
    293 training_outputs = training_inputs * 3 + 2 + noise
    294 
    295 # A loss function: Mean-squared error
    296 def loss(weight, bias):
    297   error = prediction(training_inputs, weight, bias) - training_outputs
    298   return tf.reduce_mean(tf.square(error))
    299 
    300 # Function that returns the derivative of loss with respect to
    301 # weight and bias
    302 grad = tfe.gradients_function(loss)
    303 
    304 # Train for 200 steps (starting from some random choice for W and B, on the same
    305 # batch of data).
    306 W = 5.
    307 B = 10.
    308 learning_rate = 0.01
    309 print("Initial loss: %f" % loss(W, B).numpy())
    310 for i in range(200):
    311   (dW, dB) = grad(W, B)
    312   W -= dW * learning_rate
    313   B -= dB * learning_rate
    314   if i % 20 == 0:
    315     print("Loss at step %d: %f" % (i, loss(W, B).numpy()))
    316 print("Final loss: %f" % loss(W, B).numpy())
    317 print("W, B = %f, %f" % (W.numpy(), B.numpy()))
    318 ```
    319 
    320 Output: (the exact numbers may vary depending on the randomness in noise)
    321 
    322 ```
    323 Initial loss: 66.730003
    324 Loss at step 0: 64.200096
    325 Loss at step 20: 29.872814
    326 Loss at step 40: 14.233772
    327 Loss at step 60: 7.090570
    328 Loss at step 80: 3.819887
    329 Loss at step 100: 2.318821
    330 Loss at step 120: 1.628385
    331 Loss at step 140: 1.310142
    332 Loss at step 160: 1.163167
    333 Loss at step 180: 1.095162
    334 Final loss: 1.064711
    335 W, B = 3.094944, 2.161383
    336 ```
    337 
    338 To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):`
    339 block. (However, this particular model, with only two floating point parameters,
    340 is unlikely to benefit from GPU acceleration.)
    341 
    342 ### Customizing gradients
    343 
    344 One may want to define custom gradients for an operation, or for a function.
    345 This may be useful for multiple reasons, including providing a more efficient
    346 or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability)
    347 gradient for a sequence of operations.
    348 
    349 For example, consider the function `log(1 + e^x)`, which commonly occurs in the
    350 computation of cross entropy and log likelihoods.
    351 
    352 ```python
    353 def log1pexp(x):
    354  return tf.log(1 + tf.exp(x))
    355 grad_log1pexp = tfe.gradients_function(log1pexp)
    356 
    357 # Works fine at x = 0.
    358 assert 0.5 == float(grad_log1pexp(0.)[0])
    359 
    360 # Returns a `nan` at x = 100 due to numerical instability.
    361 import math
    362 assert math.isnan(float(grad_log1pexp(100.)[0]))
    363 ```
    364 
    365 We can define a custom gradient for the above function that analytically
    366 simplifies the gradient expression.
    367 
    368 ```python
    369 @tfe.custom_gradient
    370 def log1pexp(x):
    371  e = tf.exp(x)
    372  def grad(dy):
    373  return dy * (1 - 1 / (1 + e))
    374  return tf.log(1 + e), grad
    375 grad_log1pexp = tfe.gradients_function(log1pexp)
    376 
    377 # Works as before at x = 0.
    378 assert 0.5 == float(grad_log1pexp(0.)[0])
    379 
    380 # But now works at x = 100 as well.
    381 assert 1.0 == float(grad_log1pexp(100.)[0])
    382 ```
    383 Also notice how the gradient function implementation reuses an expression
    384 (`tf.exp(x)`) computed during the forward pass, hence making the gradient
    385 computation more efficient by avoiding redundant computation.
    386 
    387 ## Building and training models
    388 
    389 In practice, your computation may have many parameters to be optimized (by
    390 computing derivatives). Encapsulating them into re-usable classes/objects
    391 makes the code easier to follow than writing a single top-level function with
    392 many arguments.
    393 
    394 In fact, eager execution encourages use of the [Keras](https://keras.io)-style
    395 "Layer" classes in the
    396 [`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers)
    397 module.
    398 
    399 Furthermore, you may want to apply more sophisticated techniques to compute
    400 parameter updates, such as those in
    401 [`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers)
    402 implementations.
    403 
    404 This next section walks through using the same `Optimizer` and `Layer` APIs used
    405 to build trainable TensorFlow graphs in an environment where eager execution is
    406 enabled.
    407 
    408 ### Variables and Optimizers
    409 
    410 `tfe.Variable` objects store mutable `Tensor` values that can be accessed during
    411 training, making automatic differentiation easier. In particular, parameters of
    412 a model can be encapsulated in Python classes as variables.
    413 
    414 `tfe.gradients_function(f)` introduced earlier computes the derivatives of `f`
    415 with respect to its arguments. However, it requires all parameters of interest
    416 to be arguments of `f`, which becomes cumbersome when `f` depends on a large
    417 number of trainable parameters.
    418 
    419 `tfe.implicit_gradients` is an alternative function with some useful properties:
    420 
    421 -   It computes the derivatives of `f` with respect to all the `tfe.Variable`s
    422     used by `f`.
    423 -   When the returned function is invoked, it returns a list of
    424     (gradient value, Variable object) tuples.
    425 
    426 Representing model parameters as `Variable` objects, along with the use of
    427 `tfe.implicit_gradients`, typically results in better encapsulation. For
    428 example, the linear regression model described above can be written into a
    429 class:
    430 
    431 ```python
    432 class Model(object):
    433   def __init__(self):
    434     self.W = tfe.Variable(5., name='weight')
    435     self.B = tfe.Variable(10., name='bias')
    436 
    437   def predict(self, inputs):
    438     return inputs * self.W + self.B
    439 
    440 
    441 # The loss function to be optimized
    442 def loss(model, inputs, targets):
    443   error = model.predict(inputs) - targets
    444   return tf.reduce_mean(tf.square(error))
    445 
    446 # A toy dataset of points around 3 * x + 2
    447 NUM_EXAMPLES = 1000
    448 training_inputs = tf.random_normal([NUM_EXAMPLES])
    449 noise = tf.random_normal([NUM_EXAMPLES])
    450 training_outputs = training_inputs * 3 + 2 + noise
    451 
    452 # Define:
    453 # 1. A model
    454 # 2. Derivatives of a loss function with respect to model parameters
    455 # 3. A strategy for updating the variables based on the derivatives
    456 model = Model()
    457 grad = tfe.implicit_gradients(loss)
    458 optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    459 
    460 # The training loop
    461 print("Initial loss: %f" %
    462       loss(model, training_inputs, training_outputs).numpy())
    463 for i in range(201):
    464   optimizer.apply_gradients(grad(model, training_inputs, training_outputs))
    465   if i % 20 == 0:
    466     print("Loss at step %d: %f" %
    467           (i, loss(model, training_inputs, training_outputs).numpy()))
    468 print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy())
    469 print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy()))
    470 ```
    471 
    472 Output:
    473 
    474 ```
    475 Initial loss: 69.693184
    476 Loss at step 0: 66.987854
    477 Loss at step 20: 30.553387
    478 Loss at step 40: 14.250237
    479 Loss at step 60: 6.955020
    480 Loss at step 80: 3.690550
    481 Loss at step 100: 2.229739
    482 Loss at step 120: 1.576032
    483 Loss at step 140: 1.283496
    484 Loss at step 160: 1.152584
    485 Loss at step 180: 1.093999
    486 Final loss: 1.067780
    487 W, B = 3.0114281, 2.0865183
    488 ```
    489 
    490 Using `implicit_gradients` avoids the need to provide all the trainable
    491 parameters of the model as arguments to the `loss` function.
    492 
    493 ### Using Keras and the Layers API
    494 
    495 [Keras](https://keras.io) is a popular API for defining model structures. The
    496 [`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers)
    497 module provides a set of building blocks for models and is implemented using the
    498 `tf.layers.Layer` subclasses in the
    499 [`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers)
    500 module. We encourage the use of these same building blocks when using
    501 TensorFlow's eager execution feature. For example, the very same linear
    502 regression model can be built using `tf.layers.Dense`:
    503 
    504 ```python
    505 class Model(object):
    506   def __init__(self):
    507     self.layer = tf.layers.Dense(1)
    508 
    509   def predict(self, inputs):
    510     return self.layer(inputs)
    511 ```
    512 
    513 The `tf.layers` API makes it more convenient to define more sophisticated
    514 models. For example, the following will train an MNIST model:
    515 
    516 ```python
    517 class MNISTModel(object):
    518   def __init__(self, data_format):
    519     # 'channels_first' is typically faster on GPUs
    520     # while 'channels_last' is typically faster on CPUs.
    521     # See: https://www.tensorflow.org/performance/performance_guide#data_formats
    522     if data_format == 'channels_first':
    523       self._input_shape = [-1, 1, 28, 28]
    524     else:
    525       self._input_shape = [-1, 28, 28, 1]
    526     self.conv1 = tf.layers.Conv2D(32, 5,
    527                                   padding='same',
    528                                   activation=tf.nn.relu,
    529                                   data_format=data_format)
    530     self.max_pool2d = tf.layers.MaxPooling2D(
    531         (2, 2), (2, 2), padding='same', data_format=data_format)
    532     self.conv2 = tf.layers.Conv2D(64, 5,
    533                                   padding='same',
    534                                   activation=tf.nn.relu,
    535                                   data_format=data_format)
    536     self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    537     self.dropout = tf.layers.Dropout(0.5)
    538     self.dense2 = tf.layers.Dense(10)
    539 
    540   def predict(self, inputs):
    541     x = tf.reshape(inputs, self._input_shape)
    542     x = self.max_pool2d(self.conv1(x))
    543     x = self.max_pool2d(self.conv2(x))
    544     x = tf.layers.flatten(x)
    545     x = self.dropout(self.dense1(x))
    546     return self.dense2(x)
    547 
    548 def loss(model, inputs, targets):
    549   return tf.reduce_mean(
    550       tf.nn.softmax_cross_entropy_with_logits(
    551           logits=model.predict(inputs), labels=targets))
    552 
    553 
    554 # Load the training and validation data
    555 from tensorflow.examples.tutorials.mnist import input_data
    556 data = input_data.read_data_sets("./mnist_data", one_hot=True)
    557 
    558 # Train
    559 device = "gpu:0" if tfe.num_gpus() else "cpu:0"
    560 model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last')
    561 optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    562 grad = tfe.implicit_gradients(loss)
    563 for i in range(20001):
    564   with tf.device(device):
    565     (inputs, targets) = data.train.next_batch(50)
    566     optimizer.apply_gradients(grad(model, inputs, targets))
    567     if i % 100 == 0:
    568       print("Step %d: Loss on training set : %f" %
    569             (i, loss(model, inputs, targets).numpy()))
    570 print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy())
    571 ```
    572 
    573 For a more complete example, see
    574 [`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py)
    575 
    576 ### Checkpointing trained variables
    577 
    578 TensorFlow Variables (`tfe.Variable`) provides a way to represent shared,
    579 persistent state of your model. The `tfe.Saver` class (which is a thin wrapper
    580 over the
    581 [`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver)
    582 class) provides a means to save and restore variables to and from _checkpoints_.
    583 
    584 For example:
    585 
    586 ```python
    587 # Create variables.
    588 x = tfe.Variable(10., name='x')
    589 y = tfe.Variable(5., name='y')
    590 
    591 # Create a Saver.
    592 saver = tfe.Saver([x, y])
    593 
    594 # Assign new values to the variables and save.
    595 x.assign(2.)
    596 saver.save('/tmp/ckpt')
    597 
    598 # Change the variable after saving.
    599 x.assign(11.)
    600 assert 16. == (x + y).numpy()  # 11 + 5
    601 
    602 # Restore the values in the checkpoint.
    603 saver.restore('/tmp/ckpt')
    604 
    605 assert 7. == (x + y).numpy()  # 2 + 5
    606 ```
    607 
    608 ### `tfe.Network`
    609 
    610 You may often want to organize your models using classes, like the `MNISTModel`
    611 class described above. We recommend inheriting from the `tfe.Network` class as
    612 it provides conveniences like keeping track of all model variables and methods
    613 to save and restore from checkpoints.
    614 
    615 Sub-classes of `tfe.Network` may register `Layer`s (like classes in
    616 [`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers),
    617 or [Keras
    618 layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers))
    619 using a call to `self.track_layer()` and define the computation in an
    620 implementation of `call()`.
    621 
    622 Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables
    623 lazily, when the first input is encountered.
    624 
    625 For example, consider the following two-layer neural network:
    626 
    627 ```python
    628 class TwoLayerNet(tfe.Network):
    629   def __init__(self):
    630     super(TwoLayerNet, self).__init__()
    631     self.layer1 = self.track_layer(
    632       tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False))
    633     self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False))
    634 
    635   def call(self, x):
    636     return self.layer2(self.layer1(x))
    637 
    638 net = TwoLayerNet()
    639 
    640 # No variables created yet
    641 assert 0 == len(net.variables)
    642 
    643 # They are created on first input:
    644 inp = tf.constant([[1.]])
    645 
    646 # Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units,
    647 # the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3
    648 # matrix.
    649 assert [1, 3] == net(inp).shape.as_list()  # Invoke net; get output shape.
    650 assert 1 == len(net.layer1.variables)
    651 assert 1 == len(net.layer2.variables)
    652 assert 2 == len(net.variables)  # weights for each layer.
    653 assert [1, 2] == net.variables[0].shape.as_list()  # weights of layer1.
    654 assert [2, 3] == net.variables[1].shape.as_list()  # weights of layer2.
    655 ```
    656 
    657 The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows
    658 instances of `tfe.Network` to be embedded in other networks. For example:
    659 
    660 ```python
    661 class ThreeLayerNet(tfe.Network):
    662   def __init__(self):
    663     super(ThreeLayerNet, self).__init__()
    664     self.a = self.track_layer(TwoLayerNet())
    665     self.b = self.track_layer(tf.layers.Dense(4, use_bias=False))
    666 
    667   def call(self, x):
    668     return self.b(self.a(x))
    669 
    670 net = ThreeLayerNet()
    671 
    672 assert [1, 4] == net(inp).shape.as_list()
    673 assert 3 == len(net.variables)
    674 assert [1, 2] == net.variables[0].shape.as_list()
    675 assert [2, 3] == net.variables[1].shape.as_list()
    676 assert [3, 4] == net.variables[2].shape.as_list()
    677 ```
    678 
    679 See more examples in
    680 [`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples).
    681 
    682 `tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a
    683 convenient way to save and load checkpoints without changing the program once
    684 the checkpoint has been created. For example, we can set an objective for the
    685 output of our network, choose an optimizer, and a location for the checkpoint:
    686 
    687 ```python
    688 objective = tf.constant([[2., 3., 4., 5.]])
    689 optimizer = tf.train.AdamOptimizer(0.01)
    690 checkpoint_directory = '/tmp/tfe_example'
    691 checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    692 net = ThreeLayerNet()
    693 ```
    694 
    695 Note that variables have not been created yet. We want them to be restored from
    696 a checkpoint, if one exists, so we create them inside a
    697 `tfe.restore_variables_on_create` context manager. Then our training loop is the
    698 same whether starting training or resuming from a previous checkpoint:
    699 
    700 ```python
    701 with tfe.restore_variables_on_create(
    702     tf.train.latest_checkpoint(checkpoint_directory)):
    703   global_step = tf.train.get_or_create_global_step()
    704   for _ in range(100):
    705     loss_fn = lambda: tf.norm(net(inp) - objective)
    706     optimizer.minimize(loss_fn, global_step=global_step)
    707     if tf.equal(global_step % 20, 0):
    708       print("Step %d, output %s" % (global_step.numpy(),
    709                                     net(inp).numpy()))
    710       all_variables = (
    711           net.variables
    712           + optimizer.variables()
    713           + [global_step])
    714       # Save the checkpoint.
    715       tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
    716 ```
    717 
    718 The first time it runs, `Network` variables are initialized randomly. Then the
    719 output is trained to match the objective we've set:
    720 
    721 ```
    722 Step 20, output [[ 0.03575622  0.29863232  0.03474367  0.24735749]]
    723 Step 40, output [[ 0.40646029  0.9856872   0.46851286  0.95358551]]
    724 Step 60, output [[ 1.74541104  2.800704    1.79055595  2.74783421]]
    725 Step 80, output [[ 2.14977384  3.44340849  3.96120024  5.16242075]]
    726 Step 100, output [[ 1.99943113  3.02364397  3.93500996  4.9610076 ]]
    727 ```
    728 
    729 In subsequent iterations, variables are initialized with the values read from
    730 the latest checkpoint. Running the same code again, we continue from where we
    731 left off:
    732 
    733 ```
    734 Step 120, output [[ 1.99234128  3.0271616   3.98732996  4.96401167]]
    735 Step 140, output [[ 2.00133467  3.01270437  4.00616646  5.00406504]]
    736 Step 160, output [[ 1.99647415  2.9956708   3.99064088  4.99632359]]
    737 Step 180, output [[ 2.00699997  3.00904822  4.00706148  5.01193142]]
    738 Step 200, output [[ 1.98334622  2.98249531  3.97375059  4.97123432]]
    739 ```
    740 
    741 
    742 ### Summaries, metrics and TensorBoard
    743 
    744 [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
    745 is a popular tool for understanding, debugging and optimizing the model training
    746 process. To benefit from the visualizations offered by TensorBoard, summary
    747 events need to be written during the course of execution of your program. You
    748 might find many Tensorflow programs that include the
    749 [`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations
    750 during graph construction.
    751 
    752 `tf.summary` operations are *not* compatible with eager execution, but an
    753 equivalent alternative exists in
    754 [`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/summary)
    755 that is compatible with both eager execution and graph construction.
    756 
    757 During model construction simply insert summary operations like
    758 `tf.contrib.summary.scalar`. These operations do nothing by default, unless a
    759 summary writer is currently active and a writing policy is set.
    760 
    761 For example, to record summaries once every 100 global steps, use:
    762 
    763 ```python
    764 tf.train.get_or_create_global_step()  # Ensuring the global step variable exists
    765 writer = tf.contrib.summary.create_file_writer(logdir)
    766 
    767 for _ in range(iterations):
    768   with writer.as_default():
    769     with tf.contrib.summary.record_summaries_every_n_global_steps(100):
    770       # your model code goes here
    771       tf.contrib.summary.scalar('loss', loss)
    772       # ...
    773 ```
    774 
    775 See the full mnist example in
    776 [`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist)
    777 for a full model using `tf.contrib.summary`.
    778 
    779 Similarly to summaries, the metrics in `tf.metrics` are currently not compatible
    780 with eager execution. We instead provide object-oriented metrics in the
    781 `tfe.metrics` package, which are compatible with graph construction as well.
    782 
    783 Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and
    784 `tfe.Metrics.Accuracy`, all implement an intuitive object-oriented
    785 interface. Here's an example of how to use the `tfe.metrics.Mean` metric:
    786 
    787 ```python
    788 # Metrics are objects, which can be created and destroyed.
    789 my_mean = tfe.metrics.Mean(name='my_mean')
    790 # While a metric is active, you can call it as a function to accumulate into its
    791 # internal state.
    792 my_mean(0.0)
    793 my_mean(10.0)
    794 # Once you've finished updating the metric, you can get its result. In this case
    795 # a simple average over all the calls to it. If a summary writer is active the
    796 # metric will write the appropriate summaries using the metric name.
    797 assert 5.0 == my_mean.result().numpy()
    798 ```
    799 
    800 For a full example of a model using metrics for evaluation, see the mnist
    801 example in
    802 [`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist).
    803 
    804 ### Input Pipelines
    805 
    806 The discussion above has been centered around the computation executed by your
    807 model. The
    808 [`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data)
    809 module provides APIs to build complex input pipelines from simple, reusable
    810 pieces.
    811 
    812 If you're familiar with constructing `tf.data.Dataset` objects when building
    813 TensorFlow graphs, the same API calls are used when eager execution is enabled.
    814 However, the process of iterating over elements of the dataset differs between
    815 eager execution and graph construction. When eager execution is enabled, the
    816 discussion on iterator creation using `make_one_shot_iterator()` and
    817 `get_next()` in the
    818 [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is
    819 *not* applicable. Instead, a more Pythonic `Iterator` class is available.
    820 
    821 For example:
    822 
    823 ```python
    824 # Create a source Dataset from in-memory numpy arrays.
    825 # For reading from files on disk, you may want to use other Dataset classes
    826 # like the TextLineDataset or the TFRecordDataset.
    827 dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
    828 
    829 # Apply transformations, shuffling, batching etc.
    830 dataset = dataset.map(tf.square).shuffle(2).batch(2)
    831 
    832 # Use tfe.Iterator to iterate over the dataset.
    833 for x in tfe.Iterator(dataset):
    834   print(x)
    835 ```
    836 
    837 Output:
    838 
    839 ```
    840 tf.Tensor([4 9], shape=(2,), dtype=int32)
    841 tf.Tensor([16 25], shape=(2,), dtype=int32)
    842 tf.Tensor([36  1], shape=(2,), dtype=int32)
    843 ```
    844 
    845 ## Interoperating with Graphs
    846 
    847 Eager execution improves the process of model development in Python; however,
    848 because it is in its earliest stages, it does not yet support some features
    849 available to [TensorFlow
    850 graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph)
    851 that are desirable when deploying models in production. In particular, eager
    852 execution does not yet support distributed training, exporting models (to other
    853 [programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow
    854 serving](https://www.tensorflow.org/serving/), and mobile applications), and
    855 various memory and computation optimizations that are applied to TensorFlow's
    856 dataflow graphs.
    857 
    858 That said, the APIs used to build modes are exactly the same whether executing
    859 eagerly or constructing graphs. This means that you can iteratively develop your
    860 model with eager execution enabled and later, if needed, use the same code to
    861 reap the benefits of representing models as computational graphs.
    862 
    863 For example,
    864 [`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py)
    865 defines a model that is eagerly executed. That same code is used to construct
    866 and execute a graph in
    867 [`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py).
    868 
    869 Other models in the [examples
    870 directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/)
    871 demonstrate this as well.
    872 
    873 Some differences worth noting:
    874 
    875 -   There is no notion of a `tf.placeholder` or a `tf.Session` when eager
    876     execution is enabled.
    877 -   Many properties on the `tf.Tensor` object, like `tf.Tensor.name`,
    878     `tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution
    879     is enabled and their use will raise an `AttributeError`.
    880 -   To use `tfe.implicit_gradients` in graph construction, variables must be
    881     created with [`use_resource=True`] provided to
    882     [`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable)
    883     or
    884     [`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope).
    885 -   Some API calls (such as the functional-style `tf.layers.dense`,
    886     `tf.layers.conv2d`) are not compatible with eager execution. Use of such
    887     methods should raise an error indicating the alternative (e.g., the
    888     `tf.layers.Dense` and `tf.layers.Conv2D` classes).
    889 
    890 ## What next?
    891 
    892 Please give eager execution a spin. This feature is in early stages and is
    893 evolving, so we welcome your feedback via issues on GitHub (see [known
    894 issues](https://github.com/tensorflow/tensorflow/labels/comp:eager)).
    895 
    896 You may want to browse through some sample code, including benchmarks for some:
    897 
    898 -   [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression)
    899 -   [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist)
    900 -   [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50)
    901 -   [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot)
    902 -   [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb)
    903 
    904