Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.reverse_sequence_op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import sys
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.ops import resource_variable_ops
     33 from tensorflow.python.platform import test
     34 
     35 
     36 class WhereOpTest(test.TestCase):
     37 
     38   def _testWhere(self, x, truth, expected_err_re=None):
     39     with self.test_session(use_gpu=True):
     40       ans = array_ops.where(x)
     41       self.assertEqual([None, x.ndim], ans.get_shape().as_list())
     42       if expected_err_re is None:
     43         tf_ans = ans.eval()
     44         self.assertAllClose(tf_ans, truth, atol=1e-10)
     45       else:
     46         with self.assertRaisesOpError(expected_err_re):
     47           ans.eval()
     48 
     49   def testWrongNumbers(self):
     50     with self.test_session(use_gpu=True):
     51       with self.assertRaises(ValueError):
     52         array_ops.where([False, True], [1, 2], None)
     53       with self.assertRaises(ValueError):
     54         array_ops.where([False, True], None, [1, 2])
     55 
     56   def testBasicVec(self):
     57     x = np.asarray([True, False])
     58     truth = np.asarray([[0]], dtype=np.int64)
     59     self._testWhere(x, truth)
     60 
     61     x = np.asarray([False, True, False])
     62     truth = np.asarray([[1]], dtype=np.int64)
     63     self._testWhere(x, truth)
     64 
     65     x = np.asarray([False, False, True, False, True])
     66     truth = np.asarray([[2], [4]], dtype=np.int64)
     67     self._testWhere(x, truth)
     68 
     69   def testRandomVec(self):
     70     x = np.random.rand(1000000) > 0.5
     71     truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
     72     self._testWhere(x, truth)
     73 
     74   def testBasicMat(self):
     75     x = np.asarray([[True, False], [True, False]])
     76 
     77     # Ensure RowMajor mode
     78     truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
     79 
     80     self._testWhere(x, truth)
     81 
     82   def testBasic3Tensor(self):
     83     x = np.asarray([[[True, False], [True, False]],
     84                     [[False, True], [False, True]],
     85                     [[False, False], [False, True]]])
     86 
     87     # Ensure RowMajor mode
     88     truth = np.asarray(
     89         [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64)
     90 
     91     self._testWhere(x, truth)
     92 
     93   def _testRandom(self, dtype, expected_err_re=None):
     94     shape = [127, 33, 53]
     95     x = np.random.randn(*shape) + 1j * np.random.randn(*shape)
     96     x = (np.random.randn(*shape) > 0).astype(dtype)
     97     truth = np.where(np.abs(x) > 0)  # Tuples of indices by axis.
     98     truth = np.vstack(truth).T  # Convert to [num_true, indices].
     99     self._testWhere(x, truth, expected_err_re)
    100 
    101   def testRandomBool(self):
    102     self._testRandom(np.bool)
    103 
    104   def testRandomInt32(self):
    105     self._testRandom(np.int32)
    106 
    107   def testRandomInt64(self):
    108     self._testRandom(np.int64)
    109 
    110   def testRandomFloat(self):
    111     self._testRandom(np.float32)
    112 
    113   def testRandomDouble(self):
    114     self._testRandom(np.float64)
    115 
    116   def testRandomComplex64(self):
    117     self._testRandom(np.complex64)
    118 
    119   def testRandomComplex128(self):
    120     self._testRandom(np.complex128)
    121 
    122   def testRandomUint8(self):
    123     self._testRandom(np.uint8)
    124 
    125   def testRandomInt8(self):
    126     self._testRandom(np.int8)
    127 
    128   def testRandomInt16(self):
    129     self._testRandom(np.int16)
    130 
    131   def testThreeArgument(self):
    132     x = np.array([[-2, 3, -1], [1, -3, -3]])
    133     np_val = np.where(x > 0, x * x, -x)
    134     with self.test_session(use_gpu=True):
    135       tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval()
    136     self.assertAllEqual(tf_val, np_val)
    137 
    138 
    139 class WhereBenchmark(test.Benchmark):
    140 
    141   def benchmarkWhere(self):
    142     for (m, n, p, use_gpu) in itertools.product(
    143         [10],
    144         [10, 100, 1000, 10000, 100000, 1000000],
    145         [0.01, 0.5, 0.99],
    146         [False, True]):
    147       name = "m_%d_n_%d_p_%g_use_gpu_%s" % (m, n, p, use_gpu)
    148       device = "/%s:0" % ("gpu" if use_gpu else "cpu")
    149       with ops.Graph().as_default():
    150         with ops.device(device):
    151           x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p
    152           v = resource_variable_ops.ResourceVariable(x)
    153           op = array_ops.where(v)
    154         with session.Session() as sess:
    155           v.initializer.run()
    156           r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
    157           gb_processed_input = m * n / 1.0e9
    158           # approximate size of output: m*n*p int64s for each axis.
    159           gb_processed_output = 2 * 8 * m * n * p / 1.0e9
    160           gb_processed = gb_processed_input + gb_processed_output
    161           throughput = gb_processed / r["wall_time"]
    162           print("Benchmark: %s \t wall_time: %0.03g s \t "
    163                 "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
    164           sys.stdout.flush()
    165 
    166 if __name__ == "__main__":
    167   test.main()
    168