Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the Python extension-based XLA client."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import threading
     23 
     24 import numpy as np
     25 
     26 from tensorflow.compiler.xla.python import xla_client
     27 import unittest
     28 
     29 
     30 class LocalComputationTest(unittest.TestCase):
     31   """Base class for running an XLA Computation through the local client."""
     32 
     33   def _NewComputation(self, name=None):
     34     if name is None:
     35       name = self.id()
     36     return xla_client.ComputationBuilder(name)
     37 
     38   def _Execute(self, c, arguments):
     39     compiled_c = c.Build().CompileWithExampleArguments(arguments)
     40     return compiled_c.Execute(arguments)
     41 
     42   def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
     43     assert expected is not None
     44     result = self._Execute(c, arguments)
     45     # Numpy's comparison methods are a bit too lenient by treating inputs as
     46     # "array-like", meaning that scalar 4 will be happily compared equal to
     47     # [[4]]. We'd like to be more strict so assert shapes as well.
     48     self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
     49     assert_func(result, expected)
     50 
     51   def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
     52     self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
     53 
     54   def _ExecuteAndCompareClose(self, c, arguments=(), expected=None):
     55     self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments,
     56                                expected)
     57 
     58 
     59 def NumpyArrayF32(*args, **kwargs):
     60   """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
     61   return np.array(*args, dtype=np.float32, **kwargs)
     62 
     63 
     64 def NumpyArrayF64(*args, **kwargs):
     65   """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
     66   return np.array(*args, dtype=np.float64, **kwargs)
     67 
     68 
     69 def NumpyArrayS32(*args, **kwargs):
     70   """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
     71   return np.array(*args, dtype=np.int32, **kwargs)
     72 
     73 
     74 def NumpyArrayS64(*args, **kwargs):
     75   """Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
     76   return np.array(*args, dtype=np.int64, **kwargs)
     77 
     78 
     79 def NumpyArrayBool(*args, **kwargs):
     80   """Convenience wrapper to create Numpy arrays with a np.bool dtype."""
     81   return np.array(*args, dtype=np.bool, **kwargs)
     82 
     83 
     84 class ComputationsWithConstantsTest(LocalComputationTest):
     85   """Tests focusing on Constant ops."""
     86 
     87   def testConstantScalarSumF32(self):
     88     c = self._NewComputation()
     89     root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
     90     self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
     91     self._ExecuteAndCompareClose(c, expected=4.25)
     92 
     93   def testConstantScalarSumF64(self):
     94     c = self._NewComputation()
     95     c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
     96     self._ExecuteAndCompareClose(c, expected=4.25)
     97 
     98   def testConstantScalarSumS32(self):
     99     c = self._NewComputation()
    100     c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
    101     self._ExecuteAndCompareClose(c, expected=3)
    102 
    103   def testConstantScalarSumS64(self):
    104     c = self._NewComputation()
    105     c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
    106     self._ExecuteAndCompareClose(c, expected=3)
    107 
    108   def testConstantVectorMulF32(self):
    109     c = self._NewComputation()
    110     c.Mul(
    111         c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
    112         c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
    113     self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
    114 
    115   def testConstantVectorMulF64(self):
    116     c = self._NewComputation()
    117     c.Mul(
    118         c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
    119         c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
    120     self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
    121 
    122   def testConstantVectorScalarDivF32(self):
    123     c = self._NewComputation()
    124     c.Div(
    125         c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
    126         c.ConstantF32Scalar(2.0))
    127     self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
    128 
    129   def testConstantVectorScalarDivF64(self):
    130     c = self._NewComputation()
    131     c.Div(
    132         c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
    133         c.ConstantF64Scalar(2.0))
    134     self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
    135 
    136   def testConstantVectorScalarPowF32(self):
    137     c = self._NewComputation()
    138     c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
    139     self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
    140 
    141   def testConstantVectorScalarPowF64(self):
    142     c = self._NewComputation()
    143     c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
    144     self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
    145 
    146   def testBooleanAnd(self):
    147     c = self._NewComputation()
    148     c.And(
    149         c.Constant(NumpyArrayBool([True, False, True, False])),
    150         c.Constant(NumpyArrayBool([True, True, False, False])))
    151     self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
    152 
    153   def testBooleanOr(self):
    154     c = self._NewComputation()
    155     c.Or(
    156         c.Constant(NumpyArrayBool([True, False, True, False])),
    157         c.Constant(NumpyArrayBool([True, True, False, False])))
    158     self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
    159 
    160   def testSum2DF32(self):
    161     c = self._NewComputation()
    162     c.Add(
    163         c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
    164         c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
    165     self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
    166 
    167   def testSum2DF64(self):
    168     c = self._NewComputation()
    169     c.Add(
    170         c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
    171         c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
    172     self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
    173 
    174   def testSum2DWith1DBroadcastDim0F32(self):
    175     # sum of a 2D array with a 1D array where the latter is replicated across
    176     # dimension 0 to match the former's shape.
    177     c = self._NewComputation()
    178     c.Add(
    179         c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    180         c.Constant(NumpyArrayF32([10, 20, 30])),
    181         broadcast_dimensions=(0,))
    182     self._ExecuteAndCompareClose(
    183         c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
    184 
    185   def testSum2DWith1DBroadcastDim0F64(self):
    186     # sum of a 2D array with a 1D array where the latter is replicated across
    187     # dimension 0 to match the former's shape.
    188     c = self._NewComputation()
    189     c.Add(
    190         c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    191         c.Constant(NumpyArrayF64([10, 20, 30])),
    192         broadcast_dimensions=(0,))
    193     self._ExecuteAndCompareClose(
    194         c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
    195 
    196   def testSum2DWith1DBroadcastDim1F32(self):
    197     # sum of a 2D array with a 1D array where the latter is replicated across
    198     # dimension 1 to match the former's shape.
    199     c = self._NewComputation()
    200     c.Add(
    201         c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    202         c.Constant(NumpyArrayF32([10, 20, 30])),
    203         broadcast_dimensions=(1,))
    204     self._ExecuteAndCompareClose(
    205         c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
    206 
    207   def testSum2DWith1DBroadcastDim1F64(self):
    208     # sum of a 2D array with a 1D array where the latter is replicated across
    209     # dimension 1 to match the former's shape.
    210     c = self._NewComputation()
    211     c.Add(
    212         c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    213         c.Constant(NumpyArrayF64([10, 20, 30])),
    214         broadcast_dimensions=(1,))
    215     self._ExecuteAndCompareClose(
    216         c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
    217 
    218   def testConstantAxpyF32(self):
    219     c = self._NewComputation()
    220     c.Add(
    221         c.Mul(
    222             c.ConstantF32Scalar(2),
    223             c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
    224         c.Constant(NumpyArrayF32([100, -100, 200, -200])))
    225     self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
    226 
    227   def testConstantAxpyF64(self):
    228     c = self._NewComputation()
    229     c.Add(
    230         c.Mul(
    231             c.ConstantF64Scalar(2),
    232             c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
    233         c.Constant(NumpyArrayF64([100, -100, 200, -200])))
    234     self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
    235 
    236 
    237 class ParametersTest(LocalComputationTest):
    238   """Tests focusing on Parameter ops and argument-passing."""
    239 
    240   def setUp(self):
    241     self.f32_scalar_2 = NumpyArrayF32(2.0)
    242     self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
    243     self.f64_scalar_2 = NumpyArrayF64(2.0)
    244     self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
    245     self.s32_scalar_3 = NumpyArrayS32(3)
    246     self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
    247     self.s64_scalar_3 = NumpyArrayS64(3)
    248     self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
    249 
    250   def testScalarTimesVectorAutonumberF32(self):
    251     c = self._NewComputation()
    252     p0 = c.ParameterFromNumpy(self.f32_scalar_2)
    253     p1 = c.ParameterFromNumpy(self.f32_4vector)
    254     c.Mul(p0, p1)
    255     self._ExecuteAndCompareClose(
    256         c,
    257         arguments=[self.f32_scalar_2, self.f32_4vector],
    258         expected=[-4.6, 6.6, -8.6, 10.6])
    259 
    260   def testScalarTimesVectorAutonumberF64(self):
    261     c = self._NewComputation()
    262     p0 = c.ParameterFromNumpy(self.f64_scalar_2)
    263     p1 = c.ParameterFromNumpy(self.f64_4vector)
    264     c.Mul(p0, p1)
    265     self._ExecuteAndCompareClose(
    266         c,
    267         arguments=[self.f64_scalar_2, self.f64_4vector],
    268         expected=[-4.6, 6.6, -8.6, 10.6])
    269 
    270   def testScalarTimesVectorS32(self):
    271     c = self._NewComputation()
    272     p0 = c.ParameterFromNumpy(self.s32_scalar_3)
    273     p1 = c.ParameterFromNumpy(self.s32_4vector)
    274     c.Mul(p0, p1)
    275     self._ExecuteAndCompareExact(
    276         c,
    277         arguments=[self.s32_scalar_3, self.s32_4vector],
    278         expected=[30, 45, -6, 21])
    279 
    280   def testScalarTimesVectorS64(self):
    281     c = self._NewComputation()
    282     p0 = c.ParameterFromNumpy(self.s64_scalar_3)
    283     p1 = c.ParameterFromNumpy(self.s64_4vector)
    284     c.Mul(p0, p1)
    285     self._ExecuteAndCompareExact(
    286         c,
    287         arguments=[self.s64_scalar_3, self.s64_4vector],
    288         expected=[30, 45, -6, 21])
    289 
    290   def testScalarMinusVectorExplicitNumberingF32(self):
    291     # Use explicit numbering and pass parameter_num first. Sub is used since
    292     # it's not commutative and can help catch parameter reversal within the
    293     # computation.
    294     c = self._NewComputation()
    295     p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
    296     p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
    297     c.Sub(p1, p0)
    298     self._ExecuteAndCompareClose(
    299         c,
    300         arguments=[self.f32_scalar_2, self.f32_4vector],
    301         expected=[-4.3, 1.3, -6.3, 3.3])
    302 
    303   def testScalarMinusVectorExplicitNumberingF64(self):
    304     # Use explicit numbering and pass parameter_num first. Sub is used since
    305     # it's not commutative and can help catch parameter reversal within the
    306     # computation.
    307     c = self._NewComputation()
    308     p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
    309     p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
    310     c.Sub(p1, p0)
    311     self._ExecuteAndCompareClose(
    312         c,
    313         arguments=[self.f64_scalar_2, self.f64_4vector],
    314         expected=[-4.3, 1.3, -6.3, 3.3])
    315 
    316 
    317 class LocalBufferTest(LocalComputationTest):
    318   """Tests focusing on execution with LocalBuffers."""
    319 
    320   def _Execute(self, c, arguments):
    321     compiled_c = c.Build().CompileWithExampleArguments(arguments)
    322     arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments]
    323     result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers)
    324     return result_buffer.to_py()
    325 
    326   def testConstantSum(self):
    327     c = self._NewComputation()
    328     c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
    329     self._ExecuteAndCompareClose(c, expected=4.25)
    330 
    331   def testOneParameterSum(self):
    332     c = self._NewComputation()
    333     c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
    334     self._ExecuteAndCompareClose(
    335         c,
    336         arguments=[NumpyArrayF32(1.11)],
    337         expected=4.25)
    338 
    339   def testTwoParameterSum(self):
    340     c = self._NewComputation()
    341     c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)),
    342           c.ParameterFromNumpy(NumpyArrayF32(0.)))
    343     self._ExecuteAndCompareClose(
    344         c,
    345         arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)],
    346         expected=4.25)
    347 
    348   def testCannotCallWithDeletedBuffers(self):
    349     c = self._NewComputation()
    350     c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
    351     arg = NumpyArrayF32(1.11)
    352     compiled_c = c.Build().CompileWithExampleArguments([arg])
    353     arg_buffer = xla_client.LocalBuffer.from_py(arg)
    354     arg_buffer.delete()
    355     with self.assertRaises(ValueError):
    356       compiled_c.ExecuteWithLocalBuffers([arg_buffer])
    357 
    358 
    359 class SingleOpTest(LocalComputationTest):
    360   """Tests for single ops.
    361 
    362   The goal here is smoke testing - to exercise the most basic functionality of
    363   single XLA ops. As minimal as possible number of additional ops are added
    364   around the op being tested.
    365   """
    366 
    367   def testConcatenateF32(self):
    368     c = self._NewComputation()
    369     c.Concatenate(
    370         (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
    371          c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
    372         dimension=0)
    373     self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    374 
    375   def testConcatenateF64(self):
    376     c = self._NewComputation()
    377     c.Concatenate(
    378         (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
    379          c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
    380         dimension=0)
    381     self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    382 
    383   def testConvertElementType(self):
    384     xla_types = {
    385         np.bool: xla_client.xla_data_pb2.PRED,
    386         np.int32: xla_client.xla_data_pb2.S32,
    387         np.int64: xla_client.xla_data_pb2.S64,
    388         np.float32: xla_client.xla_data_pb2.F32,
    389         np.float64: xla_client.xla_data_pb2.F64,
    390     }
    391 
    392     def _ConvertAndTest(template, src_dtype, dst_dtype):
    393       c = self._NewComputation()
    394       x = c.Constant(np.array(template, dtype=src_dtype))
    395       c.ConvertElementType(x, xla_types[dst_dtype])
    396 
    397       result = c.Build().Compile().Execute()
    398       expected = np.array(template, dtype=dst_dtype)
    399 
    400       self.assertEqual(result.shape, expected.shape)
    401       self.assertEqual(result.dtype, expected.dtype)
    402       np.testing.assert_equal(result, expected)
    403 
    404     x = [0, 1, 0, 0, 1]
    405     for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
    406       _ConvertAndTest(x, src_dtype, dst_dtype)
    407 
    408   def testCrossReplicaSumOneReplica(self):
    409     samples = [
    410         NumpyArrayF32(42.0),
    411         NumpyArrayF32([97.0]),
    412         NumpyArrayF32([64.0, 117.0]),
    413         NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
    414     ]
    415     for lhs in samples:
    416       c = self._NewComputation()
    417       c.CrossReplicaSum(c.Constant(lhs))
    418       self._ExecuteAndCompareExact(c, expected=lhs)
    419 
    420   def testDotMatrixVectorF32(self):
    421     c = self._NewComputation()
    422     lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
    423     rhs = NumpyArrayF32([[10.0], [20.0]])
    424     c.Dot(c.Constant(lhs), c.Constant(rhs))
    425     self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
    426 
    427   def testDotMatrixVectorF64(self):
    428     c = self._NewComputation()
    429     lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
    430     rhs = NumpyArrayF64([[10.0], [20.0]])
    431     c.Dot(c.Constant(lhs), c.Constant(rhs))
    432     self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
    433 
    434   def testDotMatrixMatrixF32(self):
    435     c = self._NewComputation()
    436     lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
    437     rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
    438     c.Dot(c.Constant(lhs), c.Constant(rhs))
    439     self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
    440 
    441   def testDotMatrixMatrixF64(self):
    442     c = self._NewComputation()
    443     lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
    444     rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
    445     c.Dot(c.Constant(lhs), c.Constant(rhs))
    446     self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
    447 
    448   def testDotGeneral(self):
    449     c = self._NewComputation()
    450     rng = np.random.RandomState(0)
    451     lhs = NumpyArrayF32(rng.randn(10, 3, 4))
    452     rhs = NumpyArrayF32(rng.randn(10, 4, 5))
    453     dimension_numbers = (([2], [1]), ([0], [0]))
    454     c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
    455     self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
    456 
    457   def testDotGeneralWithDotDimensionNumbersProto(self):
    458     c = self._NewComputation()
    459     rng = np.random.RandomState(0)
    460     lhs = NumpyArrayF32(rng.randn(10, 3, 4))
    461     rhs = NumpyArrayF32(rng.randn(10, 4, 5))
    462 
    463     dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers()
    464     dimension_numbers.lhs_contracting_dimensions.append(2)
    465     dimension_numbers.rhs_contracting_dimensions.append(1)
    466     dimension_numbers.lhs_batch_dimensions.append(0)
    467     dimension_numbers.rhs_batch_dimensions.append(0)
    468 
    469     c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
    470     self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
    471 
    472   def testConvF32Same(self):
    473     c = self._NewComputation()
    474     a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
    475     lhs = a(1, 2, 3, 4)
    476     rhs = a(1, 2, 1, 2) * 10
    477     c.Conv(c.Constant(lhs), c.Constant(rhs),
    478            [1, 1], xla_client.PaddingType.SAME)
    479     result = np.array([[[[640., 700., 760., 300.],
    480                          [880., 940., 1000., 380.],
    481                          [1120., 1180., 1240., 460.]]]])
    482     self._ExecuteAndCompareClose(c, expected=result)
    483 
    484   def testConvF32Valid(self):
    485     c = self._NewComputation()
    486     a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
    487     lhs = a(1, 2, 3, 4)
    488     rhs = a(1, 2, 1, 2) * 10
    489     c.Conv(c.Constant(lhs), c.Constant(rhs),
    490            [2, 1], xla_client.PaddingType.VALID)
    491     result = np.array([[[[640., 700., 760.],
    492                          [1120., 1180., 1240.]]]])
    493     self._ExecuteAndCompareClose(c, expected=result)
    494 
    495   def testConvWithGeneralPaddingF32(self):
    496     c = self._NewComputation()
    497     a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
    498     lhs = a(1, 1, 2, 3)
    499     rhs = a(1, 1, 1, 2) * 10
    500     strides = [1, 1]
    501     pads = [(1, 0), (0, 1)]
    502     lhs_dilation = (2, 1)
    503     rhs_dilation = (1, 1)
    504     c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
    505                              strides, pads, lhs_dilation, rhs_dilation)
    506     result = np.array([[[[0., 0., 0.],
    507                          [10., 20., 0.],
    508                          [0., 0., 0.],
    509                          [40., 50., 0.]]]])
    510     self._ExecuteAndCompareClose(c, expected=result)
    511 
    512   def testBooleanNot(self):
    513     c = self._NewComputation()
    514     arr = NumpyArrayBool([True, False, True])
    515     c.Not(c.Constant(arr))
    516     self._ExecuteAndCompareClose(c, expected=~arr)
    517 
    518   def testExp(self):
    519     c = self._NewComputation()
    520     arr = NumpyArrayF32([3.3, 12.1])
    521     c.Exp(c.Constant(arr))
    522     self._ExecuteAndCompareClose(c, expected=np.exp(arr))
    523 
    524   def testRound(self):
    525     c = self._NewComputation()
    526     arr = NumpyArrayF32([3.3, 12.1])
    527     c.Round(c.Constant(arr))
    528     self._ExecuteAndCompareClose(c, expected=np.round(arr))
    529 
    530   def testLog(self):
    531     c = self._NewComputation()
    532     arr = NumpyArrayF32([3.3, 12.1])
    533     c.Log(c.Constant(arr))
    534     self._ExecuteAndCompareClose(c, expected=np.log(arr))
    535 
    536   def testNeg(self):
    537     c = self._NewComputation()
    538     arr = NumpyArrayF32([3.3, 12.1])
    539     c.Neg(c.Constant(arr))
    540     self._ExecuteAndCompareClose(c, expected=-arr)
    541 
    542   def testFloor(self):
    543     c = self._NewComputation()
    544     arr = NumpyArrayF32([3.3, 12.1])
    545     c.Floor(c.Constant(arr))
    546     self._ExecuteAndCompareClose(c, expected=np.floor(arr))
    547 
    548   def testCeil(self):
    549     c = self._NewComputation()
    550     arr = NumpyArrayF32([3.3, 12.1])
    551     c.Ceil(c.Constant(arr))
    552     self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
    553 
    554   def testAbs(self):
    555     c = self._NewComputation()
    556     arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
    557     c.Abs(c.Constant(arr))
    558     self._ExecuteAndCompareClose(c, expected=np.abs(arr))
    559 
    560   def testTanh(self):
    561     c = self._NewComputation()
    562     arr = NumpyArrayF32([3.3, 12.1])
    563     c.Tanh(c.Constant(arr))
    564     self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
    565 
    566   def testTrans(self):
    567 
    568     def _TransposeAndTest(array):
    569       c = self._NewComputation()
    570       c.Trans(c.Constant(array))
    571       self._ExecuteAndCompareClose(c, expected=array.T)
    572 
    573     # Test square and non-square matrices in both default (C) and F orders.
    574     for array_fun in [NumpyArrayF32, NumpyArrayF64]:
    575       _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
    576       _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
    577       _TransposeAndTest(array_fun([[1, 2], [4, 5]]))
    578       _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
    579 
    580   def testTranspose(self):
    581 
    582     def _TransposeAndTest(array, permutation):
    583       c = self._NewComputation()
    584       c.Transpose(c.Constant(array), permutation)
    585       expected = np.transpose(array, permutation)
    586       self._ExecuteAndCompareClose(c, expected=expected)
    587 
    588     _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
    589     _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
    590     _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
    591     _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
    592 
    593     arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
    594     for permutation in itertools.permutations(range(arr.ndim)):
    595       _TransposeAndTest(arr, permutation)
    596       _TransposeAndTest(np.asfortranarray(arr), permutation)
    597 
    598   def testEq(self):
    599     c = self._NewComputation()
    600     c.Eq(
    601         c.Constant(NumpyArrayS32([1, 2, 3, 4])),
    602         c.Constant(NumpyArrayS32([4, 2, 3, 1])))
    603     self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
    604 
    605   def testNe(self):
    606     c = self._NewComputation()
    607     c.Ne(
    608         c.Constant(NumpyArrayS32([1, 2, 3, 4])),
    609         c.Constant(NumpyArrayS32([4, 2, 3, 1])))
    610     self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
    611 
    612     c.Ne(
    613         c.Constant(NumpyArrayF32([-2.0, 0.0,
    614                                   float("nan"),
    615                                   float("nan")])),
    616         c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
    617     self._ExecuteAndAssertWith(
    618         np.testing.assert_allclose, c, (), expected=[True, False, True, True])
    619 
    620   def testGt(self):
    621     c = self._NewComputation()
    622     c.Gt(
    623         c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
    624         c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
    625     self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
    626 
    627   def testGe(self):
    628     c = self._NewComputation()
    629     c.Ge(
    630         c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
    631         c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
    632     self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
    633 
    634   def testLt(self):
    635     c = self._NewComputation()
    636     c.Lt(
    637         c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
    638         c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
    639     self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
    640 
    641   def testLe(self):
    642     c = self._NewComputation()
    643     c.Le(
    644         c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
    645         c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
    646     self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
    647 
    648   def testMax(self):
    649     c = self._NewComputation()
    650     c.Max(
    651         c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
    652         c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
    653     self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
    654 
    655   def testMaxExplicitBroadcastDim0(self):
    656     c = self._NewComputation()
    657     c.Max(
    658         c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    659         c.Constant(NumpyArrayF32([3, 4, 5])),
    660         broadcast_dimensions=(0,))
    661     self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
    662 
    663   def testMaxExplicitBroadcastDim1(self):
    664     c = self._NewComputation()
    665     c.Max(
    666         c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    667         c.Constant(NumpyArrayF32([3, 4, 5])),
    668         broadcast_dimensions=(1,))
    669     self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
    670 
    671   def testMin(self):
    672     c = self._NewComputation()
    673     c.Min(
    674         c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
    675         c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
    676     self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
    677 
    678   def testPad(self):
    679     c = self._NewComputation()
    680     c.Pad(
    681         c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
    682         c.Constant(NumpyArrayF32(0.0)),
    683         [(1, 2, 1), (0, 1, 0)])
    684     self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
    685                                               [1.0, 2.0, 0.0],
    686                                               [0.0, 0.0, 0.0],
    687                                               [3.0, 4.0, 0.0],
    688                                               [0.0, 0.0, 0.0],
    689                                               [0.0, 0.0, 0.0]])
    690 
    691   def testPadWithPaddingConfig(self):
    692     c = self._NewComputation()
    693     padding_config = xla_client.xla_data_pb2.PaddingConfig()
    694     for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
    695       dimension = padding_config.dimensions.add()
    696       dimension.edge_padding_low = lo
    697       dimension.edge_padding_high = hi
    698       dimension.interior_padding = interior
    699     c.Pad(
    700         c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
    701         c.Constant(NumpyArrayF32(0.0)),
    702         padding_config)
    703     self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
    704                                               [1.0, 2.0, 0.0],
    705                                               [0.0, 0.0, 0.0],
    706                                               [3.0, 4.0, 0.0],
    707                                               [0.0, 0.0, 0.0],
    708                                               [0.0, 0.0, 0.0]])
    709 
    710   def testReshape(self):
    711     c = self._NewComputation()
    712     c.Reshape(
    713         c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
    714         dimensions=[0, 1],
    715         new_sizes=[2, 3])
    716     self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
    717 
    718   def testCollapse(self):
    719     c = self._NewComputation()
    720     c.Collapse(
    721         c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
    722         dimensions=[1, 2])
    723     self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]])
    724 
    725   def testRev(self):
    726     c = self._NewComputation()
    727     c.Rev(
    728         c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
    729         dimensions=[0, 2])
    730     self._ExecuteAndCompareExact(
    731         c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
    732 
    733   def testClampF32(self):
    734     c = self._NewComputation()
    735     c.Clamp(
    736         c.Constant(NumpyArrayF32(-1)),
    737         c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
    738         c.Constant(NumpyArrayF32(2)))
    739     self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
    740 
    741   # TODO(b/72689392): re-enable when bug S32 resolved
    742   def DISABLED_testClampS32(self):
    743     c = self._NewComputation()
    744     c.Clamp(
    745         c.Constant(NumpyArrayS32(-1)),
    746         c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
    747         c.Constant(NumpyArrayS32(2)))
    748     self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2])
    749 
    750   def testSelect(self):
    751     c = self._NewComputation()
    752     c.Select(
    753         c.Constant(NumpyArrayBool([True, False, False, True, False])),
    754         c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
    755         c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
    756     self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
    757 
    758   def testSlice(self):
    759     c = self._NewComputation()
    760     c.Slice(
    761         c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
    762         [3, 2])
    763     self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
    764 
    765   def testDynamicSlice(self):
    766     c = self._NewComputation()
    767     c.DynamicSlice(
    768         c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    769         c.Constant(NumpyArrayS32([1, 0])), [2, 2])
    770     self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
    771 
    772   def testDynamicUpdateSlice(self):
    773     c = self._NewComputation()
    774     c.DynamicUpdateSlice(
    775         c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
    776         c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
    777         c.Constant(NumpyArrayS32([1, 1])))
    778     self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
    779 
    780   def testTuple(self):
    781     c = self._NewComputation()
    782     c.Tuple(
    783         c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
    784         c.Constant(NumpyArrayBool([True, False, False, True])))
    785     result = c.Build().Compile().Execute()
    786     self.assertIsInstance(result, tuple)
    787     np.testing.assert_equal(result[0], 42)
    788     np.testing.assert_allclose(result[1], [1.0, 2.0])
    789     np.testing.assert_equal(result[2], [True, False, False, True])
    790 
    791   def testGetTupleElement(self):
    792     c = self._NewComputation()
    793     c.GetTupleElement(
    794         c.Tuple(
    795             c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
    796             c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
    797     self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
    798 
    799   def testBroadcast(self):
    800     c = self._NewComputation()
    801     c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
    802     self._ExecuteAndCompareExact(
    803         c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
    804 
    805   def testRngNormal(self):
    806     shape = (2, 3)
    807     c = self._NewComputation()
    808     c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)),
    809                 dims=shape)
    810     result = c.Build().Compile().Execute()
    811     # since the result is random, we just check shape and uniqueness
    812     self.assertEqual(result.shape, shape)
    813     self.assertEqual(len(np.unique(result)), np.prod(shape))
    814 
    815   def testRngUniformF32(self):
    816     lo, hi = 2., 4.
    817     shape = (2, 3)
    818     c = self._NewComputation()
    819     c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)),
    820                  dims=shape)
    821     result = c.Build().Compile().Execute()
    822     # since the result is random, we just check shape, uniqueness, and range
    823     self.assertEqual(result.shape, shape)
    824     self.assertEqual(len(np.unique(result)), np.prod(shape))
    825     self.assertTrue(np.all(lo <= result))
    826     self.assertTrue(np.all(result < hi))
    827 
    828   def testRngUniformS32(self):
    829     lo, hi = 2, 4
    830     shape = (2, 3)
    831     c = self._NewComputation()
    832     c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)),
    833                  dims=shape)
    834     result = c.Build().Compile().Execute()
    835     # since the result is random, we just check shape, integrality, and range
    836     self.assertEqual(result.shape, shape)
    837     self.assertEqual(result.dtype, np.int32)
    838     self.assertTrue(np.all(lo <= result))
    839     self.assertTrue(np.all(result < hi))
    840 
    841 
    842 class EmbeddedComputationsTest(LocalComputationTest):
    843   """Tests for XLA graphs with embedded computations (such as maps)."""
    844 
    845   def _CreateConstantS32Computation(self):
    846     """Computation (f32) -> s32 that returns a constant 1 for any input."""
    847     c = self._NewComputation("constant_s32_one")
    848     # TODO(eliben): consider adding a nicer way to create new parameters without
    849     # having to create dummy Numpy arrays or populating Shape messages. Perhaps
    850     # we need our own (Python-client-own) way to represent Shapes conveniently.
    851     c.ParameterFromNumpy(NumpyArrayF32(0))
    852     c.ConstantS32Scalar(1)
    853     return c.Build()
    854 
    855   def _CreateConstantS64Computation(self):
    856     """Computation (f64) -> s64 that returns a constant 1 for any input."""
    857     c = self._NewComputation("constant_s64_one")
    858     # TODO(eliben): consider adding a nicer way to create new parameters without
    859     # having to create dummy Numpy arrays or populating Shape messages. Perhaps
    860     # we need our own (Python-client-own) way to represent Shapes conveniently.
    861     c.ParameterFromNumpy(NumpyArrayF64(0))
    862     c.ConstantS64Scalar(1)
    863     return c.Build()
    864 
    865   def _CreateConstantF32Computation(self):
    866     """Computation (f32) -> f32 that returns a constant 1.0 for any input."""
    867     c = self._NewComputation("constant_f32_one")
    868     c.ParameterFromNumpy(NumpyArrayF32(0))
    869     c.ConstantF32Scalar(1.0)
    870     return c.Build()
    871 
    872   def _CreateConstantF64Computation(self):
    873     """Computation (f64) -> f64 that returns a constant 1.0 for any input."""
    874     c = self._NewComputation("constant_f64_one")
    875     c.ParameterFromNumpy(NumpyArrayF64(0))
    876     c.ConstantF64Scalar(1.0)
    877     return c.Build()
    878 
    879   def _CreateMulF32By2Computation(self):
    880     """Computation (f32) -> f32 that multiplies its parameter by 2."""
    881     c = self._NewComputation("mul_f32_by2")
    882     c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
    883     return c.Build()
    884 
    885   def _CreateMulF32ByParamComputation(self):
    886     """Computation (f32) -> f32 that multiplies one parameter by the other."""
    887     c = self._NewComputation("mul_f32_by_param")
    888     c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)),
    889           c.ParameterFromNumpy(NumpyArrayF32(0)))
    890     return c.Build()
    891 
    892   def _CreateMulF64By2Computation(self):
    893     """Computation (f64) -> f64 that multiplies its parameter by 2."""
    894     c = self._NewComputation("mul_f64_by2")
    895     c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
    896     return c.Build()
    897 
    898   def _CreateBinaryAddF32Computation(self):
    899     """Computation (f32, f32) -> f32 that adds its two parameters."""
    900     c = self._NewComputation("add_param0_by_param1")
    901     c.Add(
    902         c.ParameterFromNumpy(NumpyArrayF32(0)),
    903         c.ParameterFromNumpy(NumpyArrayF32(0)))
    904     return c.Build()
    905 
    906   def _CreateBinaryAddF64Computation(self):
    907     """Computation (f64, f64) -> f64 that adds its two parameters."""
    908     c = self._NewComputation("add_param0_by_param1")
    909     c.Add(
    910         c.ParameterFromNumpy(NumpyArrayF64(0)),
    911         c.ParameterFromNumpy(NumpyArrayF64(0)))
    912     return c.Build()
    913 
    914   def _CreateBinaryDivF32Computation(self):
    915     """Computation (f32, f32) -> f32 that divides its two parameters."""
    916     c = self._NewComputation("div_param0_by_param1")
    917     c.Div(
    918         c.ParameterFromNumpy(NumpyArrayF32(0)),
    919         c.ParameterFromNumpy(NumpyArrayF32(0)))
    920     return c.Build()
    921 
    922   def _CreateBinaryDivF64Computation(self):
    923     """Computation (f64, f64) -> f64 that divides its two parameters."""
    924     c = self._NewComputation("div_param0_by_param1")
    925     c.Div(
    926         c.ParameterFromNumpy(NumpyArrayF64(0)),
    927         c.ParameterFromNumpy(NumpyArrayF64(0)))
    928     return c.Build()
    929 
    930   def _CreateTestF32Lt10Computation(self):
    931     """Computation (f32) -> bool that tests if its parameter is less than 10."""
    932     c = self._NewComputation("test_f32_lt_10")
    933     c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
    934     return c.Build()
    935 
    936   def _CreateTestF64Lt10Computation(self):
    937     """Computation (f64) -> bool that tests if its parameter is less than 10."""
    938     c = self._NewComputation("test_f64_lt_10")
    939     c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
    940     return c.Build()
    941 
    942   def _CreateBinaryGeF32Computation(self):
    943     """Computation (f32, f32) -> bool that tests first_param >= second_param."""
    944     c = self._NewComputation("param0_lt_param1")
    945     c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)),
    946          c.ParameterFromNumpy(NumpyArrayF32(0)))
    947     return c.Build()
    948 
    949   def _CreateBinaryGeF64Computation(self):
    950     """Computation (f64, f64) -> bool that tests first_param >= second_param."""
    951     c = self._NewComputation("param0_lt_param1")
    952     c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)),
    953          c.ParameterFromNumpy(NumpyArrayF64(0)))
    954     return c.Build()
    955 
    956   def _MakeSample3DArrayF32(self):
    957     return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
    958                           [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
    959 
    960   def _MakeSample3DArrayF64(self):
    961     return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
    962                           [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
    963 
    964   def testCallF32(self):
    965     c = self._NewComputation()
    966     c.Call(
    967         self._CreateMulF32By2Computation(),
    968         operands=(c.ConstantF32Scalar(5.0),))
    969     self._ExecuteAndCompareClose(c, expected=10.0)
    970 
    971   def testCallF64(self):
    972     c = self._NewComputation()
    973     c.Call(
    974         self._CreateMulF64By2Computation(),
    975         operands=(c.ConstantF64Scalar(5.0),))
    976     self._ExecuteAndCompareClose(c, expected=10.0)
    977 
    978   def testMapEachElementToS32Constant(self):
    979     c = self._NewComputation()
    980     c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
    981           self._CreateConstantS32Computation(), [0])
    982     self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
    983 
    984   def testMapEachElementToS64Constant(self):
    985     c = self._NewComputation()
    986     c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
    987           self._CreateConstantS64Computation(), [0])
    988     self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
    989 
    990   def testMapMulBy2F32(self):
    991     c = self._NewComputation()
    992     c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
    993           self._CreateMulF32By2Computation(), [0])
    994     self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
    995 
    996   def testMapMulBy2F64(self):
    997     c = self._NewComputation()
    998     c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
    999           self._CreateMulF64By2Computation(), [0])
   1000     self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
   1001 
   1002   def testSimpleMapChainF32(self):
   1003     # Chains a map of constant-f32 with a map of mul-by-2
   1004     c = self._NewComputation()
   1005     const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
   1006                       self._CreateConstantF32Computation(), [0])
   1007     c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
   1008     self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
   1009 
   1010   def testSimpleMapChainF64(self):
   1011     # Chains a map of constant-f64 with a map of mul-by-2
   1012     c = self._NewComputation()
   1013     const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
   1014                       self._CreateConstantF64Computation(), [0])
   1015     c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
   1016     self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
   1017 
   1018   def testDivVectorsWithMapF32(self):
   1019     c = self._NewComputation()
   1020     c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
   1021            c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
   1022           self._CreateBinaryDivF32Computation(), [0])
   1023     self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
   1024 
   1025   def testDivVectorsWithMapF64(self):
   1026     c = self._NewComputation()
   1027     c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
   1028            c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
   1029           self._CreateBinaryDivF64Computation(), [0])
   1030     self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
   1031 
   1032   def DISABLED_testMapWithStaticOperands(self):
   1033     c = self._NewComputation()
   1034     factor = c.ConstantF32Scalar(3.0)
   1035     c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
   1036           self._CreateMulF32ByParamComputation(), [0],
   1037           static_operands=[factor])
   1038     self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0])
   1039 
   1040   def testSelectAndScatterF32(self):
   1041     c = self._NewComputation()
   1042     c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
   1043                        select=self._CreateBinaryGeF32Computation(),
   1044                        window_dimensions=(2, 1),
   1045                        window_strides=(1, 2),
   1046                        padding=xla_client.PaddingType.VALID,
   1047                        source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
   1048                        init_value=c.Constant(NumpyArrayF32(1)),
   1049                        scatter=self._CreateBinaryAddF32Computation())
   1050     self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
   1051 
   1052   def testSelectAndScatterF64(self):
   1053     c = self._NewComputation()
   1054     c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
   1055                        select=self._CreateBinaryGeF64Computation(),
   1056                        window_dimensions=(2, 1),
   1057                        window_strides=(1, 2),
   1058                        padding=xla_client.PaddingType.VALID,
   1059                        source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
   1060                        init_value=c.Constant(NumpyArrayF64(1)),
   1061                        scatter=self._CreateBinaryAddF64Computation())
   1062     self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
   1063 
   1064   def testReduce1DtoScalarF32(self):
   1065     c = self._NewComputation()
   1066     c.Reduce(
   1067         operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
   1068         init_value=c.ConstantF32Scalar(0),
   1069         computation_to_apply=self._CreateBinaryAddF32Computation(),
   1070         dimensions=[0])
   1071     self._ExecuteAndCompareClose(c, expected=10)
   1072 
   1073   def testReduce1DtoScalarF64(self):
   1074     c = self._NewComputation()
   1075     c.Reduce(
   1076         operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
   1077         init_value=c.ConstantF64Scalar(0),
   1078         computation_to_apply=self._CreateBinaryAddF64Computation(),
   1079         dimensions=[0])
   1080     self._ExecuteAndCompareClose(c, expected=10)
   1081 
   1082   def testReduce2DTo1DDim0F32(self):
   1083     input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1084     c = self._NewComputation()
   1085     c.Reduce(
   1086         operand=c.Constant(input_array),
   1087         init_value=c.ConstantF32Scalar(0),
   1088         computation_to_apply=self._CreateBinaryAddF32Computation(),
   1089         dimensions=[0])
   1090     self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
   1091 
   1092   def testReduce2DTo1DDim0F64(self):
   1093     input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1094     c = self._NewComputation()
   1095     c.Reduce(
   1096         operand=c.Constant(input_array),
   1097         init_value=c.ConstantF64Scalar(0),
   1098         computation_to_apply=self._CreateBinaryAddF64Computation(),
   1099         dimensions=[0])
   1100     self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
   1101 
   1102   def testReduce2DTo1DDim1F32(self):
   1103     input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1104     c = self._NewComputation()
   1105     c.Reduce(
   1106         operand=c.Constant(input_array),
   1107         init_value=c.ConstantF32Scalar(0),
   1108         computation_to_apply=self._CreateBinaryAddF32Computation(),
   1109         dimensions=[1])
   1110     self._ExecuteAndCompareClose(c, expected=[6, 15])
   1111 
   1112   def testReduce2DTo1DDim1F64(self):
   1113     input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1114     c = self._NewComputation()
   1115     c.Reduce(
   1116         operand=c.Constant(input_array),
   1117         init_value=c.ConstantF64Scalar(0),
   1118         computation_to_apply=self._CreateBinaryAddF64Computation(),
   1119         dimensions=[1])
   1120     self._ExecuteAndCompareClose(c, expected=[6, 15])
   1121 
   1122   def testReduce3DAllPossibleWaysF32(self):
   1123     input_array = self._MakeSample3DArrayF32()
   1124 
   1125     def _ReduceAndTest(*dims):
   1126       c = self._NewComputation()
   1127       c.Reduce(
   1128           operand=c.Constant(input_array),
   1129           init_value=c.ConstantF32Scalar(0),
   1130           computation_to_apply=self._CreateBinaryAddF32Computation(),
   1131           dimensions=dims)
   1132       self._ExecuteAndCompareClose(
   1133           c, expected=np.sum(input_array, axis=tuple(dims)))
   1134 
   1135     _ReduceAndTest(0)
   1136     _ReduceAndTest(0)
   1137     _ReduceAndTest(0, 1)
   1138     _ReduceAndTest(0, 2)
   1139     _ReduceAndTest(1, 2)
   1140     _ReduceAndTest(0, 1, 2)
   1141 
   1142   def testReduce3DAllPossibleWaysF64(self):
   1143     input_array = self._MakeSample3DArrayF64()
   1144 
   1145     def _ReduceAndTest(*dims):
   1146       c = self._NewComputation()
   1147       c.Reduce(
   1148           operand=c.Constant(input_array),
   1149           init_value=c.ConstantF64Scalar(0),
   1150           computation_to_apply=self._CreateBinaryAddF64Computation(),
   1151           dimensions=dims)
   1152       self._ExecuteAndCompareClose(
   1153           c, expected=np.sum(input_array, axis=tuple(dims)))
   1154 
   1155     _ReduceAndTest(0)
   1156     _ReduceAndTest(0)
   1157     _ReduceAndTest(0, 1)
   1158     _ReduceAndTest(0, 2)
   1159     _ReduceAndTest(1, 2)
   1160     _ReduceAndTest(0, 1, 2)
   1161 
   1162   def testReduceWindowValidUnitStridesF32(self):
   1163     input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1164     c = self._NewComputation()
   1165     c.ReduceWindow(operand=c.Constant(input_array),
   1166                    init_value=c.ConstantF32Scalar(0),
   1167                    computation_to_apply=self._CreateBinaryAddF32Computation(),
   1168                    window_dimensions=(2, 1), window_strides=(1, 1),
   1169                    padding=xla_client.PaddingType.VALID)
   1170     self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
   1171 
   1172   def testReduceWindowSameUnitStridesF32(self):
   1173     input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1174     c = self._NewComputation()
   1175     c.ReduceWindow(operand=c.Constant(input_array),
   1176                    init_value=c.ConstantF32Scalar(0),
   1177                    computation_to_apply=self._CreateBinaryAddF32Computation(),
   1178                    window_dimensions=(2, 1), window_strides=(1, 1),
   1179                    padding=xla_client.PaddingType.SAME)
   1180     self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
   1181 
   1182   def testReduceWindowValidGeneralStridesF32(self):
   1183     input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1184     c = self._NewComputation()
   1185     c.ReduceWindow(operand=c.Constant(input_array),
   1186                    init_value=c.ConstantF32Scalar(0),
   1187                    computation_to_apply=self._CreateBinaryAddF32Computation(),
   1188                    window_dimensions=(2, 1), window_strides=(1, 2),
   1189                    padding=xla_client.PaddingType.VALID)
   1190     self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
   1191 
   1192   def testReduceWindowValidUnitStridesF64(self):
   1193     input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1194     c = self._NewComputation()
   1195     c.ReduceWindow(operand=c.Constant(input_array),
   1196                    init_value=c.ConstantF64Scalar(0),
   1197                    computation_to_apply=self._CreateBinaryAddF64Computation(),
   1198                    window_dimensions=(2, 1), window_strides=(1, 1),
   1199                    padding=xla_client.PaddingType.VALID)
   1200     self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
   1201 
   1202   def testReduceWindowSameUnitStridesF64(self):
   1203     input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1204     c = self._NewComputation()
   1205     c.ReduceWindow(operand=c.Constant(input_array),
   1206                    init_value=c.ConstantF64Scalar(0),
   1207                    computation_to_apply=self._CreateBinaryAddF64Computation(),
   1208                    window_dimensions=(2, 1), window_strides=(1, 1),
   1209                    padding=xla_client.PaddingType.SAME)
   1210     self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
   1211 
   1212   def testReduceWindowValidGeneralStridesF64(self):
   1213     input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
   1214     c = self._NewComputation()
   1215     c.ReduceWindow(operand=c.Constant(input_array),
   1216                    init_value=c.ConstantF64Scalar(0),
   1217                    computation_to_apply=self._CreateBinaryAddF64Computation(),
   1218                    window_dimensions=(2, 1), window_strides=(1, 2),
   1219                    padding=xla_client.PaddingType.VALID)
   1220     self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
   1221 
   1222   def testWhileF32(self):
   1223     cond = self._CreateTestF32Lt10Computation()
   1224     body = self._CreateMulF32By2Computation()
   1225     c = self._NewComputation()
   1226     init = c.ConstantF32Scalar(1.)
   1227     c.While(cond, body, init)
   1228     self._ExecuteAndCompareClose(c, expected=16.)
   1229 
   1230   def testWhileF64(self):
   1231     cond = self._CreateTestF64Lt10Computation()
   1232     body = self._CreateMulF64By2Computation()
   1233     c = self._NewComputation()
   1234     init = c.ConstantF64Scalar(1.)
   1235     c.While(cond, body, init)
   1236     self._ExecuteAndCompareClose(c, expected=16.)
   1237 
   1238   def testConditionalTrue(self):
   1239     c = self._NewComputation()
   1240     pred = c.ConstantPredScalar(True)
   1241     true_operand = c.ConstantF32Scalar(3.)
   1242     true_computation = self._CreateMulF32By2Computation()
   1243     false_operand = c.ConstantF32Scalar(2.)
   1244     false_computation = self._CreateConstantF32Computation()
   1245     c.Conditional(pred, true_operand, true_computation, false_operand,
   1246                   false_computation)
   1247     self._ExecuteAndCompareClose(c, expected=6.)
   1248 
   1249   def testConditionalFalse(self):
   1250     c = self._NewComputation()
   1251     pred = c.ConstantPredScalar(False)
   1252     true_operand = c.ConstantF32Scalar(3.)
   1253     true_computation = self._CreateMulF32By2Computation()
   1254     false_operand = c.ConstantF32Scalar(2.)
   1255     false_computation = self._CreateConstantF32Computation()
   1256     c.Conditional(pred, true_operand, true_computation, false_operand,
   1257                   false_computation)
   1258     self._ExecuteAndCompareClose(c, expected=1.)
   1259 
   1260   def testInfeedS32Values(self):
   1261     to_infeed = NumpyArrayS32([1, 2, 3, 4])
   1262     c = self._NewComputation()
   1263     c.Infeed(xla_client.Shape.from_numpy(to_infeed[0]))
   1264     compiled_c = c.Build().CompileWithExampleArguments()
   1265     for item in to_infeed:
   1266       xla_client.transfer_to_infeed(item)
   1267 
   1268     for item in to_infeed:
   1269       result = compiled_c.Execute()
   1270       self.assertEqual(result, item)
   1271 
   1272   def testInfeedThenOutfeedS32(self):
   1273     to_round_trip = NumpyArrayS32([1, 2, 3, 4])
   1274     c = self._NewComputation()
   1275     x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0]))
   1276     c.Outfeed(x)
   1277 
   1278     compiled_c = c.Build().CompileWithExampleArguments()
   1279 
   1280     for want in to_round_trip:
   1281       execution = threading.Thread(target=compiled_c.Execute)
   1282       execution.start()
   1283       xla_client.transfer_to_infeed(want)
   1284       got = xla_client.transfer_from_outfeed(
   1285           xla_client.Shape.from_numpy(to_round_trip[0]))
   1286       execution.join()
   1287       self.assertEqual(want, got)
   1288 
   1289 
   1290 class ErrorTest(LocalComputationTest):
   1291 
   1292   def setUp(self):
   1293     self.f32_scalar_2 = NumpyArrayF32(2.0)
   1294     self.s32_scalar_2 = NumpyArrayS32(2)
   1295 
   1296   def testInvokeWithWrongElementType(self):
   1297     c = self._NewComputation()
   1298     c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
   1299     c.ParameterFromNumpy(self.s32_scalar_2)
   1300     c.ClearOpMetadata()
   1301     self.assertRaisesRegexp(
   1302         RuntimeError, r"Invalid argument shape.*xla_client_test.py.*"
   1303         r"expected s32\[\], got f32\[\]",
   1304         lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
   1305 
   1306 
   1307 if __name__ == "__main__":
   1308   unittest.main()
   1309