Home | History | Annotate | Download | only in compiler
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for contrib.compiler.jit."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.compiler import jit
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import function
     24 from tensorflow.python.framework import op_def_registry
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import random_seed
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import gradients
     29 from tensorflow.python.ops import init_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.ops import variable_scope
     33 from tensorflow.python.ops import variables
     34 from tensorflow.python.platform import test
     35 # pylint: enable=g-import-not-at-top
     36 
     37 
     38 _REGISTERED_OPS = op_def_registry.get_registered_ops()
     39 
     40 
     41 def enable_jit_nonstateful(node_def):
     42   try:
     43     return not _REGISTERED_OPS[node_def.op].is_stateful
     44   except KeyError:
     45     raise ValueError("Unregistered op being created: %s" % node_def)
     46 
     47 
     48 class JITTest(test.TestCase):
     49 
     50   def compute(self, use_jit, compute_fn):
     51     random_seed.set_random_seed(1234)
     52     with self.test_session(graph=ops.Graph()) as sess:
     53       with jit.experimental_jit_scope(use_jit):
     54         r = compute_fn()
     55       sess.run(variables.global_variables_initializer())
     56       return (r, sess.run(r))
     57 
     58   def testJITCreateOpsLambda(self):
     59     """Test several ways of customizing the compilation attribute."""
     60     def create_ops():
     61       with variable_scope.variable_scope(
     62           "root",
     63           initializer=init_ops.random_uniform_initializer(
     64               -0.1, 0.1, seed=2)):
     65         inputs = random_ops.random_uniform((1,), seed=1)
     66         return inputs
     67     v_false_1_t, v_false_1 = self.compute(False, create_ops)
     68     _, v_false_2 = self.compute(False, create_ops)
     69     v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
     70     _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
     71     v_all_true_t, _ = self.compute(True, create_ops)
     72     self.assertFalse(v_false_1_t.op.get_attr("_XlaCompile"))
     73     v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name(
     74         "root/random_uniform/RandomUniform")
     75     v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name(
     76         "root/random_uniform/RandomUniform")
     77 
     78     self.assertFalse(v_true_1_t_sampler_op.get_attr("_XlaCompile"))
     79     self.assertTrue(v_all_true_t_sampler_op.get_attr("_XlaCompile"))
     80 
     81     self.assertTrue(v_true_1_t.op.get_attr("_XlaCompile"))
     82     self.assertTrue(v_all_true_t.op.get_attr("_XlaCompile"))
     83 
     84     # Additionally ensure that where no JIT compilation happens on the
     85     # random_uniform op, the output values are identical to the case
     86     # where no JIT compilation happens anywhere.
     87     self.assertAllClose(v_false_1, v_false_2)
     88     self.assertAllClose(v_true_1, v_true_2)
     89     self.assertAllClose(v_false_1, v_true_1)
     90 
     91   def testJITXlaScope(self):
     92     with self.test_session(graph=ops.Graph()):
     93       with jit.experimental_jit_scope(True):
     94         # XlaScope 0
     95         a1 = constant_op.constant(1)
     96       with jit.experimental_jit_scope(True):
     97         # XlaScope 1
     98         a2 = constant_op.constant(1)
     99         with jit.experimental_jit_scope(True):
    100           # XlaScope still 1, depth 1
    101           a3 = constant_op.constant(1)
    102           with jit.experimental_jit_scope(True):
    103             # XlaScope still 1, depth 2
    104             a4 = constant_op.constant(1)
    105           # XlaScope still 1, depth 1
    106           a5 = constant_op.constant(1)
    107       with jit.experimental_jit_scope(True):
    108         # XlaScope now 2, depth 0
    109         a6 = constant_op.constant(1)
    110 
    111     self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
    112     self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
    113     self.assertEqual(b"jit_scope_1", a3.op.get_attr("_XlaScope"))
    114     self.assertEqual(b"jit_scope_1", a4.op.get_attr("_XlaScope"))
    115     self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope"))
    116     self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope"))
    117 
    118   def testJITVariableSeed(self):
    119     """Test that the stateful initializer is not marked for compilation.
    120 
    121     XLA does not currently support seeded initialization and XLA initializers
    122     therefore return different values than non-XLA counterparts.  Here
    123     we ensure that if we can disable JIT compilation for the initializers and
    124     get the same variable values as if no JIT compilation happened.
    125     """
    126     def create_ops():
    127       with variable_scope.variable_scope(
    128           "root",
    129           initializer=init_ops.random_uniform_initializer(
    130               -0.1, 0.1, seed=2)):
    131         inputs = variable_scope.get_variable("var", (1,))
    132         return inputs
    133     _, v_false_1 = self.compute(False, create_ops)
    134     _, v_false_2 = self.compute(False, create_ops)
    135     _, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
    136     _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
    137     self.assertAllClose(v_false_1, v_false_2)
    138     self.assertAllClose(v_true_1, v_true_2)
    139     self.assertAllClose(v_false_1, v_true_1)
    140 
    141   def testDefunNoJitScope(self):
    142     with self.test_session(graph=ops.Graph()):
    143       @function.Defun(compiled=True, noinline=True)
    144       def mulop(x1, x2):
    145         return x1 * x2
    146       x = constant_op.constant(1.0)
    147       r = mulop(x, x)
    148 
    149       # Ensure the forward function is compiled.
    150       graph_def = r.graph.as_graph_def()
    151       func_attrs = graph_def.library.function[0].attr
    152       self.assertTrue(func_attrs["_XlaCompile"].b)
    153       # No enclosing jit scope so function sets its own value for _XlaScope.
    154       self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
    155 
    156   def testDefunInheritsJitScope(self):
    157     with self.test_session(graph=ops.Graph()):
    158       with jit.experimental_jit_scope(True):
    159         @function.Defun(compiled=True, noinline=True)
    160         def mulop(x1, x2):
    161           return x1 * x2
    162         x = constant_op.constant(1.0)
    163         r = mulop(x, x)
    164 
    165       # Ensure the forward function is compiled.
    166       graph_def = r.graph.as_graph_def()
    167       func_attrs = graph_def.library.function[0].attr
    168       self.assertTrue(func_attrs["_XlaCompile"].b)
    169       # Ensure _XlaScope is inherited from enclosing context.
    170       self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
    171 
    172 
    173 @test_util.with_c_api
    174 class CompilationEnabledInGradientTest(test.TestCase):
    175 
    176   def testCompilationInGradient(self):
    177     with self.test_session():
    178       x = constant_op.constant([[3]])
    179       y_nc = math_ops.matmul(x, x, name="not_compiled")
    180       with jit.experimental_jit_scope():
    181         y_c = math_ops.matmul(y_nc, y_nc, name="compiled")
    182       x_grads = gradients.gradients([y_c], [x])[0]
    183       operations = x.graph.get_operations()
    184       c_grad_ops = [
    185           op for op in operations if "gradients/compiled" in op.name]
    186       nc_grad_ops = [
    187           op for op in operations if "gradients/not_compiled" in op.name]
    188       self.assertGreater(len(c_grad_ops), 0)
    189       self.assertGreater(len(nc_grad_ops), 0)
    190       for cg in c_grad_ops:
    191         self.assertTrue(cg.get_attr("_XlaCompile"))
    192       for ncg in nc_grad_ops:
    193         with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"):
    194           ncg.get_attr("_XlaCompile")
    195 
    196       # d/dx (x ** 4) = 4 * (x ** 3)
    197       self.assertAllClose([[108]], x_grads.eval())
    198 
    199   def testCompilationGradientScopeNames(self):
    200     with self.test_session(graph=ops.Graph()):
    201       with jit.experimental_jit_scope():
    202         # XlaScope 0
    203         a1 = constant_op.constant([[1]])
    204         a1t = math_ops.matmul(a1, a1)
    205       with jit.experimental_jit_scope():
    206         # XlaScope 1
    207         a2 = constant_op.constant([[1]])
    208         a2t = math_ops.matmul(a2, a2)
    209 
    210       self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
    211       self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
    212       grad_a1 = gradients.gradients(a1t, a1, name="GA")[0]
    213       grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
    214       grad_a1 = grad_a1.op.inputs[0]
    215       grad_a2 = grad_a2.op.inputs[0]
    216       self.assertTrue(grad_a1.op.get_attr("_XlaCompile"))
    217       self.assertTrue(grad_a2.op.get_attr("_XlaCompile"))
    218       self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope"))
    219       self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope"))
    220 
    221   def testCompilationSeparateGradientScopeNames(self):
    222     with self.test_session(graph=ops.Graph()):
    223       with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
    224         # XlaScope 0
    225         a1 = constant_op.constant([[1]])
    226         a1t = math_ops.matmul(a1, a1)
    227       with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
    228         # XlaScope 1
    229         a2 = constant_op.constant([[1]])
    230         a2t = math_ops.matmul(a2, a2)
    231 
    232       self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
    233       self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
    234       grad_a1 = gradients.gradients(a1t, a1, name="GA")[0]
    235       grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
    236       grad_a1 = grad_a1.op.inputs[0]
    237       grad_a2 = grad_a2.op.inputs[0]
    238       self.assertTrue(grad_a1.op.get_attr("_XlaCompile"))
    239       self.assertTrue(grad_a2.op.get_attr("_XlaCompile"))
    240       self.assertEqual(b"jit_scope_0_grad_GA",
    241                        grad_a1.op.get_attr("_XlaScope"))
    242       self.assertEqual(b"jit_scope_1_grad_GB",
    243                        grad_a2.op.get_attr("_XlaScope"))
    244 
    245   def testPlaysNicelyWithDefun(self):
    246     with self.test_session(graph=ops.Graph()) as sess:
    247       with jit.experimental_jit_scope(True):
    248         @function.Defun(compiled=True, noinline=True)
    249         def mulop(x1, x2):
    250           return x1 * x2
    251         x = constant_op.constant(1.0)
    252         r = mulop(x, x)
    253         g_r = gradients.gradients(r, x, name="GA")[0]
    254 
    255       # Ensure the forward function is compiled.
    256       graph_def = r.graph.as_graph_def()
    257       func_attrs = graph_def.library.function[0].attr
    258       self.assertTrue(func_attrs["_XlaCompile"].b)
    259       self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
    260 
    261       # Ensure the gradient (SymbolicGradient) is compiled, with the same
    262       # _XlaScope as the function itself.
    263       grad_op = g_r.op.inputs[0].op
    264       self.assertTrue(grad_op.get_attr("_XlaCompile"))
    265       self.assertEqual(b"jit_scope_0", grad_op.get_attr("_XlaScope"))
    266 
    267       # Ensure the ops run: grad(x1*x1) = 2*x1
    268       self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
    269 
    270   def testPlaysNicelyWithDefunSeparateGradientScope(self):
    271     with self.test_session(graph=ops.Graph()) as sess:
    272       with jit.experimental_jit_scope(True):
    273 
    274         @function.Defun(
    275             compiled=True, noinline=True, separate_compiled_gradients=True)
    276         def mulop(x1, x2):
    277           return x1 * x2
    278 
    279         x = constant_op.constant(1.0)
    280         r = mulop(x, x)
    281         g_r = gradients.gradients(r, x, name="GA")[0]
    282 
    283       # Ensure the forward function is compiled.
    284       graph_def = r.graph.as_graph_def()
    285       func_attrs = graph_def.library.function[0].attr
    286       self.assertTrue(func_attrs["_XlaCompile"].b)
    287       self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
    288 
    289       # Ensure the gradient (SymbolicGradient) is compiled, with a different
    290       # _XlaScope from the function itself.
    291       grad_op = g_r.op.inputs[0].op
    292       self.assertTrue(grad_op.get_attr("_XlaCompile"))
    293       self.assertEqual(b"jit_scope_0_grad_GA",
    294                        grad_op.get_attr("_XlaScope"))
    295 
    296       # Ensure the ops run: grad(x1*x1) = 2*x1
    297       self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
    298 
    299 
    300 if __name__ == "__main__":
    301   test.main()
    302