Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Relu and ReluGrad."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import gradient_checker
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.ops import random_ops
     30 from tensorflow.python.ops import variables
     31 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.training import gradient_descent
     34 
     35 
     36 def _elu_grad_grad(activation):
     37   if activation < 0:
     38     return np.exp(activation)
     39   return 0
     40 
     41 
     42 class ReluTest(test.TestCase):
     43 
     44   def _npRelu(self, np_features):
     45     return np.maximum(np_features, np.zeros(np_features.shape))
     46 
     47   def testNpRelu(self):
     48     self.assertAllClose(
     49         np.array([[0.0, 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
     50         self._npRelu(
     51             np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
     52                                                      0.9]])))
     53 
     54   def _testRelu(self, np_features, use_gpu=False):
     55     np_relu = self._npRelu(np_features)
     56     with self.test_session(use_gpu=use_gpu):
     57       relu = nn_ops.relu(np_features)
     58       tf_relu = relu.eval()
     59     self.assertAllClose(np_relu, tf_relu)
     60     self.assertShapeEqual(np_relu, relu)
     61 
     62   def testNumbers(self):
     63     for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
     64       self._testRelu(
     65           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
     66           use_gpu=False)
     67       if t in [np.float16, np.float32, np.float64]:
     68         self._testRelu(
     69             np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
     70             use_gpu=True)
     71 
     72   # The gradient test for ReLU is a bit tricky as the derivative is not well
     73   # defined at around zero and we want to avoid that in terms of input values.
     74   def testGradientFloat32(self):
     75     with self.test_session():
     76       x = constant_op.constant(
     77           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
     78           shape=[2, 5],
     79           name="x")
     80       y = nn_ops.relu(x, name="relu")
     81       x_init = np.asarray(
     82           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
     83           dtype=np.float32,
     84           order="F")
     85       err = gradient_checker.compute_gradient_error(
     86           x, [2, 5], y, [2, 5], x_init_value=x_init)
     87     print("relu (float32) gradient err = ", err)
     88     self.assertLess(err, 1e-4)
     89 
     90   def testGradientFloat64(self):
     91     with self.test_session():
     92       x = constant_op.constant(
     93           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
     94           shape=[2, 5],
     95           dtype=dtypes.float64,
     96           name="x")
     97       y = nn_ops.relu(x, name="relu")
     98       x_init = np.asarray(
     99           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    100           dtype=np.float64,
    101           order="F")
    102       err = gradient_checker.compute_gradient_error(
    103           x, [2, 5], y, [2, 5], x_init_value=x_init)
    104     print("relu (float64) gradient err = ", err)
    105     self.assertLess(err, 1e-10)
    106 
    107   def testGradGradFloat32(self):
    108     with self.test_session():
    109       x = constant_op.constant(
    110           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
    111           shape=[2, 5],
    112           name="x")
    113       y = nn_ops.relu(x, name="relu")
    114       z = gradients_impl.gradients(y, x)
    115       x_init = np.asarray(
    116           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    117           dtype=np.float32,
    118           order="F")
    119       err = gradient_checker.compute_gradient_error(
    120           x, [2, 5], z[0], [2, 5], x_init_value=x_init)
    121     print("relu (float32) gradient of gradient err = ", err)
    122     self.assertLess(err, 1e-4)
    123 
    124   def testGradGradFloat64(self):
    125     with self.test_session():
    126       x = constant_op.constant(
    127           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
    128           shape=[2, 5],
    129           dtype=dtypes.float64,
    130           name="x")
    131       y = nn_ops.relu(x, name="relu")
    132       z = gradients_impl.gradients(y, x)
    133       x_init = np.asarray(
    134           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    135           dtype=np.float64,
    136           order="F")
    137       err = gradient_checker.compute_gradient_error(
    138           x, [2, 5], z[0], [2, 5], x_init_value=x_init)
    139     print("relu (float64) gradient of gradient err = ", err)
    140     self.assertLess(err, 1e-10)
    141 
    142   def testGradientScalar(self):
    143     with self.test_session() as sess:
    144       x = variables.Variable(100.)
    145       y = nn_ops.relu(x)
    146       loss = y**2
    147       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.25)
    148       train_op = optimizer.minimize(loss)
    149       sess.run(variables.global_variables_initializer())
    150       sess.run(train_op)
    151       self.assertAllClose(x.eval(), 50.0)
    152 
    153 
    154 class Relu6Test(test.TestCase):
    155 
    156   def _npRelu6(self, np_features):
    157     sixes = np.copy(np_features)
    158     sixes.fill(6.0)
    159     return np.minimum(
    160         np.maximum(np_features, np.zeros(np_features.shape)), sixes)
    161 
    162   def testNpRelu6(self):
    163     self.assertAllClose(
    164         np.array([[0.0, 0.7, 0.0, 0.3, 6.0], [0.1, 0.0, 6.0, 0.0, 0.9]]),
    165         self._npRelu6(
    166             np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7,
    167                                                     0.9]])))
    168 
    169   def _testRelu6(self, np_features, use_gpu=False):
    170     np_relu6 = self._npRelu6(np_features)
    171     with self.test_session(use_gpu=use_gpu):
    172       relu6 = nn_ops.relu6(np_features)
    173       tf_relu6 = relu6.eval()
    174     self.assertAllClose(np_relu6, tf_relu6)
    175     self.assertShapeEqual(np_relu6, relu6)
    176 
    177   def testNumbers(self):
    178     for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
    179       self._testRelu6(
    180           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    181           use_gpu=False)
    182       if t in [np.float16, np.float, np.double]:
    183         self._testRelu6(
    184             np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    185             use_gpu=True)
    186 
    187   # The gradient test for ReLU6 is a bit tricky as the derivative is
    188   # not well defined at around zero and six and we want to avoid that
    189   # in terms of input values.
    190   def testGradientFloat32(self):
    191     with self.test_session():
    192       x = constant_op.constant(
    193           [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
    194           shape=[2, 5],
    195           name="x")
    196       y = nn_ops.relu6(x, name="relu6")
    197       x_init = np.asarray(
    198           [[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
    199           dtype=np.float32,
    200           order="F")
    201       err = gradient_checker.compute_gradient_error(
    202           x, [2, 5], y, [2, 5], x_init_value=x_init)
    203     print("relu6 (float32) gradient err = ", err)
    204     self.assertLess(err, 1e-4)
    205 
    206   def testGradientFloat64(self):
    207     with self.test_session():
    208       x = constant_op.constant(
    209           [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
    210           shape=[2, 5],
    211           dtype=dtypes.float64,
    212           name="x")
    213       y = nn_ops.relu6(x, name="relu6")
    214       x_init = np.asarray(
    215           [[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
    216           dtype=np.float64,
    217           order="F")
    218       err = gradient_checker.compute_gradient_error(
    219           x, [2, 5], y, [2, 5], x_init_value=x_init)
    220     print("relu6 (float64) gradient err = ", err)
    221     self.assertLess(err, 1e-10)
    222 
    223 
    224 class EluTest(test.TestCase):
    225 
    226   def _npElu(self, np_features):
    227     return np.where(np_features < 0, np.exp(np_features) - 1, np_features)
    228 
    229   def testNpElu(self):
    230     self.assertAllClose(
    231         np.array([[-0.59343034025, 0.7, -0.39346934028, 0.3, -0.09516258196],
    232                   [0.1, -0.25918177931, 0.5, -0.5034146962, 0.9]]),
    233         self._npElu(
    234             np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
    235                                                      0.9]])))
    236 
    237   def _testElu(self, np_features, use_gpu=False):
    238     np_elu = self._npElu(np_features)
    239     with self.test_session(use_gpu=use_gpu):
    240       elu = nn_ops.elu(np_features)
    241       tf_elu = elu.eval()
    242     self.assertAllClose(np_elu, tf_elu)
    243     self.assertShapeEqual(np_elu, elu)
    244 
    245   def testNumbers(self):
    246     for t in [np.float16, np.float32, np.float64]:
    247       self._testElu(
    248           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    249           use_gpu=False)
    250       self._testElu(
    251           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    252           use_gpu=True)
    253 
    254   def testGradientFloat32(self):
    255     with self.test_session():
    256       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
    257       x = constant_op.constant(x_val, name="x")
    258       y = nn_ops.elu(x, name="elu")
    259       x_init = np.asarray(x_val, dtype=np.float32, order="F")
    260       err = gradient_checker.compute_gradient_error(
    261           x, [2, 5], y, [2, 5], x_init_value=x_init)
    262     print("elu (float32) gradient err = ", err)
    263     self.assertLess(err, 1e-4)
    264 
    265   def testGradientFloat64(self):
    266     with self.test_session():
    267       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
    268       x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
    269       y = nn_ops.elu(x, name="elu")
    270       x_init = np.asarray(x_val, dtype=np.float64, order="F")
    271       err = gradient_checker.compute_gradient_error(
    272           x, [2, 5], y, [2, 5], x_init_value=x_init)
    273     print("elu (float64) gradient err = ", err)
    274     self.assertLess(err, 1e-6)
    275 
    276   def testGradGrad(self):
    277     with self.test_session():
    278       x = array_ops.placeholder(dtype=dtypes.float32)
    279       elu = nn_ops.elu(x)
    280       g, = gradients_impl.gradients(elu, x)
    281       gg, = gradients_impl.gradients(g, x)
    282 
    283       for x_val in [-1, -0.5, 0.5, 1]:
    284         err = np.abs(gg.eval(feed_dict={x: x_val}) - _elu_grad_grad(x_val))
    285         self.assertLess(err, 1e-4)
    286 
    287   def testGradGradFloat32(self):
    288     with self.test_session():
    289       x = constant_op.constant(
    290           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
    291           shape=[2, 5],
    292           name="x")
    293       y = nn_ops.elu(x, name="elu")
    294       z = gradients_impl.gradients(y, x)
    295       x_init = np.asarray(
    296           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    297           dtype=np.float32,
    298           order="F")
    299       err = gradient_checker.compute_gradient_error(
    300           x, [2, 5], z[0], [2, 5], x_init_value=x_init)
    301     print("elu (float32) gradient of gradient err = ", err)
    302     self.assertLess(err, 1e-4)
    303 
    304   def testGradGradFloat64(self):
    305     with self.test_session():
    306       x = constant_op.constant(
    307           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
    308           shape=[2, 5],
    309           dtype=dtypes.float64,
    310           name="x")
    311       y = nn_ops.elu(x, name="elu")
    312       z = gradients_impl.gradients(y, x)
    313       x_init = np.asarray(
    314           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    315           dtype=np.float64,
    316           order="F")
    317       err = gradient_checker.compute_gradient_error(
    318           x, [2, 5], z[0], [2, 5], x_init_value=x_init)
    319     print("elu (float64) gradient of gradient err = ", err)
    320     self.assertLess(err, 1e-6)
    321 
    322 
    323 class SeluTest(test.TestCase):
    324 
    325   def _npSelu(self, np_features):
    326     scale = 1.0507009873554804934193349852946
    327     scale_alpha = 1.7580993408473768599402175208123
    328     return np.where(np_features < 0, scale_alpha * (np.exp(np_features) - 1),
    329                     scale * np_features)
    330 
    331   def testNpSelu(self):
    332     self.assertAllClose(
    333         np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103, -0.16730527],
    334                   [0.1050701, -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
    335         self._npSelu(
    336             np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
    337                                                      0.9]])))
    338 
    339   def _testSelu(self, np_features, use_gpu=False):
    340     np_selu = self._npSelu(np_features)
    341     with self.test_session(use_gpu=use_gpu):
    342       selu = nn_ops.selu(np_features)
    343       tf_selu = selu.eval()
    344     self.assertAllClose(np_selu, tf_selu)
    345     self.assertShapeEqual(np_selu, selu)
    346 
    347   def testNumbers(self):
    348     for t in [np.float16, np.float32, np.float64]:
    349       self._testSelu(
    350           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    351           use_gpu=False)
    352       self._testSelu(
    353           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    354           use_gpu=True)
    355 
    356   def testGradientFloat32(self):
    357     with self.test_session():
    358       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
    359       x = constant_op.constant(x_val, name="x")
    360       y = nn_ops.selu(x, name="selu")
    361       x_init = np.asarray(x_val, dtype=np.float32, order="F")
    362       err = gradient_checker.compute_gradient_error(
    363           x, [2, 5], y, [2, 5], x_init_value=x_init)
    364     print("selu (float32) gradient err = ", err)
    365     self.assertLess(err, 1e-4)
    366 
    367   def testGradientFloat64(self):
    368     with self.test_session():
    369       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
    370       x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
    371       y = nn_ops.selu(x, name="selu")
    372       x_init = np.asarray(x_val, dtype=np.float64, order="F")
    373       err = gradient_checker.compute_gradient_error(
    374           x, [2, 5], y, [2, 5], x_init_value=x_init)
    375     print("selu (float64) gradient err = ", err)
    376     self.assertLess(err, 1e-6)
    377 
    378   def testGradGradFloat32(self):
    379     with self.test_session():
    380       x = constant_op.constant(
    381           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
    382           shape=[2, 5],
    383           name="x")
    384       y = nn_ops.selu(x, name="selu")
    385       z = gradients_impl.gradients(y, x)
    386       x_init = np.asarray(
    387           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    388           dtype=np.float32,
    389           order="F")
    390       err = gradient_checker.compute_gradient_error(
    391           x, [2, 5], z[0], [2, 5], x_init_value=x_init)
    392     print("selu (float32) gradient of gradient err = ", err)
    393     self.assertLess(err, 1e-4)
    394 
    395   def testGradGradFloat64(self):
    396     with self.test_session():
    397       x = constant_op.constant(
    398           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
    399           shape=[2, 5],
    400           dtype=dtypes.float64,
    401           name="x")
    402       y = nn_ops.selu(x, name="selu")
    403       z = gradients_impl.gradients(y, x)
    404       x_init = np.asarray(
    405           [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
    406           dtype=np.float64,
    407           order="F")
    408       err = gradient_checker.compute_gradient_error(
    409           x, [2, 5], z[0], [2, 5], x_init_value=x_init)
    410     print("selu (float64) gradient of gradient err = ", err)
    411     self.assertLess(err, 1e-6)
    412 
    413 
    414 class CreluTest(test.TestCase):
    415 
    416   def testCreluShape(self):
    417     f = random_ops.random_normal([50, 5, 7, 10])
    418     t = nn_ops.crelu(f)
    419     self.assertEqual([50, 5, 7, 20], t.get_shape())
    420 
    421   def _testCrelu(self, np_features, use_gpu=False):
    422     np_relu = np.maximum(np_features, np.zeros_like(np_features))
    423     np_neg_relu = np.maximum(-np_features, np.zeros_like(np_features))
    424     np_crelu = np.concatenate((np_relu, np_neg_relu),
    425                               len(np_features.shape) - 1)
    426 
    427     with self.test_session(use_gpu=use_gpu):
    428       crelu = nn_ops.crelu(np_features)
    429       tf_relu = crelu.eval()
    430 
    431     self.assertAllClose(np_crelu, tf_relu)
    432     self.assertShapeEqual(np_crelu, crelu)
    433 
    434   def testNumbers(self):
    435     for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
    436       self._testCrelu(
    437           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    438           use_gpu=False)
    439       if t in [np.float16, np.float32, np.float64]:
    440         self._testCrelu(
    441             np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
    442             use_gpu=True)
    443 
    444   def testNumbersWithAxis0(self):
    445     with self.test_session():
    446       crelu = nn_ops.crelu(
    447           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0)
    448       tf_relu = crelu.eval()
    449       np_crelu = np.array([[0, 7, 0, 3, 0], [1, 0, 5, 0, 9], [9, 0, 5, 0, 1],
    450                            [0, 3, 0, 7, 0]])
    451       self.assertAllEqual(np_crelu, tf_relu)
    452 
    453   def testNumbersWithAxis1(self):
    454     with self.test_session():
    455       crelu = nn_ops.crelu(
    456           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1)
    457       tf_relu = crelu.eval()
    458       np_crelu = np.array([[0, 7, 0, 3, 0, 9, 0, 5, 0, 1],
    459                            [1, 0, 5, 0, 9, 0, 3, 0, 7, 0]])
    460       self.assertAllEqual(np_crelu, tf_relu)
    461 
    462 
    463 if __name__ == "__main__":
    464   test.main()
    465