1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the Python extension-based XLA client.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 import threading 23 24 import numpy as np 25 26 from tensorflow.compiler.xla.python import xla_client 27 import unittest 28 29 30 class LocalComputationTest(unittest.TestCase): 31 """Base class for running an XLA Computation through the local client.""" 32 33 def _NewComputation(self, name=None): 34 if name is None: 35 name = self.id() 36 return xla_client.ComputationBuilder(name) 37 38 def _Execute(self, c, arguments): 39 compiled_c = c.Build().CompileWithExampleArguments(arguments) 40 return compiled_c.Execute(arguments) 41 42 def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): 43 assert expected is not None 44 result = self._Execute(c, arguments) 45 # Numpy's comparison methods are a bit too lenient by treating inputs as 46 # "array-like", meaning that scalar 4 will be happily compared equal to 47 # [[4]]. We'd like to be more strict so assert shapes as well. 48 self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) 49 assert_func(result, expected) 50 51 def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): 52 self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) 53 54 def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): 55 self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, 56 expected) 57 58 59 def NumpyArrayF32(*args, **kwargs): 60 """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" 61 return np.array(*args, dtype=np.float32, **kwargs) 62 63 64 def NumpyArrayF64(*args, **kwargs): 65 """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" 66 return np.array(*args, dtype=np.float64, **kwargs) 67 68 69 def NumpyArrayS32(*args, **kwargs): 70 """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" 71 return np.array(*args, dtype=np.int32, **kwargs) 72 73 74 def NumpyArrayS64(*args, **kwargs): 75 """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" 76 return np.array(*args, dtype=np.int64, **kwargs) 77 78 79 def NumpyArrayBool(*args, **kwargs): 80 """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" 81 return np.array(*args, dtype=np.bool, **kwargs) 82 83 84 class ComputationsWithConstantsTest(LocalComputationTest): 85 """Tests focusing on Constant ops.""" 86 87 def testConstantScalarSumF32(self): 88 c = self._NewComputation() 89 root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) 90 self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) 91 self._ExecuteAndCompareClose(c, expected=4.25) 92 93 def testConstantScalarSumF64(self): 94 c = self._NewComputation() 95 c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) 96 self._ExecuteAndCompareClose(c, expected=4.25) 97 98 def testConstantScalarSumS32(self): 99 c = self._NewComputation() 100 c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) 101 self._ExecuteAndCompareClose(c, expected=3) 102 103 def testConstantScalarSumS64(self): 104 c = self._NewComputation() 105 c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) 106 self._ExecuteAndCompareClose(c, expected=3) 107 108 def testConstantVectorMulF32(self): 109 c = self._NewComputation() 110 c.Mul( 111 c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), 112 c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) 113 self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) 114 115 def testConstantVectorMulF64(self): 116 c = self._NewComputation() 117 c.Mul( 118 c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), 119 c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) 120 self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) 121 122 def testConstantVectorScalarDivF32(self): 123 c = self._NewComputation() 124 c.Div( 125 c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), 126 c.ConstantF32Scalar(2.0)) 127 self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) 128 129 def testConstantVectorScalarDivF64(self): 130 c = self._NewComputation() 131 c.Div( 132 c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), 133 c.ConstantF64Scalar(2.0)) 134 self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) 135 136 def testConstantVectorScalarPowF32(self): 137 c = self._NewComputation() 138 c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) 139 self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) 140 141 def testConstantVectorScalarPowF64(self): 142 c = self._NewComputation() 143 c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) 144 self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) 145 146 def testBooleanAnd(self): 147 c = self._NewComputation() 148 c.And( 149 c.Constant(NumpyArrayBool([True, False, True, False])), 150 c.Constant(NumpyArrayBool([True, True, False, False]))) 151 self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) 152 153 def testBooleanOr(self): 154 c = self._NewComputation() 155 c.Or( 156 c.Constant(NumpyArrayBool([True, False, True, False])), 157 c.Constant(NumpyArrayBool([True, True, False, False]))) 158 self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) 159 160 def testSum2DF32(self): 161 c = self._NewComputation() 162 c.Add( 163 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), 164 c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) 165 self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) 166 167 def testSum2DF64(self): 168 c = self._NewComputation() 169 c.Add( 170 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), 171 c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) 172 self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) 173 174 def testSum2DWith1DBroadcastDim0F32(self): 175 # sum of a 2D array with a 1D array where the latter is replicated across 176 # dimension 0 to match the former's shape. 177 c = self._NewComputation() 178 c.Add( 179 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 180 c.Constant(NumpyArrayF32([10, 20, 30])), 181 broadcast_dimensions=(0,)) 182 self._ExecuteAndCompareClose( 183 c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) 184 185 def testSum2DWith1DBroadcastDim0F64(self): 186 # sum of a 2D array with a 1D array where the latter is replicated across 187 # dimension 0 to match the former's shape. 188 c = self._NewComputation() 189 c.Add( 190 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 191 c.Constant(NumpyArrayF64([10, 20, 30])), 192 broadcast_dimensions=(0,)) 193 self._ExecuteAndCompareClose( 194 c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) 195 196 def testSum2DWith1DBroadcastDim1F32(self): 197 # sum of a 2D array with a 1D array where the latter is replicated across 198 # dimension 1 to match the former's shape. 199 c = self._NewComputation() 200 c.Add( 201 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 202 c.Constant(NumpyArrayF32([10, 20, 30])), 203 broadcast_dimensions=(1,)) 204 self._ExecuteAndCompareClose( 205 c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) 206 207 def testSum2DWith1DBroadcastDim1F64(self): 208 # sum of a 2D array with a 1D array where the latter is replicated across 209 # dimension 1 to match the former's shape. 210 c = self._NewComputation() 211 c.Add( 212 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 213 c.Constant(NumpyArrayF64([10, 20, 30])), 214 broadcast_dimensions=(1,)) 215 self._ExecuteAndCompareClose( 216 c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) 217 218 def testConstantAxpyF32(self): 219 c = self._NewComputation() 220 c.Add( 221 c.Mul( 222 c.ConstantF32Scalar(2), 223 c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), 224 c.Constant(NumpyArrayF32([100, -100, 200, -200]))) 225 self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) 226 227 def testConstantAxpyF64(self): 228 c = self._NewComputation() 229 c.Add( 230 c.Mul( 231 c.ConstantF64Scalar(2), 232 c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), 233 c.Constant(NumpyArrayF64([100, -100, 200, -200]))) 234 self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) 235 236 237 class ParametersTest(LocalComputationTest): 238 """Tests focusing on Parameter ops and argument-passing.""" 239 240 def setUp(self): 241 self.f32_scalar_2 = NumpyArrayF32(2.0) 242 self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) 243 self.f64_scalar_2 = NumpyArrayF64(2.0) 244 self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) 245 self.s32_scalar_3 = NumpyArrayS32(3) 246 self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) 247 self.s64_scalar_3 = NumpyArrayS64(3) 248 self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) 249 250 def testScalarTimesVectorAutonumberF32(self): 251 c = self._NewComputation() 252 p0 = c.ParameterFromNumpy(self.f32_scalar_2) 253 p1 = c.ParameterFromNumpy(self.f32_4vector) 254 c.Mul(p0, p1) 255 self._ExecuteAndCompareClose( 256 c, 257 arguments=[self.f32_scalar_2, self.f32_4vector], 258 expected=[-4.6, 6.6, -8.6, 10.6]) 259 260 def testScalarTimesVectorAutonumberF64(self): 261 c = self._NewComputation() 262 p0 = c.ParameterFromNumpy(self.f64_scalar_2) 263 p1 = c.ParameterFromNumpy(self.f64_4vector) 264 c.Mul(p0, p1) 265 self._ExecuteAndCompareClose( 266 c, 267 arguments=[self.f64_scalar_2, self.f64_4vector], 268 expected=[-4.6, 6.6, -8.6, 10.6]) 269 270 def testScalarTimesVectorS32(self): 271 c = self._NewComputation() 272 p0 = c.ParameterFromNumpy(self.s32_scalar_3) 273 p1 = c.ParameterFromNumpy(self.s32_4vector) 274 c.Mul(p0, p1) 275 self._ExecuteAndCompareExact( 276 c, 277 arguments=[self.s32_scalar_3, self.s32_4vector], 278 expected=[30, 45, -6, 21]) 279 280 def testScalarTimesVectorS64(self): 281 c = self._NewComputation() 282 p0 = c.ParameterFromNumpy(self.s64_scalar_3) 283 p1 = c.ParameterFromNumpy(self.s64_4vector) 284 c.Mul(p0, p1) 285 self._ExecuteAndCompareExact( 286 c, 287 arguments=[self.s64_scalar_3, self.s64_4vector], 288 expected=[30, 45, -6, 21]) 289 290 def testScalarMinusVectorExplicitNumberingF32(self): 291 # Use explicit numbering and pass parameter_num first. Sub is used since 292 # it's not commutative and can help catch parameter reversal within the 293 # computation. 294 c = self._NewComputation() 295 p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) 296 p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) 297 c.Sub(p1, p0) 298 self._ExecuteAndCompareClose( 299 c, 300 arguments=[self.f32_scalar_2, self.f32_4vector], 301 expected=[-4.3, 1.3, -6.3, 3.3]) 302 303 def testScalarMinusVectorExplicitNumberingF64(self): 304 # Use explicit numbering and pass parameter_num first. Sub is used since 305 # it's not commutative and can help catch parameter reversal within the 306 # computation. 307 c = self._NewComputation() 308 p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) 309 p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) 310 c.Sub(p1, p0) 311 self._ExecuteAndCompareClose( 312 c, 313 arguments=[self.f64_scalar_2, self.f64_4vector], 314 expected=[-4.3, 1.3, -6.3, 3.3]) 315 316 317 class LocalBufferTest(LocalComputationTest): 318 """Tests focusing on execution with LocalBuffers.""" 319 320 def _Execute(self, c, arguments): 321 compiled_c = c.Build().CompileWithExampleArguments(arguments) 322 arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments] 323 result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) 324 return result_buffer.to_py() 325 326 def testConstantSum(self): 327 c = self._NewComputation() 328 c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) 329 self._ExecuteAndCompareClose(c, expected=4.25) 330 331 def testOneParameterSum(self): 332 c = self._NewComputation() 333 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) 334 self._ExecuteAndCompareClose( 335 c, 336 arguments=[NumpyArrayF32(1.11)], 337 expected=4.25) 338 339 def testTwoParameterSum(self): 340 c = self._NewComputation() 341 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), 342 c.ParameterFromNumpy(NumpyArrayF32(0.))) 343 self._ExecuteAndCompareClose( 344 c, 345 arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)], 346 expected=4.25) 347 348 def testCannotCallWithDeletedBuffers(self): 349 c = self._NewComputation() 350 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) 351 arg = NumpyArrayF32(1.11) 352 compiled_c = c.Build().CompileWithExampleArguments([arg]) 353 arg_buffer = xla_client.LocalBuffer.from_py(arg) 354 arg_buffer.delete() 355 with self.assertRaises(ValueError): 356 compiled_c.ExecuteWithLocalBuffers([arg_buffer]) 357 358 359 class SingleOpTest(LocalComputationTest): 360 """Tests for single ops. 361 362 The goal here is smoke testing - to exercise the most basic functionality of 363 single XLA ops. As minimal as possible number of additional ops are added 364 around the op being tested. 365 """ 366 367 def testConcatenateF32(self): 368 c = self._NewComputation() 369 c.Concatenate( 370 (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), 371 c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), 372 dimension=0) 373 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 374 375 def testConcatenateF64(self): 376 c = self._NewComputation() 377 c.Concatenate( 378 (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), 379 c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), 380 dimension=0) 381 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 382 383 def testConvertElementType(self): 384 xla_types = { 385 np.bool: xla_client.xla_data_pb2.PRED, 386 np.int32: xla_client.xla_data_pb2.S32, 387 np.int64: xla_client.xla_data_pb2.S64, 388 np.float32: xla_client.xla_data_pb2.F32, 389 np.float64: xla_client.xla_data_pb2.F64, 390 } 391 392 def _ConvertAndTest(template, src_dtype, dst_dtype): 393 c = self._NewComputation() 394 x = c.Constant(np.array(template, dtype=src_dtype)) 395 c.ConvertElementType(x, xla_types[dst_dtype]) 396 397 result = c.Build().Compile().Execute() 398 expected = np.array(template, dtype=dst_dtype) 399 400 self.assertEqual(result.shape, expected.shape) 401 self.assertEqual(result.dtype, expected.dtype) 402 np.testing.assert_equal(result, expected) 403 404 x = [0, 1, 0, 0, 1] 405 for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): 406 _ConvertAndTest(x, src_dtype, dst_dtype) 407 408 def testCrossReplicaSumOneReplica(self): 409 samples = [ 410 NumpyArrayF32(42.0), 411 NumpyArrayF32([97.0]), 412 NumpyArrayF32([64.0, 117.0]), 413 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 414 ] 415 for lhs in samples: 416 c = self._NewComputation() 417 c.CrossReplicaSum(c.Constant(lhs)) 418 self._ExecuteAndCompareExact(c, expected=lhs) 419 420 def testDotMatrixVectorF32(self): 421 c = self._NewComputation() 422 lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) 423 rhs = NumpyArrayF32([[10.0], [20.0]]) 424 c.Dot(c.Constant(lhs), c.Constant(rhs)) 425 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 426 427 def testDotMatrixVectorF64(self): 428 c = self._NewComputation() 429 lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) 430 rhs = NumpyArrayF64([[10.0], [20.0]]) 431 c.Dot(c.Constant(lhs), c.Constant(rhs)) 432 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 433 434 def testDotMatrixMatrixF32(self): 435 c = self._NewComputation() 436 lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) 437 rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) 438 c.Dot(c.Constant(lhs), c.Constant(rhs)) 439 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 440 441 def testDotMatrixMatrixF64(self): 442 c = self._NewComputation() 443 lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) 444 rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) 445 c.Dot(c.Constant(lhs), c.Constant(rhs)) 446 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 447 448 def testDotGeneral(self): 449 c = self._NewComputation() 450 rng = np.random.RandomState(0) 451 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 452 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 453 dimension_numbers = (([2], [1]), ([0], [0])) 454 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) 455 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) 456 457 def testDotGeneralWithDotDimensionNumbersProto(self): 458 c = self._NewComputation() 459 rng = np.random.RandomState(0) 460 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 461 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 462 463 dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers() 464 dimension_numbers.lhs_contracting_dimensions.append(2) 465 dimension_numbers.rhs_contracting_dimensions.append(1) 466 dimension_numbers.lhs_batch_dimensions.append(0) 467 dimension_numbers.rhs_batch_dimensions.append(0) 468 469 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) 470 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) 471 472 def testConvF32Same(self): 473 c = self._NewComputation() 474 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 475 lhs = a(1, 2, 3, 4) 476 rhs = a(1, 2, 1, 2) * 10 477 c.Conv(c.Constant(lhs), c.Constant(rhs), 478 [1, 1], xla_client.PaddingType.SAME) 479 result = np.array([[[[640., 700., 760., 300.], 480 [880., 940., 1000., 380.], 481 [1120., 1180., 1240., 460.]]]]) 482 self._ExecuteAndCompareClose(c, expected=result) 483 484 def testConvF32Valid(self): 485 c = self._NewComputation() 486 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 487 lhs = a(1, 2, 3, 4) 488 rhs = a(1, 2, 1, 2) * 10 489 c.Conv(c.Constant(lhs), c.Constant(rhs), 490 [2, 1], xla_client.PaddingType.VALID) 491 result = np.array([[[[640., 700., 760.], 492 [1120., 1180., 1240.]]]]) 493 self._ExecuteAndCompareClose(c, expected=result) 494 495 def testConvWithGeneralPaddingF32(self): 496 c = self._NewComputation() 497 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 498 lhs = a(1, 1, 2, 3) 499 rhs = a(1, 1, 1, 2) * 10 500 strides = [1, 1] 501 pads = [(1, 0), (0, 1)] 502 lhs_dilation = (2, 1) 503 rhs_dilation = (1, 1) 504 c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs), 505 strides, pads, lhs_dilation, rhs_dilation) 506 result = np.array([[[[0., 0., 0.], 507 [10., 20., 0.], 508 [0., 0., 0.], 509 [40., 50., 0.]]]]) 510 self._ExecuteAndCompareClose(c, expected=result) 511 512 def testBooleanNot(self): 513 c = self._NewComputation() 514 arr = NumpyArrayBool([True, False, True]) 515 c.Not(c.Constant(arr)) 516 self._ExecuteAndCompareClose(c, expected=~arr) 517 518 def testExp(self): 519 c = self._NewComputation() 520 arr = NumpyArrayF32([3.3, 12.1]) 521 c.Exp(c.Constant(arr)) 522 self._ExecuteAndCompareClose(c, expected=np.exp(arr)) 523 524 def testRound(self): 525 c = self._NewComputation() 526 arr = NumpyArrayF32([3.3, 12.1]) 527 c.Round(c.Constant(arr)) 528 self._ExecuteAndCompareClose(c, expected=np.round(arr)) 529 530 def testLog(self): 531 c = self._NewComputation() 532 arr = NumpyArrayF32([3.3, 12.1]) 533 c.Log(c.Constant(arr)) 534 self._ExecuteAndCompareClose(c, expected=np.log(arr)) 535 536 def testNeg(self): 537 c = self._NewComputation() 538 arr = NumpyArrayF32([3.3, 12.1]) 539 c.Neg(c.Constant(arr)) 540 self._ExecuteAndCompareClose(c, expected=-arr) 541 542 def testFloor(self): 543 c = self._NewComputation() 544 arr = NumpyArrayF32([3.3, 12.1]) 545 c.Floor(c.Constant(arr)) 546 self._ExecuteAndCompareClose(c, expected=np.floor(arr)) 547 548 def testCeil(self): 549 c = self._NewComputation() 550 arr = NumpyArrayF32([3.3, 12.1]) 551 c.Ceil(c.Constant(arr)) 552 self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) 553 554 def testAbs(self): 555 c = self._NewComputation() 556 arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) 557 c.Abs(c.Constant(arr)) 558 self._ExecuteAndCompareClose(c, expected=np.abs(arr)) 559 560 def testTanh(self): 561 c = self._NewComputation() 562 arr = NumpyArrayF32([3.3, 12.1]) 563 c.Tanh(c.Constant(arr)) 564 self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) 565 566 def testTrans(self): 567 568 def _TransposeAndTest(array): 569 c = self._NewComputation() 570 c.Trans(c.Constant(array)) 571 self._ExecuteAndCompareClose(c, expected=array.T) 572 573 # Test square and non-square matrices in both default (C) and F orders. 574 for array_fun in [NumpyArrayF32, NumpyArrayF64]: 575 _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) 576 _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) 577 _TransposeAndTest(array_fun([[1, 2], [4, 5]])) 578 _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) 579 580 def testTranspose(self): 581 582 def _TransposeAndTest(array, permutation): 583 c = self._NewComputation() 584 c.Transpose(c.Constant(array), permutation) 585 expected = np.transpose(array, permutation) 586 self._ExecuteAndCompareClose(c, expected=expected) 587 588 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) 589 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) 590 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) 591 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) 592 593 arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) 594 for permutation in itertools.permutations(range(arr.ndim)): 595 _TransposeAndTest(arr, permutation) 596 _TransposeAndTest(np.asfortranarray(arr), permutation) 597 598 def testEq(self): 599 c = self._NewComputation() 600 c.Eq( 601 c.Constant(NumpyArrayS32([1, 2, 3, 4])), 602 c.Constant(NumpyArrayS32([4, 2, 3, 1]))) 603 self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) 604 605 def testNe(self): 606 c = self._NewComputation() 607 c.Ne( 608 c.Constant(NumpyArrayS32([1, 2, 3, 4])), 609 c.Constant(NumpyArrayS32([4, 2, 3, 1]))) 610 self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) 611 612 c.Ne( 613 c.Constant(NumpyArrayF32([-2.0, 0.0, 614 float("nan"), 615 float("nan")])), 616 c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) 617 self._ExecuteAndAssertWith( 618 np.testing.assert_allclose, c, (), expected=[True, False, True, True]) 619 620 def testGt(self): 621 c = self._NewComputation() 622 c.Gt( 623 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 624 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 625 self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) 626 627 def testGe(self): 628 c = self._NewComputation() 629 c.Ge( 630 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 631 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 632 self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) 633 634 def testLt(self): 635 c = self._NewComputation() 636 c.Lt( 637 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 638 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 639 self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) 640 641 def testLe(self): 642 c = self._NewComputation() 643 c.Le( 644 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 645 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 646 self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) 647 648 def testMax(self): 649 c = self._NewComputation() 650 c.Max( 651 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 652 c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 653 self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) 654 655 def testMaxExplicitBroadcastDim0(self): 656 c = self._NewComputation() 657 c.Max( 658 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 659 c.Constant(NumpyArrayF32([3, 4, 5])), 660 broadcast_dimensions=(0,)) 661 self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) 662 663 def testMaxExplicitBroadcastDim1(self): 664 c = self._NewComputation() 665 c.Max( 666 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 667 c.Constant(NumpyArrayF32([3, 4, 5])), 668 broadcast_dimensions=(1,)) 669 self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) 670 671 def testMin(self): 672 c = self._NewComputation() 673 c.Min( 674 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 675 c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 676 self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) 677 678 def testPad(self): 679 c = self._NewComputation() 680 c.Pad( 681 c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 682 c.Constant(NumpyArrayF32(0.0)), 683 [(1, 2, 1), (0, 1, 0)]) 684 self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], 685 [1.0, 2.0, 0.0], 686 [0.0, 0.0, 0.0], 687 [3.0, 4.0, 0.0], 688 [0.0, 0.0, 0.0], 689 [0.0, 0.0, 0.0]]) 690 691 def testPadWithPaddingConfig(self): 692 c = self._NewComputation() 693 padding_config = xla_client.xla_data_pb2.PaddingConfig() 694 for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: 695 dimension = padding_config.dimensions.add() 696 dimension.edge_padding_low = lo 697 dimension.edge_padding_high = hi 698 dimension.interior_padding = interior 699 c.Pad( 700 c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 701 c.Constant(NumpyArrayF32(0.0)), 702 padding_config) 703 self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], 704 [1.0, 2.0, 0.0], 705 [0.0, 0.0, 0.0], 706 [3.0, 4.0, 0.0], 707 [0.0, 0.0, 0.0], 708 [0.0, 0.0, 0.0]]) 709 710 def testReshape(self): 711 c = self._NewComputation() 712 c.Reshape( 713 c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), 714 dimensions=[0, 1], 715 new_sizes=[2, 3]) 716 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) 717 718 def testCollapse(self): 719 c = self._NewComputation() 720 c.Collapse( 721 c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 722 dimensions=[1, 2]) 723 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) 724 725 def testRev(self): 726 c = self._NewComputation() 727 c.Rev( 728 c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 729 dimensions=[0, 2]) 730 self._ExecuteAndCompareExact( 731 c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) 732 733 def testClampF32(self): 734 c = self._NewComputation() 735 c.Clamp( 736 c.Constant(NumpyArrayF32(-1)), 737 c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), 738 c.Constant(NumpyArrayF32(2))) 739 self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) 740 741 # TODO(b/72689392): re-enable when bug S32 resolved 742 def DISABLED_testClampS32(self): 743 c = self._NewComputation() 744 c.Clamp( 745 c.Constant(NumpyArrayS32(-1)), 746 c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), 747 c.Constant(NumpyArrayS32(2))) 748 self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2]) 749 750 def testSelect(self): 751 c = self._NewComputation() 752 c.Select( 753 c.Constant(NumpyArrayBool([True, False, False, True, False])), 754 c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), 755 c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) 756 self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) 757 758 def testSlice(self): 759 c = self._NewComputation() 760 c.Slice( 761 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], 762 [3, 2]) 763 self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) 764 765 def testDynamicSlice(self): 766 c = self._NewComputation() 767 c.DynamicSlice( 768 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 769 c.Constant(NumpyArrayS32([1, 0])), [2, 2]) 770 self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) 771 772 def testDynamicUpdateSlice(self): 773 c = self._NewComputation() 774 c.DynamicUpdateSlice( 775 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 776 c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), 777 c.Constant(NumpyArrayS32([1, 1]))) 778 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) 779 780 def testTuple(self): 781 c = self._NewComputation() 782 c.Tuple( 783 c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), 784 c.Constant(NumpyArrayBool([True, False, False, True]))) 785 result = c.Build().Compile().Execute() 786 self.assertIsInstance(result, tuple) 787 np.testing.assert_equal(result[0], 42) 788 np.testing.assert_allclose(result[1], [1.0, 2.0]) 789 np.testing.assert_equal(result[2], [True, False, False, True]) 790 791 def testGetTupleElement(self): 792 c = self._NewComputation() 793 c.GetTupleElement( 794 c.Tuple( 795 c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), 796 c.Constant(NumpyArrayBool([True, False, False, True]))), 1) 797 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) 798 799 def testBroadcast(self): 800 c = self._NewComputation() 801 c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) 802 self._ExecuteAndCompareExact( 803 c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) 804 805 def testRngNormal(self): 806 shape = (2, 3) 807 c = self._NewComputation() 808 c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)), 809 dims=shape) 810 result = c.Build().Compile().Execute() 811 # since the result is random, we just check shape and uniqueness 812 self.assertEqual(result.shape, shape) 813 self.assertEqual(len(np.unique(result)), np.prod(shape)) 814 815 def testRngUniformF32(self): 816 lo, hi = 2., 4. 817 shape = (2, 3) 818 c = self._NewComputation() 819 c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)), 820 dims=shape) 821 result = c.Build().Compile().Execute() 822 # since the result is random, we just check shape, uniqueness, and range 823 self.assertEqual(result.shape, shape) 824 self.assertEqual(len(np.unique(result)), np.prod(shape)) 825 self.assertTrue(np.all(lo <= result)) 826 self.assertTrue(np.all(result < hi)) 827 828 def testRngUniformS32(self): 829 lo, hi = 2, 4 830 shape = (2, 3) 831 c = self._NewComputation() 832 c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)), 833 dims=shape) 834 result = c.Build().Compile().Execute() 835 # since the result is random, we just check shape, integrality, and range 836 self.assertEqual(result.shape, shape) 837 self.assertEqual(result.dtype, np.int32) 838 self.assertTrue(np.all(lo <= result)) 839 self.assertTrue(np.all(result < hi)) 840 841 842 class EmbeddedComputationsTest(LocalComputationTest): 843 """Tests for XLA graphs with embedded computations (such as maps).""" 844 845 def _CreateConstantS32Computation(self): 846 """Computation (f32) -> s32 that returns a constant 1 for any input.""" 847 c = self._NewComputation("constant_s32_one") 848 # TODO(eliben): consider adding a nicer way to create new parameters without 849 # having to create dummy Numpy arrays or populating Shape messages. Perhaps 850 # we need our own (Python-client-own) way to represent Shapes conveniently. 851 c.ParameterFromNumpy(NumpyArrayF32(0)) 852 c.ConstantS32Scalar(1) 853 return c.Build() 854 855 def _CreateConstantS64Computation(self): 856 """Computation (f64) -> s64 that returns a constant 1 for any input.""" 857 c = self._NewComputation("constant_s64_one") 858 # TODO(eliben): consider adding a nicer way to create new parameters without 859 # having to create dummy Numpy arrays or populating Shape messages. Perhaps 860 # we need our own (Python-client-own) way to represent Shapes conveniently. 861 c.ParameterFromNumpy(NumpyArrayF64(0)) 862 c.ConstantS64Scalar(1) 863 return c.Build() 864 865 def _CreateConstantF32Computation(self): 866 """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" 867 c = self._NewComputation("constant_f32_one") 868 c.ParameterFromNumpy(NumpyArrayF32(0)) 869 c.ConstantF32Scalar(1.0) 870 return c.Build() 871 872 def _CreateConstantF64Computation(self): 873 """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" 874 c = self._NewComputation("constant_f64_one") 875 c.ParameterFromNumpy(NumpyArrayF64(0)) 876 c.ConstantF64Scalar(1.0) 877 return c.Build() 878 879 def _CreateMulF32By2Computation(self): 880 """Computation (f32) -> f32 that multiplies its parameter by 2.""" 881 c = self._NewComputation("mul_f32_by2") 882 c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) 883 return c.Build() 884 885 def _CreateMulF32ByParamComputation(self): 886 """Computation (f32) -> f32 that multiplies one parameter by the other.""" 887 c = self._NewComputation("mul_f32_by_param") 888 c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), 889 c.ParameterFromNumpy(NumpyArrayF32(0))) 890 return c.Build() 891 892 def _CreateMulF64By2Computation(self): 893 """Computation (f64) -> f64 that multiplies its parameter by 2.""" 894 c = self._NewComputation("mul_f64_by2") 895 c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) 896 return c.Build() 897 898 def _CreateBinaryAddF32Computation(self): 899 """Computation (f32, f32) -> f32 that adds its two parameters.""" 900 c = self._NewComputation("add_param0_by_param1") 901 c.Add( 902 c.ParameterFromNumpy(NumpyArrayF32(0)), 903 c.ParameterFromNumpy(NumpyArrayF32(0))) 904 return c.Build() 905 906 def _CreateBinaryAddF64Computation(self): 907 """Computation (f64, f64) -> f64 that adds its two parameters.""" 908 c = self._NewComputation("add_param0_by_param1") 909 c.Add( 910 c.ParameterFromNumpy(NumpyArrayF64(0)), 911 c.ParameterFromNumpy(NumpyArrayF64(0))) 912 return c.Build() 913 914 def _CreateBinaryDivF32Computation(self): 915 """Computation (f32, f32) -> f32 that divides its two parameters.""" 916 c = self._NewComputation("div_param0_by_param1") 917 c.Div( 918 c.ParameterFromNumpy(NumpyArrayF32(0)), 919 c.ParameterFromNumpy(NumpyArrayF32(0))) 920 return c.Build() 921 922 def _CreateBinaryDivF64Computation(self): 923 """Computation (f64, f64) -> f64 that divides its two parameters.""" 924 c = self._NewComputation("div_param0_by_param1") 925 c.Div( 926 c.ParameterFromNumpy(NumpyArrayF64(0)), 927 c.ParameterFromNumpy(NumpyArrayF64(0))) 928 return c.Build() 929 930 def _CreateTestF32Lt10Computation(self): 931 """Computation (f32) -> bool that tests if its parameter is less than 10.""" 932 c = self._NewComputation("test_f32_lt_10") 933 c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) 934 return c.Build() 935 936 def _CreateTestF64Lt10Computation(self): 937 """Computation (f64) -> bool that tests if its parameter is less than 10.""" 938 c = self._NewComputation("test_f64_lt_10") 939 c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) 940 return c.Build() 941 942 def _CreateBinaryGeF32Computation(self): 943 """Computation (f32, f32) -> bool that tests first_param >= second_param.""" 944 c = self._NewComputation("param0_lt_param1") 945 c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)), 946 c.ParameterFromNumpy(NumpyArrayF32(0))) 947 return c.Build() 948 949 def _CreateBinaryGeF64Computation(self): 950 """Computation (f64, f64) -> bool that tests first_param >= second_param.""" 951 c = self._NewComputation("param0_lt_param1") 952 c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)), 953 c.ParameterFromNumpy(NumpyArrayF64(0))) 954 return c.Build() 955 956 def _MakeSample3DArrayF32(self): 957 return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 958 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) 959 960 def _MakeSample3DArrayF64(self): 961 return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 962 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) 963 964 def testCallF32(self): 965 c = self._NewComputation() 966 c.Call( 967 self._CreateMulF32By2Computation(), 968 operands=(c.ConstantF32Scalar(5.0),)) 969 self._ExecuteAndCompareClose(c, expected=10.0) 970 971 def testCallF64(self): 972 c = self._NewComputation() 973 c.Call( 974 self._CreateMulF64By2Computation(), 975 operands=(c.ConstantF64Scalar(5.0),)) 976 self._ExecuteAndCompareClose(c, expected=10.0) 977 978 def testMapEachElementToS32Constant(self): 979 c = self._NewComputation() 980 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 981 self._CreateConstantS32Computation(), [0]) 982 self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) 983 984 def testMapEachElementToS64Constant(self): 985 c = self._NewComputation() 986 c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 987 self._CreateConstantS64Computation(), [0]) 988 self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) 989 990 def testMapMulBy2F32(self): 991 c = self._NewComputation() 992 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 993 self._CreateMulF32By2Computation(), [0]) 994 self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) 995 996 def testMapMulBy2F64(self): 997 c = self._NewComputation() 998 c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 999 self._CreateMulF64By2Computation(), [0]) 1000 self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) 1001 1002 def testSimpleMapChainF32(self): 1003 # Chains a map of constant-f32 with a map of mul-by-2 1004 c = self._NewComputation() 1005 const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1006 self._CreateConstantF32Computation(), [0]) 1007 c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) 1008 self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) 1009 1010 def testSimpleMapChainF64(self): 1011 # Chains a map of constant-f64 with a map of mul-by-2 1012 c = self._NewComputation() 1013 const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1014 self._CreateConstantF64Computation(), [0]) 1015 c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) 1016 self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) 1017 1018 def testDivVectorsWithMapF32(self): 1019 c = self._NewComputation() 1020 c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), 1021 c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), 1022 self._CreateBinaryDivF32Computation(), [0]) 1023 self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) 1024 1025 def testDivVectorsWithMapF64(self): 1026 c = self._NewComputation() 1027 c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), 1028 c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), 1029 self._CreateBinaryDivF64Computation(), [0]) 1030 self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) 1031 1032 def DISABLED_testMapWithStaticOperands(self): 1033 c = self._NewComputation() 1034 factor = c.ConstantF32Scalar(3.0) 1035 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1036 self._CreateMulF32ByParamComputation(), [0], 1037 static_operands=[factor]) 1038 self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) 1039 1040 def testSelectAndScatterF32(self): 1041 c = self._NewComputation() 1042 c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), 1043 select=self._CreateBinaryGeF32Computation(), 1044 window_dimensions=(2, 1), 1045 window_strides=(1, 2), 1046 padding=xla_client.PaddingType.VALID, 1047 source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), 1048 init_value=c.Constant(NumpyArrayF32(1)), 1049 scatter=self._CreateBinaryAddF32Computation()) 1050 self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) 1051 1052 def testSelectAndScatterF64(self): 1053 c = self._NewComputation() 1054 c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), 1055 select=self._CreateBinaryGeF64Computation(), 1056 window_dimensions=(2, 1), 1057 window_strides=(1, 2), 1058 padding=xla_client.PaddingType.VALID, 1059 source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), 1060 init_value=c.Constant(NumpyArrayF64(1)), 1061 scatter=self._CreateBinaryAddF64Computation()) 1062 self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) 1063 1064 def testReduce1DtoScalarF32(self): 1065 c = self._NewComputation() 1066 c.Reduce( 1067 operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), 1068 init_value=c.ConstantF32Scalar(0), 1069 computation_to_apply=self._CreateBinaryAddF32Computation(), 1070 dimensions=[0]) 1071 self._ExecuteAndCompareClose(c, expected=10) 1072 1073 def testReduce1DtoScalarF64(self): 1074 c = self._NewComputation() 1075 c.Reduce( 1076 operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), 1077 init_value=c.ConstantF64Scalar(0), 1078 computation_to_apply=self._CreateBinaryAddF64Computation(), 1079 dimensions=[0]) 1080 self._ExecuteAndCompareClose(c, expected=10) 1081 1082 def testReduce2DTo1DDim0F32(self): 1083 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1084 c = self._NewComputation() 1085 c.Reduce( 1086 operand=c.Constant(input_array), 1087 init_value=c.ConstantF32Scalar(0), 1088 computation_to_apply=self._CreateBinaryAddF32Computation(), 1089 dimensions=[0]) 1090 self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) 1091 1092 def testReduce2DTo1DDim0F64(self): 1093 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1094 c = self._NewComputation() 1095 c.Reduce( 1096 operand=c.Constant(input_array), 1097 init_value=c.ConstantF64Scalar(0), 1098 computation_to_apply=self._CreateBinaryAddF64Computation(), 1099 dimensions=[0]) 1100 self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) 1101 1102 def testReduce2DTo1DDim1F32(self): 1103 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1104 c = self._NewComputation() 1105 c.Reduce( 1106 operand=c.Constant(input_array), 1107 init_value=c.ConstantF32Scalar(0), 1108 computation_to_apply=self._CreateBinaryAddF32Computation(), 1109 dimensions=[1]) 1110 self._ExecuteAndCompareClose(c, expected=[6, 15]) 1111 1112 def testReduce2DTo1DDim1F64(self): 1113 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1114 c = self._NewComputation() 1115 c.Reduce( 1116 operand=c.Constant(input_array), 1117 init_value=c.ConstantF64Scalar(0), 1118 computation_to_apply=self._CreateBinaryAddF64Computation(), 1119 dimensions=[1]) 1120 self._ExecuteAndCompareClose(c, expected=[6, 15]) 1121 1122 def testReduce3DAllPossibleWaysF32(self): 1123 input_array = self._MakeSample3DArrayF32() 1124 1125 def _ReduceAndTest(*dims): 1126 c = self._NewComputation() 1127 c.Reduce( 1128 operand=c.Constant(input_array), 1129 init_value=c.ConstantF32Scalar(0), 1130 computation_to_apply=self._CreateBinaryAddF32Computation(), 1131 dimensions=dims) 1132 self._ExecuteAndCompareClose( 1133 c, expected=np.sum(input_array, axis=tuple(dims))) 1134 1135 _ReduceAndTest(0) 1136 _ReduceAndTest(0) 1137 _ReduceAndTest(0, 1) 1138 _ReduceAndTest(0, 2) 1139 _ReduceAndTest(1, 2) 1140 _ReduceAndTest(0, 1, 2) 1141 1142 def testReduce3DAllPossibleWaysF64(self): 1143 input_array = self._MakeSample3DArrayF64() 1144 1145 def _ReduceAndTest(*dims): 1146 c = self._NewComputation() 1147 c.Reduce( 1148 operand=c.Constant(input_array), 1149 init_value=c.ConstantF64Scalar(0), 1150 computation_to_apply=self._CreateBinaryAddF64Computation(), 1151 dimensions=dims) 1152 self._ExecuteAndCompareClose( 1153 c, expected=np.sum(input_array, axis=tuple(dims))) 1154 1155 _ReduceAndTest(0) 1156 _ReduceAndTest(0) 1157 _ReduceAndTest(0, 1) 1158 _ReduceAndTest(0, 2) 1159 _ReduceAndTest(1, 2) 1160 _ReduceAndTest(0, 1, 2) 1161 1162 def testReduceWindowValidUnitStridesF32(self): 1163 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1164 c = self._NewComputation() 1165 c.ReduceWindow(operand=c.Constant(input_array), 1166 init_value=c.ConstantF32Scalar(0), 1167 computation_to_apply=self._CreateBinaryAddF32Computation(), 1168 window_dimensions=(2, 1), window_strides=(1, 1), 1169 padding=xla_client.PaddingType.VALID) 1170 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) 1171 1172 def testReduceWindowSameUnitStridesF32(self): 1173 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1174 c = self._NewComputation() 1175 c.ReduceWindow(operand=c.Constant(input_array), 1176 init_value=c.ConstantF32Scalar(0), 1177 computation_to_apply=self._CreateBinaryAddF32Computation(), 1178 window_dimensions=(2, 1), window_strides=(1, 1), 1179 padding=xla_client.PaddingType.SAME) 1180 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) 1181 1182 def testReduceWindowValidGeneralStridesF32(self): 1183 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1184 c = self._NewComputation() 1185 c.ReduceWindow(operand=c.Constant(input_array), 1186 init_value=c.ConstantF32Scalar(0), 1187 computation_to_apply=self._CreateBinaryAddF32Computation(), 1188 window_dimensions=(2, 1), window_strides=(1, 2), 1189 padding=xla_client.PaddingType.VALID) 1190 self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) 1191 1192 def testReduceWindowValidUnitStridesF64(self): 1193 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1194 c = self._NewComputation() 1195 c.ReduceWindow(operand=c.Constant(input_array), 1196 init_value=c.ConstantF64Scalar(0), 1197 computation_to_apply=self._CreateBinaryAddF64Computation(), 1198 window_dimensions=(2, 1), window_strides=(1, 1), 1199 padding=xla_client.PaddingType.VALID) 1200 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) 1201 1202 def testReduceWindowSameUnitStridesF64(self): 1203 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1204 c = self._NewComputation() 1205 c.ReduceWindow(operand=c.Constant(input_array), 1206 init_value=c.ConstantF64Scalar(0), 1207 computation_to_apply=self._CreateBinaryAddF64Computation(), 1208 window_dimensions=(2, 1), window_strides=(1, 1), 1209 padding=xla_client.PaddingType.SAME) 1210 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) 1211 1212 def testReduceWindowValidGeneralStridesF64(self): 1213 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1214 c = self._NewComputation() 1215 c.ReduceWindow(operand=c.Constant(input_array), 1216 init_value=c.ConstantF64Scalar(0), 1217 computation_to_apply=self._CreateBinaryAddF64Computation(), 1218 window_dimensions=(2, 1), window_strides=(1, 2), 1219 padding=xla_client.PaddingType.VALID) 1220 self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) 1221 1222 def testWhileF32(self): 1223 cond = self._CreateTestF32Lt10Computation() 1224 body = self._CreateMulF32By2Computation() 1225 c = self._NewComputation() 1226 init = c.ConstantF32Scalar(1.) 1227 c.While(cond, body, init) 1228 self._ExecuteAndCompareClose(c, expected=16.) 1229 1230 def testWhileF64(self): 1231 cond = self._CreateTestF64Lt10Computation() 1232 body = self._CreateMulF64By2Computation() 1233 c = self._NewComputation() 1234 init = c.ConstantF64Scalar(1.) 1235 c.While(cond, body, init) 1236 self._ExecuteAndCompareClose(c, expected=16.) 1237 1238 def testConditionalTrue(self): 1239 c = self._NewComputation() 1240 pred = c.ConstantPredScalar(True) 1241 true_operand = c.ConstantF32Scalar(3.) 1242 true_computation = self._CreateMulF32By2Computation() 1243 false_operand = c.ConstantF32Scalar(2.) 1244 false_computation = self._CreateConstantF32Computation() 1245 c.Conditional(pred, true_operand, true_computation, false_operand, 1246 false_computation) 1247 self._ExecuteAndCompareClose(c, expected=6.) 1248 1249 def testConditionalFalse(self): 1250 c = self._NewComputation() 1251 pred = c.ConstantPredScalar(False) 1252 true_operand = c.ConstantF32Scalar(3.) 1253 true_computation = self._CreateMulF32By2Computation() 1254 false_operand = c.ConstantF32Scalar(2.) 1255 false_computation = self._CreateConstantF32Computation() 1256 c.Conditional(pred, true_operand, true_computation, false_operand, 1257 false_computation) 1258 self._ExecuteAndCompareClose(c, expected=1.) 1259 1260 def testInfeedS32Values(self): 1261 to_infeed = NumpyArrayS32([1, 2, 3, 4]) 1262 c = self._NewComputation() 1263 c.Infeed(xla_client.Shape.from_numpy(to_infeed[0])) 1264 compiled_c = c.Build().CompileWithExampleArguments() 1265 for item in to_infeed: 1266 xla_client.transfer_to_infeed(item) 1267 1268 for item in to_infeed: 1269 result = compiled_c.Execute() 1270 self.assertEqual(result, item) 1271 1272 def testInfeedThenOutfeedS32(self): 1273 to_round_trip = NumpyArrayS32([1, 2, 3, 4]) 1274 c = self._NewComputation() 1275 x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0])) 1276 c.Outfeed(x) 1277 1278 compiled_c = c.Build().CompileWithExampleArguments() 1279 1280 for want in to_round_trip: 1281 execution = threading.Thread(target=compiled_c.Execute) 1282 execution.start() 1283 xla_client.transfer_to_infeed(want) 1284 got = xla_client.transfer_from_outfeed( 1285 xla_client.Shape.from_numpy(to_round_trip[0])) 1286 execution.join() 1287 self.assertEqual(want, got) 1288 1289 1290 class ErrorTest(LocalComputationTest): 1291 1292 def setUp(self): 1293 self.f32_scalar_2 = NumpyArrayF32(2.0) 1294 self.s32_scalar_2 = NumpyArrayS32(2) 1295 1296 def testInvokeWithWrongElementType(self): 1297 c = self._NewComputation() 1298 c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) 1299 c.ParameterFromNumpy(self.s32_scalar_2) 1300 c.ClearOpMetadata() 1301 self.assertRaisesRegexp( 1302 RuntimeError, r"Invalid argument shape.*xla_client_test.py.*" 1303 r"expected s32\[\], got f32\[\]", 1304 lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) 1305 1306 1307 if __name__ == "__main__": 1308 unittest.main() 1309