1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for JIT compilation on the CPU and GPU devices.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.compiler import jit 24 from tensorflow.core.protobuf import config_pb2 25 from tensorflow.python.client import session as session_lib 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import function 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import control_flow_ops 32 from tensorflow.python.ops import gradients_impl 33 from tensorflow.python.ops import math_ops 34 from tensorflow.python.ops import nn_ops 35 from tensorflow.python.platform import test 36 37 jit_scope = jit.experimental_jit_scope 38 39 40 def CompiledKernel(fn, *inputs, **kwargs): 41 """Execute 'fn' as a compiled XLA kernel, with 'inputs'.""" 42 name = kwargs.pop("name", None) 43 noinline = kwargs.pop("noinline", None) 44 45 @function.Defun(func_name=name, noinline=noinline, compiled=True) 46 def Compiled(*args): 47 return fn(*args) 48 49 return Compiled(*inputs) 50 51 52 def RunMetadataLabels(run_metadata): 53 """Returns all labels in run_metadata.""" 54 labels = [] 55 for dev_stats in run_metadata.step_stats.dev_stats: 56 for node_stats in dev_stats.node_stats: 57 labels.append(node_stats.timeline_label) 58 return labels 59 60 61 def InLabels(labels, substr): 62 """Returns true iff one of the labels contains substr.""" 63 return any([substr in x for x in labels]) 64 65 66 def MetadataHasXlaLaunch(run_metadata): 67 """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" 68 69 # TODO(phawkins): find a less hacky way to test whether a kernel ran. 70 return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") 71 72 73 class JitLaunchTest(test.TestCase): 74 75 # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel. 76 # Verifies that the outputs match and that XLA was invoked. 'fn' must take 77 # the same number of tensors as arguments that are in 'args', and must return 78 # a tuple of output tensors. 79 # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node 80 # actually ran. However, it is sometimes possible for _XlaLaunch ops to be 81 # constant-folded away, so the check is optional. 82 def _compare(self, fn, args, require_kernel_launch=True, noinline=None): 83 with session_lib.Session() as sess: 84 placeholders = [] 85 feeds = {} 86 for arg in args: 87 placeholder = array_ops.placeholder( 88 dtypes.as_dtype(arg.dtype), list(arg.shape)) 89 placeholders.append(placeholder) 90 feeds[placeholder] = arg 91 92 compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline) 93 direct_op = fn(*placeholders) 94 95 run_metadata = config_pb2.RunMetadata() 96 compiled = sess.run(compiled_op, 97 feeds, 98 run_metadata=run_metadata, 99 options=config_pb2.RunOptions( 100 trace_level=config_pb2.RunOptions.FULL_TRACE)) 101 print("Compiled Result {}".format(compiled)) 102 103 if require_kernel_launch: 104 self.assert_(MetadataHasXlaLaunch(run_metadata)) 105 106 direct = sess.run(direct_op, feeds) 107 print("Direct Result {}".format(direct)) 108 109 if (isinstance(compiled, (tuple, list)) and 110 (isinstance(direct, (tuple, list)))): 111 for (x, y) in zip(compiled, direct): 112 self.assertAllClose(x, y, rtol=1e-1) 113 else: 114 self.assertAllClose(compiled, direct) 115 116 def testNoOutputs(self): 117 with session_lib.Session() as sess: 118 119 # Check that calling the result as a compiled kernel doesn't crash. 120 @function.Defun(compiled=True) 121 def KernelWithNoOutputs(): 122 a = constant_op.constant(100) # pylint: disable=unused-variable 123 124 call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return 125 sess.run(call, {}) 126 127 def testAliasing(self): 128 """Regression test for compiled functions that return an aliased buffer. 129 130 XLA returns aliased buffers if outputs are identical. Tests that 131 we handle that case. 132 """ 133 134 def AddOnceReturnTwice(x): 135 y = math_ops.add(x, x) 136 return y, y 137 138 # Exercises compling a function (say, Foo) which calls another 139 # function (say, Bar) which is not inlined. When the compiler compiles 140 # Foo, it needs to symbolic execute Bar correctly regardless whether 141 # Bar is inlined or not. 142 143 # TODO(b/36139787): Re-enable this test when noinline works again. 144 # Tests compiled=True and noinline=True. 145 # self._compare( 146 # AddOnceReturnTwice, [np.array( 147 # [[[0.5, -1.0]]], dtype=np.float32)], 148 # noinline=True) 149 150 # Tests compiled=True and noinline=False. 151 self._compare( 152 AddOnceReturnTwice, [np.array( 153 [[[0.5, -1.0]]], dtype=np.float32)], 154 noinline=False) 155 156 def testOneConstOutput(self): 157 """Test consisting of a single constant return value.""" 158 159 def OneConstOutput(): 160 return constant_op.constant([-3, 44, 99]) 161 162 self._compare(OneConstOutput, [], require_kernel_launch=False) 163 164 def testConstZeroElementOutput(self): 165 """Test consisting of a constant zero element return value.""" 166 167 def ConstZeroElementOutput(): 168 return array_ops.fill([7, 0], 3.0) 169 170 self._compare(ConstZeroElementOutput, [], require_kernel_launch=False) 171 172 def testSomeConstOutputs(self): 173 """Test kernels that return a mixture of const and non-const outputs.""" 174 175 def SomeConstOutputs(x): 176 return constant_op.constant( 177 [-2, 7]), array_ops.identity(x), constant_op.constant(3.5) 178 179 self._compare( 180 SomeConstOutputs, [np.array( 181 [[1, 2, 3], [4, 5, 6]], dtype=np.float32)]) 182 183 def testInt32Input(self): 184 """Test an int32-typed input. 185 186 On a GPU, int32 tensors will be placed in host memory. 187 """ 188 189 def AddToSelf(x): 190 return math_ops.add(x, x) 191 192 self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)]) 193 194 def testMandatoryConstantInput(self): 195 """Tests an operator that has a mandatory-constant shape input.""" 196 197 def FillWithFloat(x): 198 return array_ops.fill(x, 9.5) 199 200 self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)]) 201 202 def testMnistForwardFunc(self): 203 """Compute inference function from MNIST beginners tutorial.""" 204 batch_size = 16 205 image_size = 28 * 28 206 num_classes = 10 207 208 # Define a TensorFlow function to compute the forward pass. 209 def MnistForward(w, b, x): 210 return nn_ops.softmax(math_ops.matmul(x, w) + b) 211 212 w = np.random.random_sample((image_size, num_classes)).astype(np.float32) 213 b = np.random.random_sample((num_classes)).astype(np.float32) 214 x = np.random.random_sample((batch_size, image_size)).astype(np.float32) 215 self._compare(MnistForward, [w, b, x]) 216 217 def testExplicitMarking(self): 218 """Test explicit marking of operators to compile.""" 219 batch_size = 16 220 image_size = 28 * 28 221 num_classes = 10 222 223 with ops.Graph().as_default(): 224 x = array_ops.placeholder(dtypes.float32) 225 w = array_ops.placeholder(dtypes.float32) 226 b = array_ops.placeholder(dtypes.float32) 227 with jit_scope(): 228 y1 = math_ops.matmul(x, w) 229 y2 = math_ops.add(y1, b) 230 with jit_scope(): 231 y = math_ops.square(y2) 232 233 dw = np.random.random_sample((image_size, num_classes)).astype(np.float32) 234 db = np.random.random_sample((num_classes)).astype(np.float32) 235 dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) 236 with session_lib.Session() as sess: 237 run_metadata = config_pb2.RunMetadata() 238 output = sess.run(y, {x: dx, 239 w: dw, 240 b: db}, 241 run_metadata=run_metadata, 242 options=config_pb2.RunOptions( 243 trace_level=config_pb2.RunOptions.FULL_TRACE)) 244 245 # TODO(phawkins): really we would like to test that there were exactly 246 # two kernel launches. However, we have no reliable way to determine 247 # that. 248 self.assert_(MetadataHasXlaLaunch(run_metadata)) 249 250 expected = np.square(np.dot(dx, dw) + db) 251 self.assertAllClose(expected, output, rtol=1e-1) 252 253 254 class XlaCompilationTest(test.TestCase): 255 """Tests for auto-compilation on CPU/GPU devices.""" 256 257 def testReshape(self): 258 """Tests an operator with compile-time constant and non-constant inputs.""" 259 260 with self.test_session() as sess: 261 x = array_ops.placeholder(dtypes.float32) 262 y = array_ops.placeholder(dtypes.int32) 263 with jit_scope(): 264 # Reshape's first argument is non-constant in the JIT, but its second 265 # (shape) argument will be treated as a compile-time constant for 266 # each JIT compilation. 267 # We do not use a tf.const() argument since we want to ensure the 268 # shape is still a run-time argument to the JIT, and not 269 # statically known as part of the JIT compilation's input graph. 270 z = array_ops.reshape(x, y) 271 run_metadata = config_pb2.RunMetadata() 272 out = sess.run(z, 273 {x: np.array([1, 2, 3, 4, 5, 6], np.float32), 274 y: [-1, 3]}, 275 run_metadata=run_metadata, 276 options=config_pb2.RunOptions( 277 trace_level=config_pb2.RunOptions.FULL_TRACE)) 278 self.assert_(MetadataHasXlaLaunch(run_metadata)) 279 self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) 280 281 def testIgnoredArguments(self): 282 """Tests that JIT computations can ignore formal parameters.""" 283 284 with self.test_session() as sess: 285 x = array_ops.placeholder(dtypes.int32) 286 y = array_ops.placeholder(dtypes.int32) 287 with jit_scope(): 288 z = math_ops.add(x, x) 289 w = math_ops.add(y, y) 290 # Pulls 'w' into the same compilation via control dependencies. 291 with ops.control_dependencies([w]): 292 n = control_flow_ops.no_op() 293 with ops.control_dependencies([n]): 294 t = math_ops.add(z, z) 295 296 run_metadata = config_pb2.RunMetadata() 297 out = sess.run(t, {x: np.int32(7), 298 y: np.int32(404)}, 299 run_metadata=run_metadata, 300 options=config_pb2.RunOptions( 301 trace_level=config_pb2.RunOptions.FULL_TRACE)) 302 self.assert_(MetadataHasXlaLaunch(run_metadata)) 303 self.assertAllClose(28, out) 304 305 def testLoops(self): 306 """Tests that compilation accepts computations containing loops.""" 307 308 with self.test_session() as session: 309 x = array_ops.placeholder(dtypes.float32) 310 with jit_scope(): 311 c = lambda i, _: math_ops.less(i, 5) 312 b = lambda i, x: (i + 1, x * 2.0 + 1.0) 313 _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) 314 315 run_metadata = config_pb2.RunMetadata() 316 result = session.run(y, {x: np.float32(2)}, 317 run_metadata=run_metadata, 318 options=config_pb2.RunOptions( 319 trace_level=config_pb2.RunOptions.FULL_TRACE)) 320 self.assert_(MetadataHasXlaLaunch(run_metadata)) 321 self.assertAllClose(result, np.float32(95), rtol=1e-1) 322 323 def testCond(self): 324 """Tests that compilation handles switch operators.""" 325 326 with self.test_session() as session: 327 x = array_ops.placeholder(dtypes.float32) 328 y = array_ops.placeholder(dtypes.float32) 329 c = array_ops.placeholder(dtypes.bool) 330 with jit_scope(): 331 z = x + 1.0 332 w = control_flow_ops.cond(c, lambda: z, lambda: y) 333 t = math_ops.add(z, w) 334 335 # If JIT compilation chooses to cluster z and t, then execution will 336 # deadlock. 337 338 run_metadata = config_pb2.RunMetadata() 339 result = session.run(t, {x: np.float32(2), 340 y: np.float32(4), 341 c: True}, 342 run_metadata=run_metadata, 343 options=config_pb2.RunOptions( 344 trace_level=config_pb2.RunOptions.FULL_TRACE)) 345 self.assert_(MetadataHasXlaLaunch(run_metadata)) 346 self.assertAllClose(result, np.float32(6), rtol=1e-1) 347 348 def testNestedFunction(self): 349 g = ops.Graph() 350 with g.as_default(): 351 352 @function.Defun(compiled=True) 353 def Bar(x, y): 354 return x + 2 * y 355 356 @function.Defun(compiled=True) 357 def Foo(x): 358 return Bar(x * x, x * x * x) 359 360 @function.Defun() 361 def Entry(x): 362 return Foo(x) 363 364 inp = array_ops.placeholder(dtypes.float32) 365 out = Entry(inp) 366 367 with self.test_session(graph=g, use_gpu=True) as sess: 368 run_metadata = config_pb2.RunMetadata() 369 val = sess.run(out, 370 feed_dict={inp: [2., 10.]}, 371 run_metadata=run_metadata, 372 options=config_pb2.RunOptions( 373 trace_level=config_pb2.RunOptions.FULL_TRACE)) 374 self.assertAllClose(val, [20., 2100.]) 375 376 def testLoopDeadlock(self): 377 """Regression test for bug that caused deadlocks in graphs with loops.""" 378 379 with self.test_session() as session: 380 x = array_ops.placeholder(dtypes.float32) 381 with jit_scope(): 382 y = x + 1.0 383 c = lambda i, _x, _y: math_ops.less(i, 5) 384 b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0) 385 _, _, w = control_flow_ops.while_loop(c, b, 386 (constant_op.constant(0), y, x)) 387 u = w + y 388 result = session.run(u, {x: np.float32(2)}) 389 self.assertAllClose(result, np.float32(63), rtol=1e-1) 390 391 def testGradient(self): 392 """Tests that the backprop function is properly compiled.""" 393 394 def _Run(compiled): 395 396 @function.Defun(compiled=compiled) 397 def Forward(x): 398 return math_ops.log(x) 399 400 g = ops.Graph() 401 with g.as_default(): 402 x = array_ops.placeholder(dtypes.float32) 403 y = Forward(x) 404 dx, = gradients_impl.gradients(y, [x], 1.0) 405 406 cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 407 optimizer_options=config_pb2.OptimizerOptions( 408 opt_level=config_pb2.OptimizerOptions.L1, 409 do_function_inlining=True))) 410 with session_lib.Session(graph=g, config=cfg) as sess: 411 run_metadata = config_pb2.RunMetadata() 412 dx_val = sess.run(dx, 413 feed_dict={x: 100.}, 414 run_metadata=run_metadata, 415 options=config_pb2.RunOptions( 416 trace_level=config_pb2.RunOptions.FULL_TRACE)) 417 self.assertAllClose(dx_val, 0.01) 418 return RunMetadataLabels(run_metadata) 419 420 # SymGrad[f=log(x)](x, dy) = 1/x * dy 421 # 422 # Note: we don't need to compute log(x) for dx due to graph pruning. 423 424 # Do not compile the backprop. We should see one Reciprocal and one Mul. 425 labels = _Run(compiled=False) 426 self.assertFalse(InLabels(labels, "Log")) 427 self.assertTrue(InLabels(labels, "Reciprocal")) 428 self.assertTrue(InLabels(labels, "Mul")) 429 self.assertFalse(InLabels(labels, "_XlaLaunch")) 430 431 # Compile the backprop. One _XlaLaunch. 432 labels = _Run(compiled=True) 433 self.assertFalse(InLabels(labels, "Log")) 434 self.assertFalse(InLabels(labels, "Reciprocal")) 435 self.assertFalse(InLabels(labels, "Mul")) 436 self.assertTrue(InLabels(labels, "_XlaLaunch")) 437 438 439 if __name__ == "__main__": 440 test.main() 441