1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.kernels.bcast_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.core.protobuf import config_pb2 24 from tensorflow.python.client import session 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import function 28 from tensorflow.python.framework import ops 29 from tensorflow.python.framework import sparse_tensor 30 from tensorflow.python.framework import test_util 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import functional_ops 33 from tensorflow.python.ops import gradients_impl 34 from tensorflow.python.ops import init_ops 35 from tensorflow.python.ops import math_ops 36 from tensorflow.python.ops import variable_scope 37 from tensorflow.python.ops import variables 38 import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 39 from tensorflow.python.platform import test 40 41 42 def simple_scoped_fn(a, x): 43 """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope.""" 44 with variable_scope.variable_scope("body"): 45 # Dummy variable, just to check that scoping works as intended. 46 two = variable_scope.get_variable( 47 "two", [], 48 dtype=dtypes.int32, 49 initializer=init_ops.constant_initializer(2)) 50 return math_ops.multiply(math_ops.add(a, x), two) 51 52 53 class FunctionalOpsTest(test.TestCase): 54 55 @test_util.run_in_graph_and_eager_modes() 56 def testFoldl_Simple(self): 57 with self.test_session(): 58 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 59 60 r = functional_ops.foldl( 61 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 62 elems) 63 self.assertAllEqual(208, self.evaluate(r)) 64 65 r = functional_ops.foldl( 66 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 67 elems, 68 initializer=10) 69 self.assertAllEqual(880, self.evaluate(r)) 70 71 def testFoldl_Scoped(self): 72 with self.test_session() as sess: 73 with variable_scope.variable_scope("root") as varscope: 74 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 75 76 r = functional_ops.foldl(simple_scoped_fn, elems) 77 # Check that we have the one variable we asked for here. 78 self.assertEqual(len(variables.trainable_variables()), 1) 79 self.assertEqual(variables.trainable_variables()[0].name, 80 "root/body/two:0") 81 sess.run([variables.global_variables_initializer()]) 82 self.assertAllEqual(208, self.evaluate(r)) 83 84 # Now let's reuse our single variable. 85 varscope.reuse_variables() 86 r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10) 87 self.assertEqual(len(variables.trainable_variables()), 1) 88 self.assertAllEqual(880, self.evaluate(r)) 89 90 @test_util.run_in_graph_and_eager_modes() 91 def testFoldr_Simple(self): 92 with self.test_session(): 93 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 94 95 r = functional_ops.foldr( 96 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 97 elems) 98 self.assertAllEqual(450, self.evaluate(r)) 99 100 r = functional_ops.foldr( 101 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 102 elems, 103 initializer=10) 104 self.assertAllEqual(1282, self.evaluate(r)) 105 106 def testFoldr_Scoped(self): 107 with self.test_session() as sess: 108 with variable_scope.variable_scope("root") as varscope: 109 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 110 111 r = functional_ops.foldr(simple_scoped_fn, elems) 112 # Check that we have the one variable we asked for here. 113 self.assertEqual(len(variables.trainable_variables()), 1) 114 self.assertEqual(variables.trainable_variables()[0].name, 115 "root/body/two:0") 116 sess.run([variables.global_variables_initializer()]) 117 self.assertAllEqual(450, self.evaluate(r)) 118 119 # Now let's reuse our single variable. 120 varscope.reuse_variables() 121 r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10) 122 self.assertEqual(len(variables.trainable_variables()), 1) 123 self.assertAllEqual(1282, self.evaluate(r)) 124 125 # pylint: disable=unnecessary-lambda 126 def testFold_Grad(self): 127 with self.test_session(): 128 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 129 v = constant_op.constant(2.0, name="v") 130 r = functional_ops.foldl( 131 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 132 r = gradients_impl.gradients(r, v)[0] 133 self.assertAllEqual(720.0, self.evaluate(r)) 134 135 r = functional_ops.foldr( 136 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 137 r = gradients_impl.gradients(r, v)[0] 138 self.assertAllEqual(720.0, self.evaluate(r)) 139 # pylint: enable=unnecessary-lambda 140 141 @test_util.run_in_graph_and_eager_modes() 142 def testMap_Simple(self): 143 with self.test_session(): 144 nums = [1, 2, 3, 4, 5, 6] 145 elems = constant_op.constant(nums, name="data") 146 r = functional_ops.map_fn( 147 lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) 148 self.assertAllEqual( 149 np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) 150 151 def testMapSparseTensor(self): 152 with self.test_session(): 153 with self.assertRaises(TypeError): 154 functional_ops.map_fn( 155 lambda x: x, 156 sparse_tensor.SparseTensor( 157 indices=[[0, 0], [0, 1], [1, 0]], 158 values=constant_op.constant([0, 1, 2]), 159 dense_shape=[2, 2])) 160 161 def testMap_Scoped(self): 162 with self.test_session() as sess: 163 164 def double_scoped(x): 165 """2x with a dummy 2 that is scoped.""" 166 with variable_scope.variable_scope("body"): 167 # Dummy variable, just to check that scoping works as intended. 168 two = variable_scope.get_variable( 169 "two", [], 170 dtype=dtypes.int32, 171 initializer=init_ops.constant_initializer(2)) 172 return math_ops.multiply(x, two) 173 174 with variable_scope.variable_scope("root") as varscope: 175 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 176 doubles = np.array([2 * x for x in [1, 2, 3, 4, 5, 6]]) 177 178 r = functional_ops.map_fn(double_scoped, elems) 179 # Check that we have the one variable we asked for here. 180 self.assertEqual(len(variables.trainable_variables()), 1) 181 self.assertEqual(variables.trainable_variables()[0].name, 182 "root/body/two:0") 183 sess.run([variables.global_variables_initializer()]) 184 self.assertAllEqual(doubles, self.evaluate(r)) 185 186 # Now let's reuse our single variable. 187 varscope.reuse_variables() 188 r = functional_ops.map_fn(double_scoped, elems) 189 self.assertEqual(len(variables.trainable_variables()), 1) 190 self.assertAllEqual(doubles, self.evaluate(r)) 191 192 def testMap_Grad(self): 193 with self.test_session(): 194 param = constant_op.constant(2.0) 195 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 196 y = functional_ops.map_fn( 197 lambda x: math_ops.multiply(math_ops.square(x), param), elems) 198 r = gradients_impl.gradients(y, param)[0] 199 self.assertAllEqual(91.0, self.evaluate(r)) 200 r = gradients_impl.gradients(y, elems)[0] 201 self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r)) 202 203 @test_util.run_in_graph_and_eager_modes() 204 def testMap_SimpleNotTensor(self): 205 with self.test_session(): 206 nums = np.array([1, 2, 3, 4, 5, 6]) 207 r = functional_ops.map_fn( 208 lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums) 209 self.assertAllEqual( 210 np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) 211 212 @test_util.run_in_graph_and_eager_modes() 213 def testMap_SingleInputMultiOutput(self): 214 with self.test_session(): 215 nums = np.array([1, 2, 3, 4, 5, 6]) 216 r = functional_ops.map_fn( 217 lambda x: ((x + 3) * 2, -(x + 3) * 2), 218 nums, 219 dtype=(dtypes.int64, dtypes.int64)) 220 self.assertEqual(2, len(r)) 221 self.assertEqual((6,), r[0].get_shape()) 222 self.assertEqual((6,), r[1].get_shape()) 223 received = self.evaluate(r) 224 self.assertAllEqual((nums + 3) * 2, received[0]) 225 self.assertAllEqual(-(nums + 3) * 2, received[1]) 226 227 @test_util.run_in_graph_and_eager_modes() 228 def testMap_MultiOutputMismatchedDtype(self): 229 with self.test_session(): 230 nums = np.array([1, 2, 3, 4, 5, 6]) 231 with self.assertRaisesRegexp( 232 TypeError, r"two structures don't have the same sequence type."): 233 # lambda emits tuple, but dtype is a list 234 functional_ops.map_fn( 235 lambda x: ((x + 3) * 2, -(x + 3) * 2), 236 nums, 237 dtype=[dtypes.int64, dtypes.int64]) 238 239 @test_util.run_in_graph_and_eager_modes() 240 def testMap_MultiInputSingleOutput(self): 241 with self.test_session(): 242 nums = np.array([1, 2, 3, 4, 5, 6]) 243 r = functional_ops.map_fn( 244 lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)), 245 dtype=dtypes.int64) 246 self.assertEqual((6,), r.get_shape()) 247 received = self.evaluate(r) 248 self.assertAllEqual(nums * nums + (-nums), received) 249 250 @test_util.run_in_graph_and_eager_modes() 251 def testMap_MultiInputSameStructureOutput(self): 252 with self.test_session(): 253 nums = np.array([1, 2, 3, 4, 5, 6]) 254 r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])), 255 (nums, (2 * nums, -nums))) 256 r = [r[0], r[1][0], r[1][1]] 257 self.assertEqual((6,), r[0].get_shape()) 258 self.assertEqual((6,), r[1].get_shape()) 259 self.assertEqual((6,), r[2].get_shape()) 260 received = self.evaluate(r) 261 self.assertAllEqual(2 * nums, received[0]) 262 self.assertAllEqual(-nums, received[1]) 263 self.assertAllEqual(nums, received[2]) 264 265 @test_util.run_in_graph_and_eager_modes() 266 def testScan_Simple(self): 267 with self.test_session(): 268 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 269 v = constant_op.constant(2.0, name="v") 270 271 # pylint: disable=unnecessary-lambda 272 r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) 273 self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) 274 275 r = functional_ops.scan( 276 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 277 self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) 278 # pylint: enable=unnecessary-lambda 279 280 @test_util.run_in_graph_and_eager_modes() 281 def testScan_SingleInputMultiOutput(self): 282 with self.test_session(): 283 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 284 initializer = (np.array(1.0), np.array(-1.0)) 285 r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, 286 initializer) 287 r_value = self.evaluate(r) 288 289 self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) 290 self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) 291 292 @test_util.run_in_graph_and_eager_modes() 293 def testScan_MultiInputSingleOutput(self): 294 with self.test_session(): 295 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 296 initializer = np.array(1.0) 297 # Multiply a * 1 each time 298 r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), 299 (elems + 1, -elems), initializer) 300 self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) 301 302 @test_util.run_in_graph_and_eager_modes() 303 def testScan_MultiInputSameTypeOutput(self): 304 with self.test_session(): 305 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 306 r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), 307 (elems, -elems)) 308 r_value = self.evaluate(r) 309 self.assertAllEqual(np.cumsum(elems), r_value[0]) 310 self.assertAllEqual(np.cumsum(-elems), r_value[1]) 311 312 @test_util.run_in_graph_and_eager_modes() 313 def testScan_MultiOutputMismatchedInitializer(self): 314 with self.test_session(): 315 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 316 initializer = np.array(1.0) 317 # Multiply a * 1 each time 318 with self.assertRaisesRegexp( 319 ValueError, "two structures don't have the same number of elements"): 320 functional_ops.scan(lambda a, x: (a, -a), elems, initializer) 321 322 def testScan_Scoped(self): 323 with self.test_session() as sess: 324 with variable_scope.variable_scope("root") as varscope: 325 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 326 327 r = functional_ops.scan(simple_scoped_fn, elems) 328 # Check that we have the one variable we asked for here. 329 self.assertEqual(len(variables.trainable_variables()), 1) 330 self.assertEqual(variables.trainable_variables()[0].name, 331 "root/body/two:0") 332 sess.run([variables.global_variables_initializer()]) 333 results = np.array([1, 6, 18, 44, 98, 208]) 334 self.assertAllEqual(results, self.evaluate(r)) 335 336 # Now let's reuse our single variable. 337 varscope.reuse_variables() 338 r = functional_ops.scan(simple_scoped_fn, elems, initializer=2) 339 self.assertEqual(len(variables.trainable_variables()), 1) 340 results = np.array([6, 16, 38, 84, 178, 368]) 341 self.assertAllEqual(results, self.evaluate(r)) 342 343 @test_util.run_in_graph_and_eager_modes() 344 def testScanFoldl_Nested(self): 345 with self.test_session(): 346 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") 347 inner_elems = constant_op.constant([0.5, 0.5], name="data") 348 349 def r_inner(a, x): 350 return functional_ops.foldl( 351 lambda b, y: b * y * x, inner_elems, initializer=a) 352 353 r = functional_ops.scan(r_inner, elems) 354 355 # t == 0 (returns 1) 356 # t == 1, a == 1, x == 2 (returns 1) 357 # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 358 # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 359 # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) 360 # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 361 # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 362 # t == 3, a == 2.25, x == 4 (returns 9) 363 # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 364 # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 365 self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) 366 367 def testScan_Control(self): 368 with self.test_session() as sess: 369 s = array_ops.placeholder(dtypes.float32, shape=[None]) 370 b = array_ops.placeholder(dtypes.bool) 371 372 with ops.control_dependencies([b]): 373 c = functional_ops.scan(lambda a, x: x * a, s) 374 self.assertAllClose( 375 np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3], 376 b: True})) 377 378 def testScan_Grad(self): 379 with self.test_session(): 380 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 381 v = constant_op.constant(2.0, name="v") 382 383 # pylint: disable=unnecessary-lambda 384 r = functional_ops.scan( 385 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 386 # pylint: enable=unnecessary-lambda 387 r = gradients_impl.gradients(r, v)[0] 388 self.assertAllEqual(873.0, self.evaluate(r)) 389 390 def testScanGradientWithPartStopGradient(self): 391 a = variables.Variable(0.0, name="a") 392 b = variables.Variable(0.0, name="b") 393 elems = array_ops.zeros(5) 394 l0, l1 = functional_ops.scan( 395 lambda elem_, input_: (a, b), elems, initializer=(0., 0.)) 396 loss = l0 + array_ops.stop_gradient(l1) 397 grad = gradients_impl.gradients(ys=[loss], xs=[a, b]) 398 with self.test_session(use_gpu=True) as sess: 399 variables.global_variables_initializer().run() 400 sess.run(grad) 401 402 @test_util.run_in_graph_and_eager_modes() 403 def testFoldShape(self): 404 with self.test_session(): 405 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 406 407 def fn(_, current_input): 408 return current_input 409 410 initializer = constant_op.constant([0, 0, 0]) 411 y = functional_ops.foldl(fn, x, initializer=initializer) 412 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 413 414 @test_util.run_in_graph_and_eager_modes() 415 def testMapShape(self): 416 with self.test_session(): 417 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 418 y = functional_ops.map_fn(lambda e: e, x) 419 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 420 421 def testMapUnknownShape(self): 422 x = array_ops.placeholder(dtypes.float32) 423 y = functional_ops.map_fn(lambda e: e, x) 424 self.assertIs(None, y.get_shape().dims) 425 426 @test_util.run_in_graph_and_eager_modes() 427 def testMapEmptyScalar(self): 428 with self.test_session(): 429 map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) 430 self.assertAllEqual([0], map_return.get_shape().dims) 431 self.assertAllEqual([0], self.evaluate(map_return).shape) 432 433 # TODO(akshayka): this test fails in eager: the iterable is of length 0 so 434 # so the body of the while loop never executes 435 def testMapEmptyTensor(self): 436 with self.test_session(): 437 map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]), 438 constant_op.constant([])) 439 self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) 440 self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape) 441 442 @test_util.run_in_graph_and_eager_modes() 443 def testScanShape(self): 444 with self.test_session(): 445 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 446 447 def fn(_, current_input): 448 return current_input 449 450 initializer = constant_op.constant([0, 0, 0]) 451 y = functional_ops.scan(fn, x, initializer=initializer) 452 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 453 454 # TODO(akshayka): this test fails in eager: the iterable is of length 0 so 455 # so the body of the while loop never executes 456 def testScanEmptyTensor(self): 457 with self.test_session(): 458 x = functional_ops.scan( 459 lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4])) 460 self.assertAllEqual([0, 2, 4], x.get_shape()) 461 self.assertAllEqual(x.get_shape(), self.evaluate(x).shape) 462 463 def testScanUnknownShape(self): 464 x = array_ops.placeholder(dtypes.float32) 465 initializer = array_ops.placeholder(dtypes.float32) 466 467 def fn(_, current_input): 468 return current_input 469 470 y = functional_ops.scan(fn, x, initializer=initializer) 471 self.assertIs(None, y.get_shape().dims) 472 473 def testScanVaryingShape(self): 474 with self.test_session() as sess: 475 x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2]) 476 x_t = array_ops.transpose(x) 477 # scan over dimension 0 (with shape None) 478 result = functional_ops.scan(lambda a, x: a + x, x) 479 # scanned over transposed dimension 0 (with shape 2) 480 result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False) 481 # ensure gradients can be calculated 482 result_grad = gradients_impl.gradients(result, [x])[0] 483 result_t_grad = gradients_impl.gradients(result_t, [x_t])[0] 484 485 # smoke test to ensure they all evaluate 486 sess.run([result, result_t, result_grad, result_t_grad], 487 feed_dict={x: [[1.0, 2.0]]}) 488 489 def testRemoteFunction(self): 490 worker_config = config_pb2.ConfigProto() 491 worker_config.device_count["CPU"] = 2 492 worker, _ = test_util.create_local_cluster( 493 1, 1, worker_config=worker_config) 494 495 @function.Defun(dtypes.int32, dtypes.int32) 496 def _remote_fn(a, b): 497 return math_ops.multiply(a, b) 498 499 with ops.device("/job:ps/task:0"): 500 a = variables.Variable(2, dtype=dtypes.int32) 501 b = variables.Variable(3, dtype=dtypes.int32) 502 503 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 504 remote_op = functional_ops.remote_call( 505 args=[a, b], 506 Tout=[dtypes.int32], 507 f=_remote_fn, 508 target="/job:worker/replica:0/task:0/cpu:1") 509 510 with session.Session(worker[0].target) as sess: 511 sess.run(variables.global_variables_initializer()) 512 mul = sess.run(remote_op) 513 self.assertEqual(mul, [6]) 514 515 def testRemoteFunctionDirectSession(self): 516 worker_config = config_pb2.ConfigProto() 517 worker_config.device_count["CPU"] = 2 518 519 @function.Defun(dtypes.int32, dtypes.int32) 520 def _remote_fn(a, b): 521 return math_ops.multiply(a, b) 522 523 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 524 a = variables.Variable(2, dtype=dtypes.int32) 525 b = variables.Variable(3, dtype=dtypes.int32) 526 527 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 528 remote_op = functional_ops.remote_call( 529 args=[a, b], 530 Tout=[dtypes.int32], 531 f=_remote_fn, 532 target="/job:localhost/replica:0/task:0/cpu:1") 533 534 with self.test_session(config=worker_config) as sess: 535 sess.run(variables.global_variables_initializer()) 536 mul = sess.run(remote_op) 537 self.assertEqual(mul, [6]) 538 539 def testRemoteFunctionCPUGPU(self): 540 if not test_util.is_gpu_available(): 541 self.skipTest("No GPU available") 542 543 @function.Defun(dtypes.float32, dtypes.float32) 544 def _remote_fn(a, b): 545 return math_ops.multiply(a, b) 546 547 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 548 a = variables.Variable(2, dtype=dtypes.float32) 549 b = variables.Variable(3, dtype=dtypes.float32) 550 551 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 552 remote_op = functional_ops.remote_call( 553 args=[a, b], 554 Tout=[dtypes.float32], 555 f=_remote_fn, 556 target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 557 558 with self.test_session() as sess: 559 sess.run(variables.global_variables_initializer()) 560 mul = sess.run(remote_op) 561 self.assertEqual(mul, 9.0) 562 563 def testRemoteFunctionGPUCPU(self): 564 if not test_util.is_gpu_available(): 565 self.skipTest("No GPU available") 566 567 @function.Defun(dtypes.float32, dtypes.float32) 568 def _remote_fn(a, b): 569 return math_ops.multiply(a, b) 570 571 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 572 a = variables.Variable(2, dtype=dtypes.float32) 573 b = variables.Variable(3, dtype=dtypes.float32) 574 575 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 576 remote_op = functional_ops.remote_call( 577 args=[a, b], 578 Tout=[dtypes.float32], 579 f=_remote_fn, 580 target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 581 582 with self.test_session() as sess: 583 sess.run(variables.global_variables_initializer()) 584 mul = sess.run(remote_op) 585 self.assertEqual(mul, 9.0) 586 587 def testRemoteFunctionCrossProcess(self): 588 workers, _ = test_util.create_local_cluster(2, 1) 589 590 @function.Defun(dtypes.float32, dtypes.float32) 591 def _remote_fn(a, b): 592 return math_ops.multiply(a, b) 593 594 with ops.device("/job:ps/task:0"): 595 a = variables.Variable(2, dtype=dtypes.float32) 596 b = variables.Variable(3, dtype=dtypes.float32) 597 598 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 599 remote_op = functional_ops.remote_call( 600 args=[a, b], 601 Tout=[dtypes.float32], 602 f=_remote_fn, 603 target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0 604 605 with session.Session(workers[0].target) as sess: 606 sess.run(variables.global_variables_initializer()) 607 mul = sess.run(remote_op) 608 self.assertEqual(mul, 9) 609 610 611 if __name__ == "__main__": 612 test.main() 613