1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for contrib.compiler.jit.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.compiler import jit 22 from tensorflow.python.framework import constant_op 23 from tensorflow.python.framework import function 24 from tensorflow.python.framework import op_def_registry 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import random_seed 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.ops import gradients 29 from tensorflow.python.ops import init_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import random_ops 32 from tensorflow.python.ops import variable_scope 33 from tensorflow.python.ops import variables 34 from tensorflow.python.platform import test 35 # pylint: enable=g-import-not-at-top 36 37 38 _REGISTERED_OPS = op_def_registry.get_registered_ops() 39 40 41 def enable_jit_nonstateful(node_def): 42 try: 43 return not _REGISTERED_OPS[node_def.op].is_stateful 44 except KeyError: 45 raise ValueError("Unregistered op being created: %s" % node_def) 46 47 48 class JITTest(test.TestCase): 49 50 def compute(self, use_jit, compute_fn): 51 random_seed.set_random_seed(1234) 52 with self.test_session(graph=ops.Graph()) as sess: 53 with jit.experimental_jit_scope(use_jit): 54 r = compute_fn() 55 sess.run(variables.global_variables_initializer()) 56 return (r, sess.run(r)) 57 58 def testJITCreateOpsLambda(self): 59 """Test several ways of customizing the compilation attribute.""" 60 def create_ops(): 61 with variable_scope.variable_scope( 62 "root", 63 initializer=init_ops.random_uniform_initializer( 64 -0.1, 0.1, seed=2)): 65 inputs = random_ops.random_uniform((1,), seed=1) 66 return inputs 67 v_false_1_t, v_false_1 = self.compute(False, create_ops) 68 _, v_false_2 = self.compute(False, create_ops) 69 v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops) 70 _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops) 71 v_all_true_t, _ = self.compute(True, create_ops) 72 self.assertFalse(v_false_1_t.op.get_attr("_XlaCompile")) 73 v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name( 74 "root/random_uniform/RandomUniform") 75 v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name( 76 "root/random_uniform/RandomUniform") 77 78 self.assertFalse(v_true_1_t_sampler_op.get_attr("_XlaCompile")) 79 self.assertTrue(v_all_true_t_sampler_op.get_attr("_XlaCompile")) 80 81 self.assertTrue(v_true_1_t.op.get_attr("_XlaCompile")) 82 self.assertTrue(v_all_true_t.op.get_attr("_XlaCompile")) 83 84 # Additionally ensure that where no JIT compilation happens on the 85 # random_uniform op, the output values are identical to the case 86 # where no JIT compilation happens anywhere. 87 self.assertAllClose(v_false_1, v_false_2) 88 self.assertAllClose(v_true_1, v_true_2) 89 self.assertAllClose(v_false_1, v_true_1) 90 91 def testJITXlaScope(self): 92 with self.test_session(graph=ops.Graph()): 93 with jit.experimental_jit_scope(True): 94 # XlaScope 0 95 a1 = constant_op.constant(1) 96 with jit.experimental_jit_scope(True): 97 # XlaScope 1 98 a2 = constant_op.constant(1) 99 with jit.experimental_jit_scope(True): 100 # XlaScope still 1, depth 1 101 a3 = constant_op.constant(1) 102 with jit.experimental_jit_scope(True): 103 # XlaScope still 1, depth 2 104 a4 = constant_op.constant(1) 105 # XlaScope still 1, depth 1 106 a5 = constant_op.constant(1) 107 with jit.experimental_jit_scope(True): 108 # XlaScope now 2, depth 0 109 a6 = constant_op.constant(1) 110 111 self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) 112 self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) 113 self.assertEqual(b"jit_scope_1", a3.op.get_attr("_XlaScope")) 114 self.assertEqual(b"jit_scope_1", a4.op.get_attr("_XlaScope")) 115 self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope")) 116 self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope")) 117 118 def testJITVariableSeed(self): 119 """Test that the stateful initializer is not marked for compilation. 120 121 XLA does not currently support seeded initialization and XLA initializers 122 therefore return different values than non-XLA counterparts. Here 123 we ensure that if we can disable JIT compilation for the initializers and 124 get the same variable values as if no JIT compilation happened. 125 """ 126 def create_ops(): 127 with variable_scope.variable_scope( 128 "root", 129 initializer=init_ops.random_uniform_initializer( 130 -0.1, 0.1, seed=2)): 131 inputs = variable_scope.get_variable("var", (1,)) 132 return inputs 133 _, v_false_1 = self.compute(False, create_ops) 134 _, v_false_2 = self.compute(False, create_ops) 135 _, v_true_1 = self.compute(enable_jit_nonstateful, create_ops) 136 _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops) 137 self.assertAllClose(v_false_1, v_false_2) 138 self.assertAllClose(v_true_1, v_true_2) 139 self.assertAllClose(v_false_1, v_true_1) 140 141 def testDefunNoJitScope(self): 142 with self.test_session(graph=ops.Graph()): 143 @function.Defun(compiled=True, noinline=True) 144 def mulop(x1, x2): 145 return x1 * x2 146 x = constant_op.constant(1.0) 147 r = mulop(x, x) 148 149 # Ensure the forward function is compiled. 150 graph_def = r.graph.as_graph_def() 151 func_attrs = graph_def.library.function[0].attr 152 self.assertTrue(func_attrs["_XlaCompile"].b) 153 # No enclosing jit scope so function sets its own value for _XlaScope. 154 self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s) 155 156 def testDefunInheritsJitScope(self): 157 with self.test_session(graph=ops.Graph()): 158 with jit.experimental_jit_scope(True): 159 @function.Defun(compiled=True, noinline=True) 160 def mulop(x1, x2): 161 return x1 * x2 162 x = constant_op.constant(1.0) 163 r = mulop(x, x) 164 165 # Ensure the forward function is compiled. 166 graph_def = r.graph.as_graph_def() 167 func_attrs = graph_def.library.function[0].attr 168 self.assertTrue(func_attrs["_XlaCompile"].b) 169 # Ensure _XlaScope is inherited from enclosing context. 170 self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) 171 172 173 @test_util.with_c_api 174 class CompilationEnabledInGradientTest(test.TestCase): 175 176 def testCompilationInGradient(self): 177 with self.test_session(): 178 x = constant_op.constant([[3]]) 179 y_nc = math_ops.matmul(x, x, name="not_compiled") 180 with jit.experimental_jit_scope(): 181 y_c = math_ops.matmul(y_nc, y_nc, name="compiled") 182 x_grads = gradients.gradients([y_c], [x])[0] 183 operations = x.graph.get_operations() 184 c_grad_ops = [ 185 op for op in operations if "gradients/compiled" in op.name] 186 nc_grad_ops = [ 187 op for op in operations if "gradients/not_compiled" in op.name] 188 self.assertGreater(len(c_grad_ops), 0) 189 self.assertGreater(len(nc_grad_ops), 0) 190 for cg in c_grad_ops: 191 self.assertTrue(cg.get_attr("_XlaCompile")) 192 for ncg in nc_grad_ops: 193 with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"): 194 ncg.get_attr("_XlaCompile") 195 196 # d/dx (x ** 4) = 4 * (x ** 3) 197 self.assertAllClose([[108]], x_grads.eval()) 198 199 def testCompilationGradientScopeNames(self): 200 with self.test_session(graph=ops.Graph()): 201 with jit.experimental_jit_scope(): 202 # XlaScope 0 203 a1 = constant_op.constant([[1]]) 204 a1t = math_ops.matmul(a1, a1) 205 with jit.experimental_jit_scope(): 206 # XlaScope 1 207 a2 = constant_op.constant([[1]]) 208 a2t = math_ops.matmul(a2, a2) 209 210 self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) 211 self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) 212 grad_a1 = gradients.gradients(a1t, a1, name="GA")[0] 213 grad_a2 = gradients.gradients(a2t, a2, name="GB")[0] 214 grad_a1 = grad_a1.op.inputs[0] 215 grad_a2 = grad_a2.op.inputs[0] 216 self.assertTrue(grad_a1.op.get_attr("_XlaCompile")) 217 self.assertTrue(grad_a2.op.get_attr("_XlaCompile")) 218 self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope")) 219 self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope")) 220 221 def testCompilationSeparateGradientScopeNames(self): 222 with self.test_session(graph=ops.Graph()): 223 with jit.experimental_jit_scope(True, separate_compiled_gradients=True): 224 # XlaScope 0 225 a1 = constant_op.constant([[1]]) 226 a1t = math_ops.matmul(a1, a1) 227 with jit.experimental_jit_scope(True, separate_compiled_gradients=True): 228 # XlaScope 1 229 a2 = constant_op.constant([[1]]) 230 a2t = math_ops.matmul(a2, a2) 231 232 self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) 233 self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) 234 grad_a1 = gradients.gradients(a1t, a1, name="GA")[0] 235 grad_a2 = gradients.gradients(a2t, a2, name="GB")[0] 236 grad_a1 = grad_a1.op.inputs[0] 237 grad_a2 = grad_a2.op.inputs[0] 238 self.assertTrue(grad_a1.op.get_attr("_XlaCompile")) 239 self.assertTrue(grad_a2.op.get_attr("_XlaCompile")) 240 self.assertEqual(b"jit_scope_0_grad_GA", 241 grad_a1.op.get_attr("_XlaScope")) 242 self.assertEqual(b"jit_scope_1_grad_GB", 243 grad_a2.op.get_attr("_XlaScope")) 244 245 def testPlaysNicelyWithDefun(self): 246 with self.test_session(graph=ops.Graph()) as sess: 247 with jit.experimental_jit_scope(True): 248 @function.Defun(compiled=True, noinline=True) 249 def mulop(x1, x2): 250 return x1 * x2 251 x = constant_op.constant(1.0) 252 r = mulop(x, x) 253 g_r = gradients.gradients(r, x, name="GA")[0] 254 255 # Ensure the forward function is compiled. 256 graph_def = r.graph.as_graph_def() 257 func_attrs = graph_def.library.function[0].attr 258 self.assertTrue(func_attrs["_XlaCompile"].b) 259 self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) 260 261 # Ensure the gradient (SymbolicGradient) is compiled, with the same 262 # _XlaScope as the function itself. 263 grad_op = g_r.op.inputs[0].op 264 self.assertTrue(grad_op.get_attr("_XlaCompile")) 265 self.assertEqual(b"jit_scope_0", grad_op.get_attr("_XlaScope")) 266 267 # Ensure the ops run: grad(x1*x1) = 2*x1 268 self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r])) 269 270 def testPlaysNicelyWithDefunSeparateGradientScope(self): 271 with self.test_session(graph=ops.Graph()) as sess: 272 with jit.experimental_jit_scope(True): 273 274 @function.Defun( 275 compiled=True, noinline=True, separate_compiled_gradients=True) 276 def mulop(x1, x2): 277 return x1 * x2 278 279 x = constant_op.constant(1.0) 280 r = mulop(x, x) 281 g_r = gradients.gradients(r, x, name="GA")[0] 282 283 # Ensure the forward function is compiled. 284 graph_def = r.graph.as_graph_def() 285 func_attrs = graph_def.library.function[0].attr 286 self.assertTrue(func_attrs["_XlaCompile"].b) 287 self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) 288 289 # Ensure the gradient (SymbolicGradient) is compiled, with a different 290 # _XlaScope from the function itself. 291 grad_op = g_r.op.inputs[0].op 292 self.assertTrue(grad_op.get_attr("_XlaCompile")) 293 self.assertEqual(b"jit_scope_0_grad_GA", 294 grad_op.get_attr("_XlaScope")) 295 296 # Ensure the ops run: grad(x1*x1) = 2*x1 297 self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r])) 298 299 300 if __name__ == "__main__": 301 test.main() 302