Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for JIT compilation on the CPU and GPU devices."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.compiler import jit
     24 from tensorflow.core.protobuf import config_pb2
     25 from tensorflow.python.client import session as session_lib
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import function
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import control_flow_ops
     32 from tensorflow.python.ops import gradients_impl
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import nn_ops
     35 from tensorflow.python.platform import test
     36 
     37 jit_scope = jit.experimental_jit_scope
     38 
     39 
     40 def CompiledKernel(fn, *inputs, **kwargs):
     41   """Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
     42   name = kwargs.pop("name", None)
     43   noinline = kwargs.pop("noinline", None)
     44 
     45   @function.Defun(func_name=name, noinline=noinline, compiled=True)
     46   def Compiled(*args):
     47     return fn(*args)
     48 
     49   return Compiled(*inputs)
     50 
     51 
     52 def RunMetadataLabels(run_metadata):
     53   """Returns all labels in run_metadata."""
     54   labels = []
     55   for dev_stats in run_metadata.step_stats.dev_stats:
     56     for node_stats in dev_stats.node_stats:
     57       labels.append(node_stats.timeline_label)
     58   return labels
     59 
     60 
     61 def InLabels(labels, substr):
     62   """Returns true iff one of the labels contains substr."""
     63   return any([substr in x for x in labels])
     64 
     65 
     66 def MetadataHasXlaLaunch(run_metadata):
     67   """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
     68 
     69   # TODO(phawkins): find a less hacky way to test whether a kernel ran.
     70   return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
     71 
     72 
     73 class JitLaunchTest(test.TestCase):
     74 
     75   # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
     76   # Verifies that the outputs match and that XLA was invoked. 'fn' must take
     77   # the same number of tensors as arguments that are in 'args', and must return
     78   # a tuple of output tensors.
     79   # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
     80   # actually ran. However, it is sometimes possible for _XlaLaunch ops to be
     81   # constant-folded away, so the check is optional.
     82   def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
     83     with session_lib.Session() as sess:
     84       placeholders = []
     85       feeds = {}
     86       for arg in args:
     87         placeholder = array_ops.placeholder(
     88             dtypes.as_dtype(arg.dtype), list(arg.shape))
     89         placeholders.append(placeholder)
     90         feeds[placeholder] = arg
     91 
     92       compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
     93       direct_op = fn(*placeholders)
     94 
     95       run_metadata = config_pb2.RunMetadata()
     96       compiled = sess.run(compiled_op,
     97                           feeds,
     98                           run_metadata=run_metadata,
     99                           options=config_pb2.RunOptions(
    100                               trace_level=config_pb2.RunOptions.FULL_TRACE))
    101       print("Compiled Result {}".format(compiled))
    102 
    103       if require_kernel_launch:
    104         self.assert_(MetadataHasXlaLaunch(run_metadata))
    105 
    106         direct = sess.run(direct_op, feeds)
    107         print("Direct Result {}".format(direct))
    108 
    109         if (isinstance(compiled, (tuple, list)) and
    110             (isinstance(direct, (tuple, list)))):
    111           for (x, y) in zip(compiled, direct):
    112             self.assertAllClose(x, y, rtol=1e-1)
    113         else:
    114           self.assertAllClose(compiled, direct)
    115 
    116   def testNoOutputs(self):
    117     with session_lib.Session() as sess:
    118 
    119       # Check that calling the result as a compiled kernel doesn't crash.
    120       @function.Defun(compiled=True)
    121       def KernelWithNoOutputs():
    122         a = constant_op.constant(100)  # pylint: disable=unused-variable
    123 
    124       call = KernelWithNoOutputs()  # pylint: disable=assignment-from-no-return
    125       sess.run(call, {})
    126 
    127   def testAliasing(self):
    128     """Regression test for compiled functions that return an aliased buffer.
    129 
    130        XLA returns aliased buffers if outputs are identical. Tests that
    131        we handle that case.
    132     """
    133 
    134     def AddOnceReturnTwice(x):
    135       y = math_ops.add(x, x)
    136       return y, y
    137 
    138     # Exercises compling a function (say, Foo) which calls another
    139     # function (say, Bar) which is not inlined. When the compiler compiles
    140     # Foo, it needs to symbolic execute Bar correctly regardless whether
    141     # Bar is inlined or not.
    142 
    143     # TODO(b/36139787): Re-enable this test when noinline works again.
    144     # Tests compiled=True and noinline=True.
    145     # self._compare(
    146     #     AddOnceReturnTwice, [np.array(
    147     #         [[[0.5, -1.0]]], dtype=np.float32)],
    148     #     noinline=True)
    149 
    150     # Tests compiled=True and noinline=False.
    151     self._compare(
    152         AddOnceReturnTwice, [np.array(
    153             [[[0.5, -1.0]]], dtype=np.float32)],
    154         noinline=False)
    155 
    156   def testOneConstOutput(self):
    157     """Test consisting of a single constant return value."""
    158 
    159     def OneConstOutput():
    160       return constant_op.constant([-3, 44, 99])
    161 
    162     self._compare(OneConstOutput, [], require_kernel_launch=False)
    163 
    164   def testConstZeroElementOutput(self):
    165     """Test consisting of a constant zero element return value."""
    166 
    167     def ConstZeroElementOutput():
    168       return array_ops.fill([7, 0], 3.0)
    169 
    170     self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
    171 
    172   def testSomeConstOutputs(self):
    173     """Test kernels that return a mixture of const and non-const outputs."""
    174 
    175     def SomeConstOutputs(x):
    176       return constant_op.constant(
    177           [-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
    178 
    179     self._compare(
    180         SomeConstOutputs, [np.array(
    181             [[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
    182 
    183   def testInt32Input(self):
    184     """Test an int32-typed input.
    185 
    186        On a GPU, int32 tensors will be placed in host memory.
    187     """
    188 
    189     def AddToSelf(x):
    190       return math_ops.add(x, x)
    191 
    192     self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
    193 
    194   def testMandatoryConstantInput(self):
    195     """Tests an operator that has a mandatory-constant shape input."""
    196 
    197     def FillWithFloat(x):
    198       return array_ops.fill(x, 9.5)
    199 
    200     self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
    201 
    202   def testMnistForwardFunc(self):
    203     """Compute inference function from MNIST beginners tutorial."""
    204     batch_size = 16
    205     image_size = 28 * 28
    206     num_classes = 10
    207 
    208     # Define a TensorFlow function to compute the forward pass.
    209     def MnistForward(w, b, x):
    210       return nn_ops.softmax(math_ops.matmul(x, w) + b)
    211 
    212     w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
    213     b = np.random.random_sample((num_classes)).astype(np.float32)
    214     x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
    215     self._compare(MnistForward, [w, b, x])
    216 
    217   def testExplicitMarking(self):
    218     """Test explicit marking of operators to compile."""
    219     batch_size = 16
    220     image_size = 28 * 28
    221     num_classes = 10
    222 
    223     with ops.Graph().as_default():
    224       x = array_ops.placeholder(dtypes.float32)
    225       w = array_ops.placeholder(dtypes.float32)
    226       b = array_ops.placeholder(dtypes.float32)
    227       with jit_scope():
    228         y1 = math_ops.matmul(x, w)
    229       y2 = math_ops.add(y1, b)
    230       with jit_scope():
    231         y = math_ops.square(y2)
    232 
    233       dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
    234       db = np.random.random_sample((num_classes)).astype(np.float32)
    235       dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
    236       with session_lib.Session() as sess:
    237         run_metadata = config_pb2.RunMetadata()
    238         output = sess.run(y, {x: dx,
    239                               w: dw,
    240                               b: db},
    241                           run_metadata=run_metadata,
    242                           options=config_pb2.RunOptions(
    243                               trace_level=config_pb2.RunOptions.FULL_TRACE))
    244 
    245         # TODO(phawkins): really we would like to test that there were exactly
    246         # two kernel launches. However, we have no reliable way to determine
    247         # that.
    248         self.assert_(MetadataHasXlaLaunch(run_metadata))
    249 
    250         expected = np.square(np.dot(dx, dw) + db)
    251         self.assertAllClose(expected, output, rtol=1e-1)
    252 
    253 
    254 class XlaCompilationTest(test.TestCase):
    255   """Tests for auto-compilation on CPU/GPU devices."""
    256 
    257   def testReshape(self):
    258     """Tests an operator with compile-time constant and non-constant inputs."""
    259 
    260     with self.test_session() as sess:
    261       x = array_ops.placeholder(dtypes.float32)
    262       y = array_ops.placeholder(dtypes.int32)
    263       with jit_scope():
    264         # Reshape's first argument is non-constant in the JIT, but its second
    265         # (shape) argument will be treated as a compile-time constant for
    266         # each JIT compilation.
    267         # We do not use a tf.const() argument since we want to ensure the
    268         # shape is still a run-time argument to the JIT, and not
    269         # statically known as part of the JIT compilation's input graph.
    270         z = array_ops.reshape(x, y)
    271       run_metadata = config_pb2.RunMetadata()
    272       out = sess.run(z,
    273                      {x: np.array([1, 2, 3, 4, 5, 6], np.float32),
    274                       y: [-1, 3]},
    275                      run_metadata=run_metadata,
    276                      options=config_pb2.RunOptions(
    277                          trace_level=config_pb2.RunOptions.FULL_TRACE))
    278       self.assert_(MetadataHasXlaLaunch(run_metadata))
    279       self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
    280 
    281   def testIgnoredArguments(self):
    282     """Tests that JIT computations can ignore formal parameters."""
    283 
    284     with self.test_session() as sess:
    285       x = array_ops.placeholder(dtypes.int32)
    286       y = array_ops.placeholder(dtypes.int32)
    287       with jit_scope():
    288         z = math_ops.add(x, x)
    289         w = math_ops.add(y, y)
    290         # Pulls 'w' into the same compilation via control dependencies.
    291         with ops.control_dependencies([w]):
    292           n = control_flow_ops.no_op()
    293         with ops.control_dependencies([n]):
    294           t = math_ops.add(z, z)
    295 
    296       run_metadata = config_pb2.RunMetadata()
    297       out = sess.run(t, {x: np.int32(7),
    298                          y: np.int32(404)},
    299                      run_metadata=run_metadata,
    300                      options=config_pb2.RunOptions(
    301                          trace_level=config_pb2.RunOptions.FULL_TRACE))
    302       self.assert_(MetadataHasXlaLaunch(run_metadata))
    303       self.assertAllClose(28, out)
    304 
    305   def testLoops(self):
    306     """Tests that compilation accepts computations containing loops."""
    307 
    308     with self.test_session() as session:
    309       x = array_ops.placeholder(dtypes.float32)
    310       with jit_scope():
    311         c = lambda i, _: math_ops.less(i, 5)
    312         b = lambda i, x: (i + 1, x * 2.0 + 1.0)
    313         _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
    314 
    315       run_metadata = config_pb2.RunMetadata()
    316       result = session.run(y, {x: np.float32(2)},
    317                            run_metadata=run_metadata,
    318                            options=config_pb2.RunOptions(
    319                                trace_level=config_pb2.RunOptions.FULL_TRACE))
    320       self.assert_(MetadataHasXlaLaunch(run_metadata))
    321       self.assertAllClose(result, np.float32(95), rtol=1e-1)
    322 
    323   def testCond(self):
    324     """Tests that compilation handles switch operators."""
    325 
    326     with self.test_session() as session:
    327       x = array_ops.placeholder(dtypes.float32)
    328       y = array_ops.placeholder(dtypes.float32)
    329       c = array_ops.placeholder(dtypes.bool)
    330       with jit_scope():
    331         z = x + 1.0
    332         w = control_flow_ops.cond(c, lambda: z, lambda: y)
    333         t = math_ops.add(z, w)
    334 
    335       # If JIT compilation chooses to cluster z and t, then execution will
    336       # deadlock.
    337 
    338       run_metadata = config_pb2.RunMetadata()
    339       result = session.run(t, {x: np.float32(2),
    340                                y: np.float32(4),
    341                                c: True},
    342                            run_metadata=run_metadata,
    343                            options=config_pb2.RunOptions(
    344                                trace_level=config_pb2.RunOptions.FULL_TRACE))
    345       self.assert_(MetadataHasXlaLaunch(run_metadata))
    346       self.assertAllClose(result, np.float32(6), rtol=1e-1)
    347 
    348   def testNestedFunction(self):
    349     g = ops.Graph()
    350     with g.as_default():
    351 
    352       @function.Defun(compiled=True)
    353       def Bar(x, y):
    354         return x + 2 * y
    355 
    356       @function.Defun(compiled=True)
    357       def Foo(x):
    358         return Bar(x * x, x * x * x)
    359 
    360       @function.Defun()
    361       def Entry(x):
    362         return Foo(x)
    363 
    364       inp = array_ops.placeholder(dtypes.float32)
    365       out = Entry(inp)
    366 
    367     with self.test_session(graph=g, use_gpu=True) as sess:
    368       run_metadata = config_pb2.RunMetadata()
    369       val = sess.run(out,
    370                      feed_dict={inp: [2., 10.]},
    371                      run_metadata=run_metadata,
    372                      options=config_pb2.RunOptions(
    373                          trace_level=config_pb2.RunOptions.FULL_TRACE))
    374       self.assertAllClose(val, [20., 2100.])
    375 
    376   def testLoopDeadlock(self):
    377     """Regression test for bug that caused deadlocks in graphs with loops."""
    378 
    379     with self.test_session() as session:
    380       x = array_ops.placeholder(dtypes.float32)
    381       with jit_scope():
    382         y = x + 1.0
    383         c = lambda i, _x, _y: math_ops.less(i, 5)
    384         b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
    385         _, _, w = control_flow_ops.while_loop(c, b,
    386                                               (constant_op.constant(0), y, x))
    387         u = w + y
    388       result = session.run(u, {x: np.float32(2)})
    389       self.assertAllClose(result, np.float32(63), rtol=1e-1)
    390 
    391   def testGradient(self):
    392     """Tests that the backprop function is properly compiled."""
    393 
    394     def _Run(compiled):
    395 
    396       @function.Defun(compiled=compiled)
    397       def Forward(x):
    398         return math_ops.log(x)
    399 
    400       g = ops.Graph()
    401       with g.as_default():
    402         x = array_ops.placeholder(dtypes.float32)
    403         y = Forward(x)
    404         dx, = gradients_impl.gradients(y, [x], 1.0)
    405 
    406       cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
    407           optimizer_options=config_pb2.OptimizerOptions(
    408               opt_level=config_pb2.OptimizerOptions.L1,
    409               do_function_inlining=True)))
    410       with session_lib.Session(graph=g, config=cfg) as sess:
    411         run_metadata = config_pb2.RunMetadata()
    412         dx_val = sess.run(dx,
    413                           feed_dict={x: 100.},
    414                           run_metadata=run_metadata,
    415                           options=config_pb2.RunOptions(
    416                               trace_level=config_pb2.RunOptions.FULL_TRACE))
    417       self.assertAllClose(dx_val, 0.01)
    418       return RunMetadataLabels(run_metadata)
    419 
    420     # SymGrad[f=log(x)](x, dy) = 1/x * dy
    421     #
    422     # Note: we don't need to compute log(x) for dx due to graph pruning.
    423 
    424     # Do not compile the backprop. We should see one Reciprocal and one Mul.
    425     labels = _Run(compiled=False)
    426     self.assertFalse(InLabels(labels, "Log"))
    427     self.assertTrue(InLabels(labels, "Reciprocal"))
    428     self.assertTrue(InLabels(labels, "Mul"))
    429     self.assertFalse(InLabels(labels, "_XlaLaunch"))
    430 
    431     # Compile the backprop. One _XlaLaunch.
    432     labels = _Run(compiled=True)
    433     self.assertFalse(InLabels(labels, "Log"))
    434     self.assertFalse(InLabels(labels, "Reciprocal"))
    435     self.assertFalse(InLabels(labels, "Mul"))
    436     self.assertTrue(InLabels(labels, "_XlaLaunch"))
    437 
    438 
    439 if __name__ == "__main__":
    440   test.main()
    441