Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.kernels.bcast_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.core.protobuf import config_pb2
     24 from tensorflow.python.client import session
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import function
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.framework import sparse_tensor
     30 from tensorflow.python.framework import test_util
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import functional_ops
     33 from tensorflow.python.ops import gradients_impl
     34 from tensorflow.python.ops import init_ops
     35 from tensorflow.python.ops import math_ops
     36 from tensorflow.python.ops import variable_scope
     37 from tensorflow.python.ops import variables
     38 import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
     39 from tensorflow.python.platform import test
     40 
     41 
     42 def simple_scoped_fn(a, x):
     43   """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
     44   with variable_scope.variable_scope("body"):
     45     # Dummy variable, just to check that scoping works as intended.
     46     two = variable_scope.get_variable(
     47         "two", [],
     48         dtype=dtypes.int32,
     49         initializer=init_ops.constant_initializer(2))
     50     return math_ops.multiply(math_ops.add(a, x), two)
     51 
     52 
     53 class FunctionalOpsTest(test.TestCase):
     54 
     55   @test_util.run_in_graph_and_eager_modes()
     56   def testFoldl_Simple(self):
     57     with self.test_session():
     58       elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
     59 
     60       r = functional_ops.foldl(
     61           lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
     62           elems)
     63       self.assertAllEqual(208, self.evaluate(r))
     64 
     65       r = functional_ops.foldl(
     66           lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
     67           elems,
     68           initializer=10)
     69       self.assertAllEqual(880, self.evaluate(r))
     70 
     71   def testFoldl_Scoped(self):
     72     with self.test_session() as sess:
     73       with variable_scope.variable_scope("root") as varscope:
     74         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
     75 
     76         r = functional_ops.foldl(simple_scoped_fn, elems)
     77         # Check that we have the one variable we asked for here.
     78         self.assertEqual(len(variables.trainable_variables()), 1)
     79         self.assertEqual(variables.trainable_variables()[0].name,
     80                          "root/body/two:0")
     81         sess.run([variables.global_variables_initializer()])
     82         self.assertAllEqual(208, self.evaluate(r))
     83 
     84         # Now let's reuse our single variable.
     85         varscope.reuse_variables()
     86         r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10)
     87         self.assertEqual(len(variables.trainable_variables()), 1)
     88         self.assertAllEqual(880, self.evaluate(r))
     89 
     90   @test_util.run_in_graph_and_eager_modes()
     91   def testFoldr_Simple(self):
     92     with self.test_session():
     93       elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
     94 
     95       r = functional_ops.foldr(
     96           lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
     97           elems)
     98       self.assertAllEqual(450, self.evaluate(r))
     99 
    100       r = functional_ops.foldr(
    101           lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
    102           elems,
    103           initializer=10)
    104       self.assertAllEqual(1282, self.evaluate(r))
    105 
    106   def testFoldr_Scoped(self):
    107     with self.test_session() as sess:
    108       with variable_scope.variable_scope("root") as varscope:
    109         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    110 
    111         r = functional_ops.foldr(simple_scoped_fn, elems)
    112         # Check that we have the one variable we asked for here.
    113         self.assertEqual(len(variables.trainable_variables()), 1)
    114         self.assertEqual(variables.trainable_variables()[0].name,
    115                          "root/body/two:0")
    116         sess.run([variables.global_variables_initializer()])
    117         self.assertAllEqual(450, self.evaluate(r))
    118 
    119         # Now let's reuse our single variable.
    120         varscope.reuse_variables()
    121         r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10)
    122         self.assertEqual(len(variables.trainable_variables()), 1)
    123         self.assertAllEqual(1282, self.evaluate(r))
    124 
    125   # pylint: disable=unnecessary-lambda
    126   def testFold_Grad(self):
    127     with self.test_session():
    128       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
    129       v = constant_op.constant(2.0, name="v")
    130       r = functional_ops.foldl(
    131           lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
    132       r = gradients_impl.gradients(r, v)[0]
    133       self.assertAllEqual(720.0, self.evaluate(r))
    134 
    135       r = functional_ops.foldr(
    136           lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
    137       r = gradients_impl.gradients(r, v)[0]
    138       self.assertAllEqual(720.0, self.evaluate(r))
    139   # pylint: enable=unnecessary-lambda
    140 
    141   @test_util.run_in_graph_and_eager_modes()
    142   def testMap_Simple(self):
    143     with self.test_session():
    144       nums = [1, 2, 3, 4, 5, 6]
    145       elems = constant_op.constant(nums, name="data")
    146       r = functional_ops.map_fn(
    147           lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
    148       self.assertAllEqual(
    149           np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
    150 
    151   def testMapSparseTensor(self):
    152     with self.test_session():
    153       with self.assertRaises(TypeError):
    154         functional_ops.map_fn(
    155             lambda x: x,
    156             sparse_tensor.SparseTensor(
    157                 indices=[[0, 0], [0, 1], [1, 0]],
    158                 values=constant_op.constant([0, 1, 2]),
    159                 dense_shape=[2, 2]))
    160 
    161   def testMap_Scoped(self):
    162     with self.test_session() as sess:
    163 
    164       def double_scoped(x):
    165         """2x with a dummy 2 that is scoped."""
    166         with variable_scope.variable_scope("body"):
    167           # Dummy variable, just to check that scoping works as intended.
    168           two = variable_scope.get_variable(
    169               "two", [],
    170               dtype=dtypes.int32,
    171               initializer=init_ops.constant_initializer(2))
    172           return math_ops.multiply(x, two)
    173 
    174       with variable_scope.variable_scope("root") as varscope:
    175         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    176         doubles = np.array([2 * x for x in [1, 2, 3, 4, 5, 6]])
    177 
    178         r = functional_ops.map_fn(double_scoped, elems)
    179         # Check that we have the one variable we asked for here.
    180         self.assertEqual(len(variables.trainable_variables()), 1)
    181         self.assertEqual(variables.trainable_variables()[0].name,
    182                          "root/body/two:0")
    183         sess.run([variables.global_variables_initializer()])
    184         self.assertAllEqual(doubles, self.evaluate(r))
    185 
    186         # Now let's reuse our single variable.
    187         varscope.reuse_variables()
    188         r = functional_ops.map_fn(double_scoped, elems)
    189         self.assertEqual(len(variables.trainable_variables()), 1)
    190         self.assertAllEqual(doubles, self.evaluate(r))
    191 
    192   def testMap_Grad(self):
    193     with self.test_session():
    194       param = constant_op.constant(2.0)
    195       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
    196       y = functional_ops.map_fn(
    197           lambda x: math_ops.multiply(math_ops.square(x), param), elems)
    198       r = gradients_impl.gradients(y, param)[0]
    199       self.assertAllEqual(91.0, self.evaluate(r))
    200       r = gradients_impl.gradients(y, elems)[0]
    201       self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
    202 
    203   @test_util.run_in_graph_and_eager_modes()
    204   def testMap_SimpleNotTensor(self):
    205     with self.test_session():
    206       nums = np.array([1, 2, 3, 4, 5, 6])
    207       r = functional_ops.map_fn(
    208           lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
    209       self.assertAllEqual(
    210           np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
    211 
    212   @test_util.run_in_graph_and_eager_modes()
    213   def testMap_SingleInputMultiOutput(self):
    214     with self.test_session():
    215       nums = np.array([1, 2, 3, 4, 5, 6])
    216       r = functional_ops.map_fn(
    217           lambda x: ((x + 3) * 2, -(x + 3) * 2),
    218           nums,
    219           dtype=(dtypes.int64, dtypes.int64))
    220       self.assertEqual(2, len(r))
    221       self.assertEqual((6,), r[0].get_shape())
    222       self.assertEqual((6,), r[1].get_shape())
    223       received = self.evaluate(r)
    224       self.assertAllEqual((nums + 3) * 2, received[0])
    225       self.assertAllEqual(-(nums + 3) * 2, received[1])
    226 
    227   @test_util.run_in_graph_and_eager_modes()
    228   def testMap_MultiOutputMismatchedDtype(self):
    229     with self.test_session():
    230       nums = np.array([1, 2, 3, 4, 5, 6])
    231       with self.assertRaisesRegexp(
    232           TypeError, r"two structures don't have the same sequence type."):
    233         # lambda emits tuple, but dtype is a list
    234         functional_ops.map_fn(
    235             lambda x: ((x + 3) * 2, -(x + 3) * 2),
    236             nums,
    237             dtype=[dtypes.int64, dtypes.int64])
    238 
    239   @test_util.run_in_graph_and_eager_modes()
    240   def testMap_MultiInputSingleOutput(self):
    241     with self.test_session():
    242       nums = np.array([1, 2, 3, 4, 5, 6])
    243       r = functional_ops.map_fn(
    244           lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
    245           dtype=dtypes.int64)
    246       self.assertEqual((6,), r.get_shape())
    247       received = self.evaluate(r)
    248       self.assertAllEqual(nums * nums + (-nums), received)
    249 
    250   @test_util.run_in_graph_and_eager_modes()
    251   def testMap_MultiInputSameStructureOutput(self):
    252     with self.test_session():
    253       nums = np.array([1, 2, 3, 4, 5, 6])
    254       r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
    255                                 (nums, (2 * nums, -nums)))
    256       r = [r[0], r[1][0], r[1][1]]
    257       self.assertEqual((6,), r[0].get_shape())
    258       self.assertEqual((6,), r[1].get_shape())
    259       self.assertEqual((6,), r[2].get_shape())
    260       received = self.evaluate(r)
    261       self.assertAllEqual(2 * nums, received[0])
    262       self.assertAllEqual(-nums, received[1])
    263       self.assertAllEqual(nums, received[2])
    264 
    265   @test_util.run_in_graph_and_eager_modes()
    266   def testScan_Simple(self):
    267     with self.test_session():
    268       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
    269       v = constant_op.constant(2.0, name="v")
    270 
    271       # pylint: disable=unnecessary-lambda
    272       r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
    273       self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
    274 
    275       r = functional_ops.scan(
    276           lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
    277       self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
    278       # pylint: enable=unnecessary-lambda
    279 
    280   @test_util.run_in_graph_and_eager_modes()
    281   def testScan_SingleInputMultiOutput(self):
    282     with self.test_session():
    283       elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    284       initializer = (np.array(1.0), np.array(-1.0))
    285       r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
    286                               initializer)
    287       r_value = self.evaluate(r)
    288 
    289       self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
    290       self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
    291 
    292   @test_util.run_in_graph_and_eager_modes()
    293   def testScan_MultiInputSingleOutput(self):
    294     with self.test_session():
    295       elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    296       initializer = np.array(1.0)
    297       # Multiply a * 1 each time
    298       r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
    299                               (elems + 1, -elems), initializer)
    300       self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
    301 
    302   @test_util.run_in_graph_and_eager_modes()
    303   def testScan_MultiInputSameTypeOutput(self):
    304     with self.test_session():
    305       elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    306       r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
    307                               (elems, -elems))
    308       r_value = self.evaluate(r)
    309       self.assertAllEqual(np.cumsum(elems), r_value[0])
    310       self.assertAllEqual(np.cumsum(-elems), r_value[1])
    311 
    312   @test_util.run_in_graph_and_eager_modes()
    313   def testScan_MultiOutputMismatchedInitializer(self):
    314     with self.test_session():
    315       elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    316       initializer = np.array(1.0)
    317       # Multiply a * 1 each time
    318       with self.assertRaisesRegexp(
    319           ValueError, "two structures don't have the same number of elements"):
    320         functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
    321 
    322   def testScan_Scoped(self):
    323     with self.test_session() as sess:
    324       with variable_scope.variable_scope("root") as varscope:
    325         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    326 
    327         r = functional_ops.scan(simple_scoped_fn, elems)
    328         # Check that we have the one variable we asked for here.
    329         self.assertEqual(len(variables.trainable_variables()), 1)
    330         self.assertEqual(variables.trainable_variables()[0].name,
    331                          "root/body/two:0")
    332         sess.run([variables.global_variables_initializer()])
    333         results = np.array([1, 6, 18, 44, 98, 208])
    334         self.assertAllEqual(results, self.evaluate(r))
    335 
    336         # Now let's reuse our single variable.
    337         varscope.reuse_variables()
    338         r = functional_ops.scan(simple_scoped_fn, elems, initializer=2)
    339         self.assertEqual(len(variables.trainable_variables()), 1)
    340         results = np.array([6, 16, 38, 84, 178, 368])
    341         self.assertAllEqual(results, self.evaluate(r))
    342 
    343   @test_util.run_in_graph_and_eager_modes()
    344   def testScanFoldl_Nested(self):
    345     with self.test_session():
    346       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
    347       inner_elems = constant_op.constant([0.5, 0.5], name="data")
    348 
    349       def r_inner(a, x):
    350         return functional_ops.foldl(
    351             lambda b, y: b * y * x, inner_elems, initializer=a)
    352 
    353       r = functional_ops.scan(r_inner, elems)
    354 
    355       # t == 0 (returns 1)
    356       # t == 1, a == 1, x == 2 (returns 1)
    357       #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
    358       #   t_1 == 1, b == 1,      y == 0.5, returns b * y * x = 1
    359       # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
    360       #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
    361       #   t_1 == 1, b == 1.5,    y == 0.5, returns b * y * x = 1.5*1.5
    362       # t == 3, a == 2.25, x == 4 (returns 9)
    363       #   t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
    364       #   t_1 == 1, b == 4.5,       y == 0.5, returns b * y * x = 9
    365       self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
    366 
    367   def testScan_Control(self):
    368     with self.test_session() as sess:
    369       s = array_ops.placeholder(dtypes.float32, shape=[None])
    370       b = array_ops.placeholder(dtypes.bool)
    371 
    372       with ops.control_dependencies([b]):
    373         c = functional_ops.scan(lambda a, x: x * a, s)
    374       self.assertAllClose(
    375           np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3],
    376                                                   b: True}))
    377 
    378   def testScan_Grad(self):
    379     with self.test_session():
    380       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
    381       v = constant_op.constant(2.0, name="v")
    382 
    383       # pylint: disable=unnecessary-lambda
    384       r = functional_ops.scan(
    385           lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
    386       # pylint: enable=unnecessary-lambda
    387       r = gradients_impl.gradients(r, v)[0]
    388       self.assertAllEqual(873.0, self.evaluate(r))
    389 
    390   def testScanGradientWithPartStopGradient(self):
    391     a = variables.Variable(0.0, name="a")
    392     b = variables.Variable(0.0, name="b")
    393     elems = array_ops.zeros(5)
    394     l0, l1 = functional_ops.scan(
    395         lambda elem_, input_: (a, b), elems, initializer=(0., 0.))
    396     loss = l0 + array_ops.stop_gradient(l1)
    397     grad = gradients_impl.gradients(ys=[loss], xs=[a, b])
    398     with self.test_session(use_gpu=True) as sess:
    399       variables.global_variables_initializer().run()
    400       sess.run(grad)
    401 
    402   @test_util.run_in_graph_and_eager_modes()
    403   def testFoldShape(self):
    404     with self.test_session():
    405       x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
    406 
    407       def fn(_, current_input):
    408         return current_input
    409 
    410       initializer = constant_op.constant([0, 0, 0])
    411       y = functional_ops.foldl(fn, x, initializer=initializer)
    412       self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
    413 
    414   @test_util.run_in_graph_and_eager_modes()
    415   def testMapShape(self):
    416     with self.test_session():
    417       x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
    418       y = functional_ops.map_fn(lambda e: e, x)
    419       self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
    420 
    421   def testMapUnknownShape(self):
    422     x = array_ops.placeholder(dtypes.float32)
    423     y = functional_ops.map_fn(lambda e: e, x)
    424     self.assertIs(None, y.get_shape().dims)
    425 
    426   @test_util.run_in_graph_and_eager_modes()
    427   def testMapEmptyScalar(self):
    428     with self.test_session():
    429       map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
    430       self.assertAllEqual([0], map_return.get_shape().dims)
    431       self.assertAllEqual([0], self.evaluate(map_return).shape)
    432 
    433   # TODO(akshayka): this test fails in eager: the iterable is of length 0 so
    434   # so the body of the while loop never executes
    435   def testMapEmptyTensor(self):
    436     with self.test_session():
    437       map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]),
    438                                          constant_op.constant([]))
    439       self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
    440       self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
    441 
    442   @test_util.run_in_graph_and_eager_modes()
    443   def testScanShape(self):
    444     with self.test_session():
    445       x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
    446 
    447       def fn(_, current_input):
    448         return current_input
    449 
    450       initializer = constant_op.constant([0, 0, 0])
    451       y = functional_ops.scan(fn, x, initializer=initializer)
    452       self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
    453 
    454   # TODO(akshayka): this test fails in eager: the iterable is of length 0 so
    455   # so the body of the while loop never executes
    456   def testScanEmptyTensor(self):
    457     with self.test_session():
    458       x = functional_ops.scan(
    459           lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
    460       self.assertAllEqual([0, 2, 4], x.get_shape())
    461       self.assertAllEqual(x.get_shape(), self.evaluate(x).shape)
    462 
    463   def testScanUnknownShape(self):
    464     x = array_ops.placeholder(dtypes.float32)
    465     initializer = array_ops.placeholder(dtypes.float32)
    466 
    467     def fn(_, current_input):
    468       return current_input
    469 
    470     y = functional_ops.scan(fn, x, initializer=initializer)
    471     self.assertIs(None, y.get_shape().dims)
    472 
    473   def testScanVaryingShape(self):
    474     with self.test_session() as sess:
    475       x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
    476       x_t = array_ops.transpose(x)
    477       # scan over dimension 0 (with shape None)
    478       result = functional_ops.scan(lambda a, x: a + x, x)
    479       # scanned over transposed dimension 0 (with shape 2)
    480       result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False)
    481       # ensure gradients can be calculated
    482       result_grad = gradients_impl.gradients(result, [x])[0]
    483       result_t_grad = gradients_impl.gradients(result_t, [x_t])[0]
    484 
    485       # smoke test to ensure they all evaluate
    486       sess.run([result, result_t, result_grad, result_t_grad],
    487                feed_dict={x: [[1.0, 2.0]]})
    488 
    489   def testRemoteFunction(self):
    490     worker_config = config_pb2.ConfigProto()
    491     worker_config.device_count["CPU"] = 2
    492     worker, _ = test_util.create_local_cluster(
    493         1, 1, worker_config=worker_config)
    494 
    495     @function.Defun(dtypes.int32, dtypes.int32)
    496     def _remote_fn(a, b):
    497       return math_ops.multiply(a, b)
    498 
    499     with ops.device("/job:ps/task:0"):
    500       a = variables.Variable(2, dtype=dtypes.int32)
    501       b = variables.Variable(3, dtype=dtypes.int32)
    502 
    503     with ops.device("/job:worker/replica:0/task:0/cpu:0"):
    504       remote_op = functional_ops.remote_call(
    505           args=[a, b],
    506           Tout=[dtypes.int32],
    507           f=_remote_fn,
    508           target="/job:worker/replica:0/task:0/cpu:1")
    509 
    510     with session.Session(worker[0].target) as sess:
    511       sess.run(variables.global_variables_initializer())
    512       mul = sess.run(remote_op)
    513       self.assertEqual(mul, [6])
    514 
    515   def testRemoteFunctionDirectSession(self):
    516     worker_config = config_pb2.ConfigProto()
    517     worker_config.device_count["CPU"] = 2
    518 
    519     @function.Defun(dtypes.int32, dtypes.int32)
    520     def _remote_fn(a, b):
    521       return math_ops.multiply(a, b)
    522 
    523     with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    524       a = variables.Variable(2, dtype=dtypes.int32)
    525       b = variables.Variable(3, dtype=dtypes.int32)
    526 
    527     with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    528       remote_op = functional_ops.remote_call(
    529           args=[a, b],
    530           Tout=[dtypes.int32],
    531           f=_remote_fn,
    532           target="/job:localhost/replica:0/task:0/cpu:1")
    533 
    534     with self.test_session(config=worker_config) as sess:
    535       sess.run(variables.global_variables_initializer())
    536       mul = sess.run(remote_op)
    537       self.assertEqual(mul, [6])
    538 
    539   def testRemoteFunctionCPUGPU(self):
    540     if not test_util.is_gpu_available():
    541       self.skipTest("No GPU available")
    542 
    543     @function.Defun(dtypes.float32, dtypes.float32)
    544     def _remote_fn(a, b):
    545       return math_ops.multiply(a, b)
    546 
    547     with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    548       a = variables.Variable(2, dtype=dtypes.float32)
    549       b = variables.Variable(3, dtype=dtypes.float32)
    550 
    551     with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    552       remote_op = functional_ops.remote_call(
    553           args=[a, b],
    554           Tout=[dtypes.float32],
    555           f=_remote_fn,
    556           target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
    557 
    558     with self.test_session() as sess:
    559       sess.run(variables.global_variables_initializer())
    560       mul = sess.run(remote_op)
    561       self.assertEqual(mul, 9.0)
    562 
    563   def testRemoteFunctionGPUCPU(self):
    564     if not test_util.is_gpu_available():
    565       self.skipTest("No GPU available")
    566 
    567     @function.Defun(dtypes.float32, dtypes.float32)
    568     def _remote_fn(a, b):
    569       return math_ops.multiply(a, b)
    570 
    571     with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
    572       a = variables.Variable(2, dtype=dtypes.float32)
    573       b = variables.Variable(3, dtype=dtypes.float32)
    574 
    575     with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
    576       remote_op = functional_ops.remote_call(
    577           args=[a, b],
    578           Tout=[dtypes.float32],
    579           f=_remote_fn,
    580           target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
    581 
    582     with self.test_session() as sess:
    583       sess.run(variables.global_variables_initializer())
    584       mul = sess.run(remote_op)
    585       self.assertEqual(mul, 9.0)
    586 
    587   def testRemoteFunctionCrossProcess(self):
    588     workers, _ = test_util.create_local_cluster(2, 1)
    589 
    590     @function.Defun(dtypes.float32, dtypes.float32)
    591     def _remote_fn(a, b):
    592       return math_ops.multiply(a, b)
    593 
    594     with ops.device("/job:ps/task:0"):
    595       a = variables.Variable(2, dtype=dtypes.float32)
    596       b = variables.Variable(3, dtype=dtypes.float32)
    597 
    598     with ops.device("/job:worker/replica:0/task:0/cpu:0"):
    599       remote_op = functional_ops.remote_call(
    600           args=[a, b],
    601           Tout=[dtypes.float32],
    602           f=_remote_fn,
    603           target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0
    604 
    605     with session.Session(workers[0].target) as sess:
    606       sess.run(variables.global_variables_initializer())
    607       mul = sess.run(remote_op)
    608       self.assertEqual(mul, 9)
    609 
    610 
    611 if __name__ == "__main__":
    612   test.main()
    613