Home | History | Annotate | Download | only in keras
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Keras backend."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from absl.testing import parameterized
     21 import numpy as np
     22 import scipy.sparse
     23 
     24 from tensorflow.core.protobuf import config_pb2
     25 from tensorflow.python import keras
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors_impl
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import sparse_tensor
     31 from tensorflow.python.framework import test_util
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import nn
     34 from tensorflow.python.ops import variables
     35 from tensorflow.python.platform import test
     36 from tensorflow.python.util import tf_inspect
     37 
     38 
     39 def compare_single_input_op_to_numpy(keras_op,
     40                                      np_op,
     41                                      input_shape,
     42                                      dtype='float32',
     43                                      negative_values=True,
     44                                      keras_args=None,
     45                                      keras_kwargs=None,
     46                                      np_args=None,
     47                                      np_kwargs=None):
     48   keras_args = keras_args or []
     49   keras_kwargs = keras_kwargs or {}
     50   np_args = np_args or []
     51   np_kwargs = np_kwargs or {}
     52   inputs = 2. * np.random.random(input_shape)
     53   if negative_values:
     54     inputs -= 1.
     55   keras_output = keras_op(keras.backend.variable(inputs, dtype=dtype),
     56                           *keras_args, **keras_kwargs)
     57   keras_output = keras.backend.eval(keras_output)
     58   np_output = np_op(inputs.astype(dtype), *np_args, **np_kwargs)
     59   try:
     60     np.testing.assert_allclose(keras_output, np_output, atol=1e-4)
     61   except AssertionError:
     62     raise AssertionError('Test for op `' + str(keras_op.__name__) + '` failed; '
     63                          'Expected ' + str(np_output) + ' but got ' +
     64                          str(keras_output))
     65 
     66 
     67 def compare_two_inputs_op_to_numpy(keras_op,
     68                                    np_op,
     69                                    input_shape_a,
     70                                    input_shape_b,
     71                                    dtype='float32',
     72                                    keras_args=None,
     73                                    keras_kwargs=None,
     74                                    np_args=None,
     75                                    np_kwargs=None):
     76   keras_args = keras_args or []
     77   keras_kwargs = keras_kwargs or {}
     78   np_args = np_args or []
     79   np_kwargs = np_kwargs or {}
     80   input_a = np.random.random(input_shape_a)
     81   input_b = np.random.random(input_shape_b)
     82   keras_output = keras_op(keras.backend.variable(input_a, dtype=dtype),
     83                           keras.backend.variable(input_b, dtype=dtype),
     84                           *keras_args, **keras_kwargs)
     85   keras_output = keras.backend.eval(keras_output)
     86   np_output = np_op(input_a.astype(dtype), input_b.astype(dtype),
     87                     *np_args, **np_kwargs)
     88   try:
     89     np.testing.assert_allclose(keras_output, np_output, atol=1e-4)
     90   except AssertionError:
     91     raise AssertionError('Test for op `' + str(keras_op.__name__) + '` failed; '
     92                          'Expected ' + str(np_output) + ' but got ' +
     93                          str(keras_output))
     94 
     95 
     96 @test_util.run_all_in_graph_and_eager_modes
     97 class BackendUtilsTest(test.TestCase):
     98 
     99   def test_backend(self):
    100     self.assertEqual(keras.backend.backend(), 'tensorflow')
    101 
    102   def test_get_reset_uids(self):
    103     self.assertEqual(keras.backend.get_uid('foo'), 1)
    104     self.assertEqual(keras.backend.get_uid('foo'), 2)
    105 
    106     keras.backend.reset_uids()
    107     self.assertEqual(keras.backend.get_uid('foo'), 1)
    108 
    109   def test_learning_phase(self):
    110     with self.cached_session() as sess:
    111       with self.assertRaises(ValueError):
    112         keras.backend.set_learning_phase(2)
    113 
    114       # Test running with a learning-phase-consuming layer
    115       with keras.backend.learning_phase_scope(0):
    116         x = keras.Input((3,))
    117         y = keras.layers.BatchNormalization()(x)
    118         if not context.executing_eagerly():
    119           self.evaluate(variables.global_variables_initializer())
    120           sess.run(y, feed_dict={x: np.random.random((2, 3))})
    121 
    122   def test_learning_phase_name(self):
    123     with ops.name_scope('test_scope'):
    124       # Test that outer name scopes do not affect the learning phase's name.
    125       lp = keras.backend.symbolic_learning_phase()
    126     self.assertEqual(lp.name, 'keras_learning_phase:0')
    127 
    128   def test_learning_phase_scope(self):
    129     initial_learning_phase = keras.backend.learning_phase()
    130     with keras.backend.learning_phase_scope(1):
    131       self.assertEqual(keras.backend.learning_phase(), 1)
    132     self.assertEqual(keras.backend.learning_phase(), initial_learning_phase)
    133     with keras.backend.learning_phase_scope(0):
    134       self.assertEqual(keras.backend.learning_phase(), 0)
    135     self.assertEqual(keras.backend.learning_phase(), initial_learning_phase)
    136     with self.assertRaises(ValueError):
    137       with keras.backend.learning_phase_scope(None):
    138         pass
    139     self.assertEqual(keras.backend.learning_phase(), initial_learning_phase)
    140 
    141     new_learning_phase = 0
    142     keras.backend.set_learning_phase(new_learning_phase)
    143     self.assertEqual(keras.backend.learning_phase(), new_learning_phase)
    144     with keras.backend.learning_phase_scope(1):
    145       self.assertEqual(keras.backend.learning_phase(), 1)
    146     self.assertEqual(keras.backend.learning_phase(), new_learning_phase)
    147 
    148   def test_learning_phase_scope_in_graph(self):
    149     initial_learning_phase_outside_graph = keras.backend.learning_phase()
    150     with keras.backend.get_graph().as_default():
    151       initial_learning_phase_in_graph = keras.backend.learning_phase()
    152 
    153     self.assertEqual(keras.backend.learning_phase(),
    154                      initial_learning_phase_outside_graph)
    155     with keras.backend.learning_phase_scope(1):
    156       self.assertEqual(keras.backend.learning_phase(), 1)
    157     self.assertEqual(keras.backend.learning_phase(),
    158                      initial_learning_phase_outside_graph)
    159 
    160     with keras.backend.get_graph().as_default():
    161       self.assertEqual(keras.backend.learning_phase(),
    162                        initial_learning_phase_in_graph)
    163 
    164     self.assertEqual(keras.backend.learning_phase(),
    165                      initial_learning_phase_outside_graph)
    166 
    167   def test_int_shape(self):
    168     x = keras.backend.ones(shape=(3, 4))
    169     self.assertEqual(keras.backend.int_shape(x), (3, 4))
    170 
    171     if not context.executing_eagerly():
    172       x = keras.backend.placeholder(shape=(None, 4))
    173       self.assertEqual(keras.backend.int_shape(x), (None, 4))
    174 
    175   def test_in_train_phase(self):
    176     y1 = keras.backend.variable(1)
    177     y2 = keras.backend.variable(2)
    178     if context.executing_eagerly():
    179       with keras.backend.learning_phase_scope(0):
    180         y_val_test = keras.backend.in_train_phase(y1, y2).numpy()
    181       with keras.backend.learning_phase_scope(1):
    182         y_val_train = keras.backend.in_train_phase(y1, y2).numpy()
    183     else:
    184       y = keras.backend.in_train_phase(y1, y2)
    185       f = keras.backend.function([keras.backend.learning_phase()], [y])
    186       y_val_test = f([0])[0]
    187       y_val_train = f([1])[0]
    188     self.assertAllClose(y_val_test, 2)
    189     self.assertAllClose(y_val_train, 1)
    190 
    191   def test_is_keras_tensor(self):
    192     x = keras.backend.variable(1)
    193     self.assertEqual(keras.backend.is_keras_tensor(x), False)
    194     x = keras.Input(shape=(1,))
    195     self.assertEqual(keras.backend.is_keras_tensor(x), True)
    196     with self.assertRaises(ValueError):
    197       keras.backend.is_keras_tensor(0)
    198 
    199   def test_stop_gradient(self):
    200     x = keras.backend.variable(1)
    201     y = keras.backend.stop_gradient(x)
    202     if not context.executing_eagerly():
    203       self.assertEqual(y.op.name[:12], 'StopGradient')
    204 
    205     xs = [keras.backend.variable(1) for _ in range(3)]
    206     ys = keras.backend.stop_gradient(xs)
    207     if not context.executing_eagerly():
    208       for y in ys:
    209         self.assertEqual(y.op.name[:12], 'StopGradient')
    210 
    211 
    212 @test_util.run_all_in_graph_and_eager_modes
    213 class BackendVariableTest(test.TestCase):
    214 
    215   def test_zeros(self):
    216     x = keras.backend.zeros((3, 4))
    217     val = keras.backend.eval(x)
    218     self.assertAllClose(val, np.zeros((3, 4)))
    219 
    220   def test_ones(self):
    221     x = keras.backend.ones((3, 4))
    222     val = keras.backend.eval(x)
    223     self.assertAllClose(val, np.ones((3, 4)))
    224 
    225   def test_eye(self):
    226     x = keras.backend.eye(4)
    227     val = keras.backend.eval(x)
    228     self.assertAllClose(val, np.eye(4))
    229 
    230   def test_zeros_like(self):
    231     x = keras.backend.zeros((3, 4))
    232     y = keras.backend.zeros_like(x)
    233     val = keras.backend.eval(y)
    234     self.assertAllClose(val, np.zeros((3, 4)))
    235 
    236   def test_ones_like(self):
    237     x = keras.backend.zeros((3, 4))
    238     y = keras.backend.ones_like(x)
    239     val = keras.backend.eval(y)
    240     self.assertAllClose(val, np.ones((3, 4)))
    241 
    242   def test_random_uniform_variable(self):
    243     x = keras.backend.random_uniform_variable((30, 20), low=1, high=2, seed=0)
    244     val = keras.backend.eval(x)
    245     self.assertAllClose(val.mean(), 1.5, atol=1e-1)
    246     self.assertAllClose(val.max(), 2., atol=1e-1)
    247     self.assertAllClose(val.min(), 1., atol=1e-1)
    248 
    249   def test_random_normal_variable(self):
    250     x = keras.backend.random_normal_variable((30, 20), 1., 0.5, seed=0)
    251     val = keras.backend.eval(x)
    252     self.assertAllClose(val.mean(), 1., atol=1e-1)
    253     self.assertAllClose(val.std(), 0.5, atol=1e-1)
    254 
    255   def test_count_params(self):
    256     x = keras.backend.zeros((4, 5))
    257     val = keras.backend.count_params(x)
    258     self.assertAllClose(val, 20)
    259 
    260   def test_constant(self):
    261     ref_val = np.random.random((3, 4)).astype('float32')
    262     x = keras.backend.constant(ref_val)
    263     val = keras.backend.eval(x)
    264     self.assertAllClose(val, ref_val)
    265 
    266   def test_sparse_variable(self):
    267     val = scipy.sparse.eye(10)
    268     x = keras.backend.variable(val)
    269     self.assertTrue(isinstance(x, sparse_tensor.SparseTensor))
    270 
    271     y = keras.backend.to_dense(x)
    272     self.assertFalse(keras.backend.is_sparse(y))
    273 
    274 
    275 @test_util.run_all_in_graph_and_eager_modes
    276 class BackendLinearAlgebraTest(test.TestCase):
    277 
    278   def test_dot(self):
    279     x = keras.backend.ones(shape=(2, 3))
    280     y = keras.backend.ones(shape=(3, 4))
    281     xy = keras.backend.dot(x, y)
    282     self.assertEqual(xy.shape.as_list(), [2, 4])
    283 
    284     x = keras.backend.ones(shape=(32, 28, 3))
    285     y = keras.backend.ones(shape=(3, 4))
    286     xy = keras.backend.dot(x, y)
    287     self.assertEqual(xy.shape.as_list(), [32, 28, 4])
    288 
    289   def test_batch_dot(self):
    290     x = keras.backend.ones(shape=(32, 20, 1))
    291     y = keras.backend.ones(shape=(32, 30, 20))
    292     xy = keras.backend.batch_dot(x, y, axes=[1, 2])
    293     self.assertEqual(xy.shape.as_list(), [32, 1, 30])
    294 
    295     # TODO(fchollet): insufficiently tested.
    296 
    297   def test_reduction_ops(self):
    298     ops_to_test = [
    299         (keras.backend.max, np.max),
    300         (keras.backend.min, np.min),
    301         (keras.backend.sum, np.sum),
    302         (keras.backend.prod, np.prod),
    303         (keras.backend.var, np.var),
    304         (keras.backend.std, np.std),
    305         (keras.backend.mean, np.mean),
    306         (keras.backend.argmin, np.argmin),
    307         (keras.backend.argmax, np.argmax),
    308     ]
    309     for keras_op, np_op in ops_to_test:
    310       compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
    311                                        keras_kwargs={'axis': 1},
    312                                        np_kwargs={'axis': 1})
    313       compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
    314                                        keras_kwargs={'axis': -1},
    315                                        np_kwargs={'axis': -1})
    316       if 'keepdims' in tf_inspect.getargspec(keras_op).args:
    317         compare_single_input_op_to_numpy(keras_op, np_op,
    318                                          input_shape=(4, 7, 5),
    319                                          keras_kwargs={'axis': 1,
    320                                                        'keepdims': True},
    321                                          np_kwargs={'axis': 1,
    322                                                     'keepdims': True})
    323 
    324   def test_elementwise_ops(self):
    325     ops_to_test = [
    326         (keras.backend.square, np.square),
    327         (keras.backend.abs, np.abs),
    328         (keras.backend.round, np.round),
    329         (keras.backend.sign, np.sign),
    330         (keras.backend.sin, np.sin),
    331         (keras.backend.cos, np.cos),
    332         (keras.backend.exp, np.exp),
    333     ]
    334     for keras_op, np_op in ops_to_test:
    335       compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7))
    336 
    337     ops_to_test = [
    338         (keras.backend.sqrt, np.sqrt),
    339         (keras.backend.log, np.log),
    340     ]
    341     for keras_op, np_op in ops_to_test:
    342       compare_single_input_op_to_numpy(keras_op, np_op,
    343                                        input_shape=(4, 7),
    344                                        negative_values=False)
    345 
    346     compare_single_input_op_to_numpy(
    347         keras.backend.clip, np.clip,
    348         input_shape=(6, 4),
    349         keras_kwargs={'min_value': 0.1, 'max_value': 2.4},
    350         np_kwargs={'a_min': 0.1, 'a_max': 1.4})
    351 
    352     compare_single_input_op_to_numpy(
    353         keras.backend.pow, np.power,
    354         input_shape=(6, 4),
    355         keras_args=[3],
    356         np_args=[3])
    357 
    358   def test_two_tensor_ops(self):
    359     ops_to_test = [
    360         (keras.backend.equal, np.equal),
    361         (keras.backend.not_equal, np.not_equal),
    362         (keras.backend.greater, np.greater),
    363         (keras.backend.greater_equal, np.greater_equal),
    364         (keras.backend.less, np.less),
    365         (keras.backend.less_equal, np.less_equal),
    366         (keras.backend.maximum, np.maximum),
    367         (keras.backend.minimum, np.minimum),
    368     ]
    369     for keras_op, np_op in ops_to_test:
    370       compare_two_inputs_op_to_numpy(keras_op, np_op,
    371                                      input_shape_a=(4, 7),
    372                                      input_shape_b=(4, 7))
    373 
    374   def test_relu(self):
    375     x = ops.convert_to_tensor([[-4, 0], [2, 7]], 'float32')
    376 
    377     # standard relu
    378     relu_op = keras.backend.relu(x)
    379     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
    380 
    381     # alpha (leaky relu used)
    382     relu_op = keras.backend.relu(x, alpha=0.5)
    383     if not context.executing_eagerly():
    384       self.assertTrue('LeakyRelu' in relu_op.name)
    385     self.assertAllClose(keras.backend.eval(relu_op), [[-2, 0], [2, 7]])
    386 
    387     # max_value < some elements
    388     relu_op = keras.backend.relu(x, max_value=5)
    389     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 5]])
    390 
    391     # nn.relu6 used
    392     relu_op = keras.backend.relu(x, max_value=6)
    393     if not context.executing_eagerly():
    394       self.assertTrue('Relu6' in relu_op.name)  # uses tf.nn.relu6
    395     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 6]])
    396 
    397     # max value > 6
    398     relu_op = keras.backend.relu(x, max_value=10)
    399     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
    400 
    401     # max value is float
    402     relu_op = keras.backend.relu(x, max_value=4.3)
    403     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 4.3]])
    404 
    405     # max value == 0
    406     relu_op = keras.backend.relu(x, max_value=0)
    407     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 0]])
    408 
    409     # alpha and max_value
    410     relu_op = keras.backend.relu(x, alpha=0.25, max_value=3)
    411     self.assertAllClose(keras.backend.eval(relu_op), [[-1, 0], [2, 3]])
    412 
    413     # threshold
    414     relu_op = keras.backend.relu(x, threshold=3)
    415     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 7]])
    416 
    417     # threshold is float
    418     relu_op = keras.backend.relu(x, threshold=1.5)
    419     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
    420 
    421     # threshold is negative
    422     relu_op = keras.backend.relu(x, threshold=-5)
    423     self.assertAllClose(keras.backend.eval(relu_op), [[-4, 0], [2, 7]])
    424 
    425     # threshold and max_value
    426     relu_op = keras.backend.relu(x, threshold=3, max_value=5)
    427     self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 5]])
    428 
    429     # threshold and alpha
    430     relu_op = keras.backend.relu(x, alpha=0.25, threshold=4)
    431     self.assertAllClose(keras.backend.eval(relu_op), [[-2, -1], [-0.5, 7]])
    432 
    433     # threshold, alpha, and max_value
    434     relu_op = keras.backend.relu(x, alpha=0.25, threshold=4, max_value=5)
    435     self.assertAllClose(keras.backend.eval(relu_op), [[-2, -1], [-0.5, 5]])
    436 
    437 
    438 @test_util.run_all_in_graph_and_eager_modes
    439 class BackendShapeOpsTest(test.TestCase):
    440 
    441   def test_reshape(self):
    442     compare_single_input_op_to_numpy(keras.backend.reshape, np.reshape,
    443                                      input_shape=(4, 7),
    444                                      keras_args=[(2, 14)],
    445                                      np_args=[(2, 14)])
    446 
    447   def test_concatenate(self):
    448     a = keras.backend.variable(np.ones((1, 2, 3)))
    449     b = keras.backend.variable(np.ones((1, 2, 2)))
    450     y = keras.backend.concatenate([a, b], axis=-1)
    451     self.assertEqual(y.shape.as_list(), [1, 2, 5])
    452 
    453   def test_permute_dimensions(self):
    454     compare_single_input_op_to_numpy(keras.backend.permute_dimensions,
    455                                      np.transpose,
    456                                      input_shape=(4, 7),
    457                                      keras_args=[(1, 0)],
    458                                      np_args=[(1, 0)])
    459 
    460   def test_resize_images(self):
    461     height_factor = 2
    462     width_factor = 2
    463     data_format = 'channels_last'
    464     x = keras.backend.variable(np.ones((1, 2, 2, 3)))
    465     y = keras.backend.resize_images(x,
    466                                     height_factor,
    467                                     width_factor,
    468                                     data_format)
    469     self.assertEqual(y.shape.as_list(), [1, 4, 4, 3])
    470 
    471     data_format = 'channels_first'
    472     x = keras.backend.variable(np.ones((1, 3, 2, 2)))
    473     y = keras.backend.resize_images(x,
    474                                     height_factor,
    475                                     width_factor,
    476                                     data_format)
    477     self.assertEqual(y.shape.as_list(), [1, 3, 4, 4])
    478 
    479     # Invalid use:
    480     with self.assertRaises(ValueError):
    481       keras.backend.resize_images(x,
    482                                   height_factor,
    483                                   width_factor,
    484                                   data_format='unknown')
    485 
    486   def test_resize_volumes(self):
    487     height_factor = 2
    488     width_factor = 2
    489     depth_factor = 2
    490     data_format = 'channels_last'
    491     x = keras.backend.variable(np.ones((1, 2, 2, 2, 3)))
    492     y = keras.backend.resize_volumes(x,
    493                                      depth_factor,
    494                                      height_factor,
    495                                      width_factor,
    496                                      data_format)
    497     self.assertEqual(y.shape.as_list(), [1, 4, 4, 4, 3])
    498 
    499     data_format = 'channels_first'
    500     x = keras.backend.variable(np.ones((1, 3, 2, 2, 2)))
    501     y = keras.backend.resize_volumes(x,
    502                                      depth_factor,
    503                                      height_factor,
    504                                      width_factor,
    505                                      data_format)
    506     self.assertEqual(y.shape.as_list(), [1, 3, 4, 4, 4])
    507 
    508     # Invalid use:
    509     with self.assertRaises(ValueError):
    510       keras.backend.resize_volumes(x,
    511                                    depth_factor,
    512                                    height_factor,
    513                                    width_factor,
    514                                    data_format='unknown')
    515 
    516   def test_repeat_elements(self):
    517     x = keras.backend.variable(np.ones((1, 3, 2)))
    518     y = keras.backend.repeat_elements(x, 3, axis=1)
    519     self.assertEqual(y.shape.as_list(), [1, 9, 2])
    520 
    521     # Use with a dynamic axis:
    522     if not context.executing_eagerly():
    523       x = keras.backend.placeholder(shape=(2, None, 2))
    524       y = keras.backend.repeat_elements(x, 3, axis=1)
    525       self.assertEqual(y.shape.as_list(), [2, None, 2])
    526 
    527   def test_repeat(self):
    528     x = keras.backend.variable(np.ones((1, 3)))
    529     y = keras.backend.repeat(x, 2)
    530     self.assertEqual(y.shape.as_list(), [1, 2, 3])
    531 
    532   def test_flatten(self):
    533     compare_single_input_op_to_numpy(keras.backend.flatten,
    534                                      np.reshape,
    535                                      input_shape=(4, 7, 6),
    536                                      np_args=[(4 * 7 * 6,)])
    537 
    538   def test_batch_flatten(self):
    539     compare_single_input_op_to_numpy(keras.backend.batch_flatten,
    540                                      np.reshape,
    541                                      input_shape=(4, 7, 6),
    542                                      np_args=[(4, 7 * 6)])
    543 
    544   def test_temporal_padding(self):
    545 
    546     def ref_op(x, padding):
    547       shape = list(x.shape)
    548       shape[1] += padding[0] + padding[1]
    549       y = np.zeros(tuple(shape))
    550       y[:, padding[0]:-padding[1], :] = x
    551       return y
    552 
    553     compare_single_input_op_to_numpy(keras.backend.temporal_padding,
    554                                      ref_op,
    555                                      input_shape=(4, 7, 6),
    556                                      keras_args=[(2, 3)],
    557                                      np_args=[(2, 3)])
    558 
    559   def test_spatial_2d_padding(self):
    560 
    561     def ref_op(x, padding, data_format='channels_last'):
    562       shape = list(x.shape)
    563       if data_format == 'channels_last':
    564         shape[1] += padding[0][0] + padding[0][1]
    565         shape[2] += padding[1][0] + padding[1][1]
    566         y = np.zeros(tuple(shape))
    567         y[:, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1], :] = x
    568       else:
    569         shape[2] += padding[0][0] + padding[0][1]
    570         shape[3] += padding[1][0] + padding[1][1]
    571         y = np.zeros(tuple(shape))
    572         y[:, :, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1]] = x
    573       return y
    574 
    575     compare_single_input_op_to_numpy(
    576         keras.backend.spatial_2d_padding,
    577         ref_op,
    578         input_shape=(2, 3, 2, 3),
    579         keras_args=[((2, 3), (1, 2))],
    580         keras_kwargs={'data_format': 'channels_last'},
    581         np_args=[((2, 3), (1, 2))],
    582         np_kwargs={'data_format': 'channels_last'})
    583     compare_single_input_op_to_numpy(
    584         keras.backend.spatial_2d_padding,
    585         ref_op,
    586         input_shape=(2, 3, 2, 3),
    587         keras_args=[((2, 3), (1, 2))],
    588         keras_kwargs={'data_format': 'channels_first'},
    589         np_args=[((2, 3), (1, 2))],
    590         np_kwargs={'data_format': 'channels_first'})
    591 
    592   def test_spatial_3d_padding(self):
    593 
    594     def ref_op(x, padding, data_format='channels_last'):
    595       shape = list(x.shape)
    596       if data_format == 'channels_last':
    597         shape[1] += padding[0][0] + padding[0][1]
    598         shape[2] += padding[1][0] + padding[1][1]
    599         shape[3] += padding[2][0] + padding[2][1]
    600         y = np.zeros(tuple(shape))
    601         y[:,
    602           padding[0][0]:-padding[0][1],
    603           padding[1][0]:-padding[1][1],
    604           padding[2][0]:-padding[2][1],
    605           :] = x
    606       else:
    607         shape[2] += padding[0][0] + padding[0][1]
    608         shape[3] += padding[1][0] + padding[1][1]
    609         shape[4] += padding[2][0] + padding[2][1]
    610         y = np.zeros(tuple(shape))
    611         y[:, :,
    612           padding[0][0]:-padding[0][1],
    613           padding[1][0]:-padding[1][1],
    614           padding[2][0]:-padding[2][1]] = x
    615       return y
    616 
    617     compare_single_input_op_to_numpy(
    618         keras.backend.spatial_3d_padding,
    619         ref_op,
    620         input_shape=(2, 3, 2, 3, 2),
    621         keras_args=[((2, 3), (1, 2), (2, 3))],
    622         keras_kwargs={'data_format': 'channels_last'},
    623         np_args=[((2, 3), (1, 2), (2, 3))],
    624         np_kwargs={'data_format': 'channels_last'})
    625     compare_single_input_op_to_numpy(
    626         keras.backend.spatial_3d_padding,
    627         ref_op,
    628         input_shape=(2, 3, 2, 3, 2),
    629         keras_args=[((2, 3), (1, 2), (2, 3))],
    630         keras_kwargs={'data_format': 'channels_first'},
    631         np_args=[((2, 3), (1, 2), (2, 3))],
    632         np_kwargs={'data_format': 'channels_first'})
    633 
    634 
    635 @test_util.run_all_in_graph_and_eager_modes
    636 class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
    637 
    638   def test_bias_add(self):
    639     keras_op = keras.backend.bias_add
    640     np_op = np.add
    641     compare_two_inputs_op_to_numpy(keras_op, np_op,
    642                                    input_shape_a=(4, 7),
    643                                    input_shape_b=(7,))
    644     compare_two_inputs_op_to_numpy(keras_op, np_op,
    645                                    input_shape_a=(4, 3, 7),
    646                                    input_shape_b=(7,))
    647     compare_two_inputs_op_to_numpy(keras_op, np_op,
    648                                    input_shape_a=(4, 3, 5, 7),
    649                                    input_shape_b=(7,))
    650     compare_two_inputs_op_to_numpy(keras_op, np_op,
    651                                    input_shape_a=(4, 3, 5, 2, 7),
    652                                    input_shape_b=(7,))
    653 
    654     with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
    655       x = keras.backend.variable((3, 4))
    656       b = keras.backend.variable((3, 4))
    657       keras.backend.bias_add(x, b)
    658     with self.assertRaises(ValueError):
    659       x = keras.backend.variable((3, 4))
    660       b = keras.backend.variable((4,))
    661       keras.backend.bias_add(x, b, data_format='unknown')
    662 
    663   def test_bias_add_channels_first(self):
    664 
    665     def keras_op(x, b):
    666       return keras.backend.bias_add(x, b, data_format='channels_first')
    667 
    668     def np_op(x, b):
    669       if x.ndim == 3:
    670         b = b.reshape((1, b.shape[0], 1))
    671       if x.ndim == 4:
    672         b = b.reshape((1, b.shape[0], 1, 1))
    673       return x + b
    674 
    675     compare_two_inputs_op_to_numpy(keras_op, np_op,
    676                                    input_shape_a=(4, 3, 7),
    677                                    input_shape_b=(3,))
    678     compare_two_inputs_op_to_numpy(keras_op, np_op,
    679                                    input_shape_a=(4, 3, 5, 7),
    680                                    input_shape_b=(3,))
    681 
    682   def test_pool2d(self):
    683     val = np.random.random((10, 3, 10, 10))
    684     x = keras.backend.variable(val)
    685     y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
    686                              padding='valid', data_format='channels_first',
    687                              pool_mode='max')
    688     self.assertEqual(y.shape.as_list(), [10, 3, 9, 9])
    689 
    690     y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
    691                              padding='valid', data_format='channels_first',
    692                              pool_mode='avg')
    693     self.assertEqual(y.shape.as_list(), [10, 3, 9, 9])
    694 
    695     val = np.random.random((10, 10, 10, 3))
    696     x = keras.backend.variable(val)
    697     y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
    698                              padding='valid', data_format='channels_last')
    699     self.assertEqual(y.shape.as_list(), [10, 9, 9, 3])
    700 
    701     val = np.random.random((10, 10, 10, 3))
    702     x = keras.backend.variable(val)
    703     y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
    704                              padding='same', data_format='channels_last')
    705     self.assertEqual(y.shape.as_list(), [10, 10, 10, 3])
    706 
    707     val = np.random.random((10, 10, 10, 3))
    708     x = keras.backend.variable(val)
    709     y = keras.backend.pool2d(x, (2, 2), strides=(2, 2),
    710                              padding='same', data_format='channels_last')
    711     self.assertEqual(y.shape.as_list(), [10, 5, 5, 3])
    712 
    713     with self.assertRaises(ValueError):
    714       y = keras.backend.pool2d(x, (2, 2), strides=(2, 2),
    715                                padding='other', data_format='channels_last')
    716     with self.assertRaises(ValueError):
    717       y = keras.backend.pool2d(x, (2, 2), strides=(2, 2),
    718                                data_format='other')
    719     with self.assertRaises(ValueError):
    720       y = keras.backend.pool2d(x, (2, 2, 2), strides=(2, 2))
    721     with self.assertRaises(ValueError):
    722       y = keras.backend.pool2d(x, (2, 2), strides=(2, 2, 2))
    723     with self.assertRaises(ValueError):
    724       y = keras.backend.pool2d(x, (2, 2), strides=(2, 2), pool_mode='other')
    725 
    726   def test_pool3d(self):
    727     val = np.random.random((10, 3, 10, 10, 10))
    728     x = keras.backend.variable(val)
    729     y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
    730                              padding='valid', data_format='channels_first',
    731                              pool_mode='max')
    732     self.assertEqual(y.shape.as_list(), [10, 3, 9, 9, 9])
    733 
    734     y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
    735                              padding='valid', data_format='channels_first',
    736                              pool_mode='avg')
    737     self.assertEqual(y.shape.as_list(), [10, 3, 9, 9, 9])
    738 
    739     val = np.random.random((10, 10, 10, 10, 3))
    740     x = keras.backend.variable(val)
    741     y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
    742                              padding='valid', data_format='channels_last')
    743     self.assertEqual(y.shape.as_list(), [10, 9, 9, 9, 3])
    744 
    745     val = np.random.random((10, 10, 10, 10, 3))
    746     x = keras.backend.variable(val)
    747     y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
    748                              padding='same', data_format='channels_last')
    749     self.assertEqual(y.shape.as_list(), [10, 10, 10, 10, 3])
    750 
    751     val = np.random.random((10, 10, 10, 10, 3))
    752     x = keras.backend.variable(val)
    753     y = keras.backend.pool3d(x, (2, 2, 2), strides=(2, 2, 2),
    754                              padding='same', data_format='channels_last')
    755     self.assertEqual(y.shape.as_list(), [10, 5, 5, 5, 3])
    756 
    757   def test_conv1d(self):
    758     val = np.random.random((10, 4, 10))
    759     x = keras.backend.variable(val)
    760     kernel_val = np.random.random((3, 4, 5))
    761     k = keras.backend.variable(kernel_val)
    762     y = keras.backend.conv1d(x, k, strides=(1,),
    763                              padding='valid', data_format='channels_first')
    764     self.assertEqual(y.shape.as_list(), [10, 5, 8])
    765 
    766     val = np.random.random((10, 10, 4))
    767     x = keras.backend.variable(val)
    768     y = keras.backend.conv1d(x, k, strides=(1,),
    769                              padding='valid', data_format='channels_last')
    770     self.assertEqual(y.shape.as_list(), [10, 8, 5])
    771 
    772     val = np.random.random((10, 10, 4))
    773     x = keras.backend.variable(val)
    774     y = keras.backend.conv1d(x, k, strides=(1,),
    775                              padding='same', data_format='channels_last')
    776     self.assertEqual(y.shape.as_list(), [10, 10, 5])
    777 
    778     val = np.random.random((10, 10, 4))
    779     x = keras.backend.variable(val)
    780     y = keras.backend.conv1d(x, k, strides=(2,),
    781                              padding='same', data_format='channels_last')
    782     self.assertEqual(y.shape.as_list(), [10, 5, 5])
    783 
    784   def test_local_conv_channels_dim(self):
    785     filters = 3
    786     batch_size = 2
    787 
    788     for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]:
    789       channels_in = input_shape[0]
    790       input_spatial_shape = input_shape[1:]
    791       dim = len(input_spatial_shape)
    792 
    793       inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
    794       inputs_cf = keras.backend.variable(inputs)
    795 
    796       for kernel_size in [1, 2]:
    797         for stride in [1, 2]:
    798           kernel_sizes = (kernel_size,) * dim
    799           strides = (stride,) * dim
    800 
    801           output_shape = tuple([(i - kernel_size + stride) // stride
    802                                 for i in input_spatial_shape])
    803 
    804           kernel_shape = (np.prod(output_shape),
    805                           np.prod(kernel_sizes) * channels_in,
    806                           filters)
    807 
    808           kernel = np.random.normal(
    809               0,
    810               1,
    811               output_shape + (channels_in, np.prod(kernel_sizes), filters)
    812           )
    813 
    814           kernel_cf = np.reshape(kernel, kernel_shape)
    815           kernel_cf = keras.backend.variable(kernel_cf)
    816 
    817           conv_cf = keras.backend.local_conv(inputs_cf,
    818                                              kernel_cf,
    819                                              kernel_sizes,
    820                                              strides,
    821                                              output_shape,
    822                                              'channels_first')
    823 
    824           inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) +
    825                                    [1])
    826           inputs_cl = keras.backend.variable(inputs_cl)
    827 
    828           kernel_cl = np.reshape(
    829               np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]),
    830               kernel_shape
    831           )
    832           kernel_cl = keras.backend.variable(kernel_cl)
    833 
    834           conv_cl = keras.backend.local_conv(inputs_cl,
    835                                              kernel_cl,
    836                                              kernel_sizes,
    837                                              strides,
    838                                              output_shape,
    839                                              'channels_last')
    840 
    841           conv_cf = keras.backend.eval(conv_cf)
    842           conv_cl = keras.backend.eval(conv_cl)
    843 
    844           self.assertAllCloseAccordingToType(
    845               conv_cf,
    846               np.transpose(conv_cl,
    847                            [0, dim + 1] + list(range(1, dim + 1))),
    848               atol=1e-5
    849           )
    850 
    851   @parameterized.named_parameters(
    852       ('local_conv1d', (5, 6), (3,), (1,), (3,)),
    853       ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3)))
    854   def test_local_conv_1d_and_2d(self,
    855                                 input_shape,
    856                                 kernel_sizes,
    857                                 strides,
    858                                 output_shape):
    859     filters = 3
    860     batch_size = 2
    861 
    862     inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
    863     inputs = keras.backend.variable(inputs)
    864 
    865     kernel = np.random.normal(0, 1, (np.prod(output_shape),
    866                                      np.prod(kernel_sizes) * input_shape[-1],
    867                                      filters))
    868     kernel = keras.backend.variable(kernel)
    869 
    870     local_conv = keras.backend.local_conv(inputs,
    871                                           kernel,
    872                                           kernel_sizes,
    873                                           strides,
    874                                           output_shape,
    875                                           'channels_last')
    876     if len(output_shape) == 1:
    877       local_conv_dim = keras.backend.local_conv1d(inputs,
    878                                                   kernel,
    879                                                   kernel_sizes,
    880                                                   strides,
    881                                                   'channels_last')
    882     else:
    883       local_conv_dim = keras.backend.local_conv2d(inputs,
    884                                                   kernel,
    885                                                   kernel_sizes,
    886                                                   strides,
    887                                                   output_shape,
    888                                                   'channels_last')
    889 
    890     local_conv = keras.backend.eval(local_conv)
    891     local_conv_dim = keras.backend.eval(local_conv_dim)
    892 
    893     self.assertAllCloseAccordingToType(local_conv, local_conv_dim)
    894 
    895   def test_conv2d(self):
    896     val = np.random.random((10, 4, 10, 10))
    897     x = keras.backend.variable(val)
    898     kernel_val = np.random.random((3, 3, 4, 5))
    899     k = keras.backend.variable(kernel_val)
    900     y = keras.backend.conv2d(x, k,
    901                              padding='valid', data_format='channels_first')
    902     self.assertEqual(y.shape.as_list(), [10, 5, 8, 8])
    903 
    904     val = np.random.random((10, 10, 10, 4))
    905     x = keras.backend.variable(val)
    906     y = keras.backend.conv2d(x, k, strides=(1, 1),
    907                              padding='valid', data_format='channels_last')
    908     self.assertEqual(y.shape.as_list(), [10, 8, 8, 5])
    909 
    910     val = np.random.random((10, 10, 10, 4))
    911     x = keras.backend.variable(val)
    912     y = keras.backend.conv2d(x, k, strides=(1, 1),
    913                              padding='same', data_format='channels_last')
    914     self.assertEqual(y.shape.as_list(), [10, 10, 10, 5])
    915 
    916     val = np.random.random((10, 10, 10, 4))
    917     x = keras.backend.variable(val)
    918     y = keras.backend.conv2d(x, k, strides=(2, 2),
    919                              padding='same', data_format='channels_last')
    920     self.assertEqual(y.shape.as_list(), [10, 5, 5, 5])
    921     with self.assertRaises(ValueError):
    922       y = keras.backend.conv2d(x, k, (2, 2),
    923                                padding='other', data_format='channels_last')
    924     with self.assertRaises(ValueError):
    925       y = keras.backend.conv2d(x, k, (2, 2),
    926                                data_format='other')
    927     with self.assertRaises(ValueError):
    928       y = keras.backend.conv2d(x, k, (2, 2, 2))
    929 
    930   def test_separable_conv2d(self):
    931     val = np.random.random((10, 4, 10, 10))
    932     x = keras.backend.variable(val)
    933     depthwise_kernel_val = np.random.random((3, 3, 4, 1))
    934     pointwise_kernel_val = np.random.random((1, 1, 4, 5))
    935     dk = keras.backend.variable(depthwise_kernel_val)
    936     pk = keras.backend.variable(pointwise_kernel_val)
    937     y = keras.backend.separable_conv2d(
    938         x, dk, pk, padding='valid', data_format='channels_first')
    939     self.assertEqual(y.shape.as_list(), [10, 5, 8, 8])
    940 
    941     val = np.random.random((10, 10, 10, 4))
    942     x = keras.backend.variable(val)
    943     y = keras.backend.separable_conv2d(
    944         x, dk, pk, strides=(1, 1), padding='valid', data_format='channels_last')
    945     self.assertEqual(y.shape.as_list(), [10, 8, 8, 5])
    946 
    947     val = np.random.random((10, 10, 10, 4))
    948     x = keras.backend.variable(val)
    949     y = keras.backend.separable_conv2d(
    950         x, dk, pk, strides=(1, 1), padding='same', data_format='channels_last')
    951     self.assertEqual(y.shape.as_list(), [10, 10, 10, 5])
    952 
    953     val = np.random.random((10, 10, 10, 4))
    954     x = keras.backend.variable(val)
    955     y = keras.backend.separable_conv2d(
    956         x, dk, pk, strides=(2, 2), padding='same', data_format='channels_last')
    957     self.assertEqual(y.shape.as_list(), [10, 5, 5, 5])
    958     with self.assertRaises(ValueError):
    959       y = keras.backend.separable_conv2d(
    960           x, dk, pk, (2, 2), padding='other', data_format='channels_last')
    961     with self.assertRaises(ValueError):
    962       y = keras.backend.separable_conv2d(
    963           x, dk, pk, (2, 2), data_format='other')
    964     with self.assertRaises(ValueError):
    965       y = keras.backend.separable_conv2d(x, dk, pk, (2, 2, 2))
    966 
    967   def test_conv3d(self):
    968     val = np.random.random((10, 4, 10, 10, 10))
    969     x = keras.backend.variable(val)
    970     kernel_val = np.random.random((3, 3, 3, 4, 5))
    971     k = keras.backend.variable(kernel_val)
    972     y = keras.backend.conv3d(x, k,
    973                              padding='valid', data_format='channels_first')
    974     self.assertEqual(y.shape.as_list(), [10, 5, 8, 8, 8])
    975 
    976     val = np.random.random((10, 10, 10, 10, 4))
    977     x = keras.backend.variable(val)
    978     y = keras.backend.conv3d(x, k, strides=(1, 1, 1),
    979                              padding='valid', data_format='channels_last')
    980     self.assertEqual(y.shape.as_list(), [10, 8, 8, 8, 5])
    981 
    982     val = np.random.random((10, 10, 10, 10, 4))
    983     x = keras.backend.variable(val)
    984     y = keras.backend.conv3d(x, k, strides=(1, 1, 1),
    985                              padding='same', data_format='channels_last')
    986     self.assertEqual(y.shape.as_list(), [10, 10, 10, 10, 5])
    987 
    988     val = np.random.random((10, 10, 10, 10, 4))
    989     x = keras.backend.variable(val)
    990     y = keras.backend.conv3d(x, k, strides=(2, 2, 2),
    991                              padding='same', data_format='channels_last')
    992     self.assertEqual(y.shape.as_list(), [10, 5, 5, 5, 5])
    993     with self.assertRaises(ValueError):
    994       y = keras.backend.conv3d(x, k, (2, 2, 2),
    995                                padding='other', data_format='channels_last')
    996     with self.assertRaises(ValueError):
    997       y = keras.backend.conv3d(x, k, (2, 2, 2),
    998                                data_format='other')
    999     with self.assertRaises(ValueError):
   1000       y = keras.backend.conv3d(x, k, (2, 2))
   1001 
   1002   def test_rnn(self):
   1003     # implement a simple RNN
   1004     num_samples = 4
   1005     input_dim = 5
   1006     output_dim = 3
   1007     timesteps = 6
   1008 
   1009     input_val = np.random.random(
   1010         (num_samples, timesteps, input_dim)).astype(np.float32)
   1011     init_state_val = np.random.random(
   1012         (num_samples, output_dim)).astype(np.float32)
   1013     w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
   1014     w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
   1015     np_mask = np.random.randint(2, size=(num_samples, timesteps))
   1016 
   1017     def rnn_step_fn():
   1018       w_i = keras.backend.variable(w_i_val)
   1019       w_o = keras.backend.variable(w_o_val)
   1020 
   1021       def step_function(x, states):
   1022         assert len(states) == 1
   1023         prev_output = states[0]
   1024         output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o)
   1025         return output, [output]
   1026 
   1027       return step_function
   1028 
   1029     # test default setup
   1030     last_output_list = [[], [], [], [], [], []]
   1031     outputs_list = [[], [], [], [], [], []]
   1032     state_list = [[], [], [], [], [], []]
   1033 
   1034     rnn_fn = rnn_step_fn()
   1035     inputs = keras.backend.variable(input_val)
   1036     initial_states = [keras.backend.variable(init_state_val)]
   1037     mask = keras.backend.variable(np_mask)
   1038 
   1039     kwargs_list = [
   1040         {'go_backwards': False, 'mask': None},
   1041         {'go_backwards': False, 'mask': None, 'unroll': True},
   1042         {'go_backwards': True, 'mask': None},
   1043         {'go_backwards': True, 'mask': None, 'unroll': True},
   1044         {'go_backwards': False, 'mask': mask},
   1045         {'go_backwards': False, 'mask': mask, 'unroll': True},
   1046     ]
   1047     for i, kwargs in enumerate(kwargs_list):
   1048       last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
   1049                                                            initial_states,
   1050                                                            **kwargs)
   1051       # check static shape inference
   1052       self.assertEqual(last_output.shape.as_list(), [num_samples, output_dim])
   1053       self.assertEqual(outputs.shape.as_list(),
   1054                        [num_samples, timesteps, output_dim])
   1055       for state in new_states:
   1056         self.assertEqual(state.shape.as_list(), [num_samples, output_dim])
   1057 
   1058       last_output_list[i].append(keras.backend.eval(last_output))
   1059       outputs_list[i].append(keras.backend.eval(outputs))
   1060       self.assertLen(new_states, 1)
   1061       state_list[i].append(keras.backend.eval(new_states[0]))
   1062 
   1063       def assert_list_pairwise(z_list, atol=1e-05):
   1064         for (z1, z2) in zip(z_list[1:], z_list[:-1]):
   1065           self.assertAllClose(z1, z2, atol=atol)
   1066 
   1067       assert_list_pairwise(last_output_list[0], atol=1e-04)
   1068       assert_list_pairwise(outputs_list[0], atol=1e-04)
   1069       assert_list_pairwise(state_list[0], atol=1e-04)
   1070       assert_list_pairwise(last_output_list[2], atol=1e-04)
   1071       assert_list_pairwise(outputs_list[2], atol=1e-04)
   1072       assert_list_pairwise(state_list[2], atol=1e-04)
   1073 
   1074       for l, u_l in zip(last_output_list[0], last_output_list[1]):
   1075         self.assertAllClose(l, u_l, atol=1e-04)
   1076 
   1077       for o, u_o in zip(outputs_list[0], outputs_list[1]):
   1078         self.assertAllClose(o, u_o, atol=1e-04)
   1079 
   1080       for s, u_s in zip(state_list[0], state_list[1]):
   1081         self.assertAllClose(s, u_s, atol=1e-04)
   1082 
   1083       for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
   1084         self.assertAllClose(b_l, b_u_l, atol=1e-04)
   1085 
   1086       for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
   1087         self.assertAllClose(b_o, b_u_o, atol=1e-04)
   1088 
   1089       for b_s, b_u_s in zip(state_list[2], state_list[3]):
   1090         self.assertAllClose(b_s, b_u_s, atol=1e-04)
   1091 
   1092   def test_rnn_additional_states(self):
   1093     # implement a simple RNN
   1094     num_samples = 4
   1095     input_dim = 5
   1096     output_dim = 3
   1097     timesteps = 6
   1098 
   1099     input_val = np.random.random(
   1100         (num_samples, timesteps, input_dim)).astype(np.float32)
   1101     init_state_val = np.random.random(
   1102         (num_samples, output_dim)).astype(np.float32)
   1103     w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
   1104     w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
   1105     np_mask = np.random.randint(2, size=(num_samples, timesteps))
   1106 
   1107     def rnn_step_fn():
   1108       w_i = keras.backend.variable(w_i_val)
   1109       w_o = keras.backend.variable(w_o_val)
   1110 
   1111       def step_function(x, states):
   1112         assert len(states) == 2
   1113         prev_output = states[0]
   1114         output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o)
   1115         return output, [output,
   1116                         keras.backend.concatenate([output, output], axis=-1)]
   1117 
   1118       return step_function
   1119 
   1120     # test default setup
   1121     last_output_list = [[], [], [], [], [], []]
   1122     outputs_list = [[], [], [], [], [], []]
   1123     state_list = [[], [], [], [], [], []]
   1124     additional_state_list = [[], [], [], [], [], []]
   1125 
   1126     rnn_fn = rnn_step_fn()
   1127     inputs = keras.backend.variable(input_val)
   1128     initial_states = [
   1129         keras.backend.variable(init_state_val),
   1130         ops.convert_to_tensor(
   1131             np.concatenate([init_state_val, init_state_val], axis=-1))
   1132     ]
   1133     mask = keras.backend.variable(np_mask)
   1134 
   1135     kwargs_list = [
   1136         {'go_backwards': False, 'mask': None},
   1137         {'go_backwards': False, 'mask': None, 'unroll': True},
   1138         {'go_backwards': True, 'mask': None},
   1139         {'go_backwards': True, 'mask': None, 'unroll': True},
   1140         {'go_backwards': False, 'mask': mask},
   1141         {'go_backwards': False, 'mask': mask, 'unroll': True},
   1142     ]
   1143     for i, kwargs in enumerate(kwargs_list):
   1144       last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
   1145                                                            initial_states,
   1146                                                            **kwargs)
   1147       # check static shape inference
   1148       self.assertEqual(last_output.shape.as_list(), [num_samples, output_dim])
   1149       self.assertEqual(outputs.shape.as_list(),
   1150                        [num_samples, timesteps, output_dim])
   1151       # for state in new_states:
   1152       #   self.assertEqual(state.shape.as_list(),
   1153       #                     [num_samples, output_dim])
   1154       self.assertEqual(new_states[0].shape.as_list(), [num_samples, output_dim])
   1155       self.assertEqual(new_states[1].shape.as_list(),
   1156                        [num_samples, 2 * output_dim])
   1157 
   1158       last_output_list[i].append(keras.backend.eval(last_output))
   1159       outputs_list[i].append(keras.backend.eval(outputs))
   1160       self.assertLen(new_states, 2)
   1161       state_list[i].append(keras.backend.eval(new_states[0]))
   1162       additional_state_list[i].append(keras.backend.eval(new_states[1]))
   1163 
   1164       def assert_list_pairwise(z_list, atol=1e-05):
   1165         for (z1, z2) in zip(z_list[1:], z_list[:-1]):
   1166           self.assertAllClose(z1, z2, atol=atol)
   1167 
   1168       assert_list_pairwise(last_output_list[0], atol=1e-04)
   1169       assert_list_pairwise(outputs_list[0], atol=1e-04)
   1170       assert_list_pairwise(state_list[0], atol=1e-04)
   1171       assert_list_pairwise(additional_state_list[0], atol=1e-04)
   1172       assert_list_pairwise(last_output_list[2], atol=1e-04)
   1173       assert_list_pairwise(outputs_list[2], atol=1e-04)
   1174       assert_list_pairwise(state_list[2], atol=1e-04)
   1175       assert_list_pairwise(additional_state_list[2], atol=1e-04)
   1176 
   1177       for l, u_l in zip(last_output_list[0], last_output_list[1]):
   1178         self.assertAllClose(l, u_l, atol=1e-04)
   1179 
   1180       for o, u_o in zip(outputs_list[0], outputs_list[1]):
   1181         self.assertAllClose(o, u_o, atol=1e-04)
   1182 
   1183       for s, u_s in zip(state_list[0], state_list[1]):
   1184         self.assertAllClose(s, u_s, atol=1e-04)
   1185 
   1186       for s, u_s in zip(additional_state_list[0], additional_state_list[1]):
   1187         self.assertAllClose(s, u_s, atol=1e-04)
   1188 
   1189       for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
   1190         self.assertAllClose(b_l, b_u_l, atol=1e-04)
   1191 
   1192       for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
   1193         self.assertAllClose(b_o, b_u_o, atol=1e-04)
   1194 
   1195       for b_s, b_u_s in zip(state_list[2], state_list[3]):
   1196         self.assertAllClose(b_s, b_u_s, atol=1e-04)
   1197 
   1198       for s, u_s in zip(additional_state_list[2], additional_state_list[3]):
   1199         self.assertAllClose(s, u_s, atol=1e-04)
   1200 
   1201   def test_rnn_output_and_state_masking_independent(self):
   1202     num_samples = 2
   1203     num_timesteps = 4
   1204     state_and_io_size = 2
   1205     mask_last_num_timesteps = 2  # for second sample only
   1206 
   1207     # a step function that just outputs inputs,
   1208     # but increments states +1 per timestep
   1209     def step_function(inputs, states):
   1210       return inputs, [s + 1 for s in states]
   1211 
   1212     inputs_vals = np.random.random((num_samples, num_timesteps,
   1213                                     state_and_io_size))
   1214     initial_state_vals = np.random.random((num_samples, state_and_io_size))
   1215     # masking of two last timesteps for second sample only
   1216     mask_vals = np.ones((num_samples, num_timesteps))
   1217     mask_vals[1, -mask_last_num_timesteps:] = 0
   1218 
   1219     # outputs expected to be same as inputs for the first sample
   1220     expected_outputs = inputs_vals.copy()
   1221     # but for the second sample all outputs in masked region should be the same
   1222     # as last output before masked region
   1223     expected_outputs[1, -mask_last_num_timesteps:] = \
   1224         expected_outputs[1, -(mask_last_num_timesteps + 1)]
   1225 
   1226     expected_last_state = initial_state_vals.copy()
   1227     # first state should be incremented for every timestep (no masking)
   1228     expected_last_state[0] += num_timesteps
   1229     # second state should not be incremented for last two timesteps
   1230     expected_last_state[1] += (num_timesteps - mask_last_num_timesteps)
   1231 
   1232     # verify same expected output for `unroll=true/false`
   1233     inputs = keras.backend.variable(inputs_vals)
   1234     initial_states = [keras.backend.variable(initial_state_vals)]
   1235     mask = keras.backend.variable(mask_vals)
   1236     for unroll in [True, False]:
   1237       _, outputs, last_states = keras.backend.rnn(
   1238           step_function,
   1239           inputs,
   1240           initial_states,
   1241           mask=mask,
   1242           unroll=unroll,
   1243           input_length=num_timesteps if unroll else None)
   1244 
   1245       self.assertAllClose(keras.backend.eval(outputs), expected_outputs)
   1246       self.assertAllClose(
   1247           keras.backend.eval(last_states[0]), expected_last_state)
   1248 
   1249   def test_rnn_output_num_dim_larger_than_2_masking(self):
   1250     num_samples = 3
   1251     num_timesteps = 4
   1252     num_features = 5
   1253 
   1254     def step_function(inputs, states):
   1255       outputs = keras.backend.tile(keras.backend.expand_dims(inputs), [1, 1, 2])
   1256       return outputs, [keras.backend.identity(s) for s in states]
   1257       # Note: cannot just return states (which can be a problem) ->
   1258       # tensorflow/python/ops/resource_variable_ops.py", line 824, in set_shape
   1259       # NotImplementedError: ResourceVariable does not implement set_shape()
   1260 
   1261     inputs_vals = np.random.random((num_samples, num_timesteps, num_features))
   1262     initial_state_vals = np.random.random((num_samples, 6))
   1263     mask_vals = np.ones((num_samples, num_timesteps))
   1264     mask_vals[-1, -1] = 0  # final timestep masked for last sample
   1265 
   1266     expected_outputs = np.repeat(inputs_vals[..., None], repeats=2, axis=-1)
   1267     # for the last sample, the final timestep (in masked region) should be the
   1268     # same as the second to final output (before masked region)
   1269     expected_outputs[-1, -1] = expected_outputs[-1, -2]
   1270 
   1271     inputs = keras.backend.variable(inputs_vals)
   1272     initial_states = [keras.backend.variable(initial_state_vals)]
   1273     mask = keras.backend.variable(mask_vals)
   1274     for unroll in [True, False]:
   1275       _, outputs, _ = keras.backend.rnn(
   1276           step_function,
   1277           inputs,
   1278           initial_states,
   1279           mask=mask,
   1280           unroll=unroll,
   1281           input_length=num_timesteps if unroll else None)
   1282 
   1283       self.assertAllClose(keras.backend.eval(outputs), expected_outputs)
   1284 
   1285   def test_rnn_state_num_dim_larger_than_2_masking(self):
   1286     num_samples = 3
   1287     num_timesteps = 4
   1288 
   1289     def step_function(inputs, states):
   1290       return inputs, [s + 1 for s in states]
   1291 
   1292     inputs_vals = np.random.random((num_samples, num_timesteps, 5))
   1293     initial_state_vals = np.random.random((num_samples, 6, 7))
   1294     mask_vals = np.ones((num_samples, num_timesteps))
   1295     mask_vals[0, -2:] = 0  # final two timesteps masked for first sample
   1296 
   1297     expected_last_state = initial_state_vals.copy()
   1298     expected_last_state[0] += (num_timesteps - 2)
   1299     expected_last_state[1:] += num_timesteps
   1300 
   1301     inputs = keras.backend.variable(inputs_vals)
   1302     initial_states = [keras.backend.variable(initial_state_vals)]
   1303     mask = keras.backend.variable(mask_vals)
   1304     for unroll in [True, False]:
   1305       _, _, last_states = keras.backend.rnn(
   1306           step_function,
   1307           inputs,
   1308           initial_states,
   1309           mask=mask,
   1310           unroll=unroll,
   1311           input_length=num_timesteps if unroll else None)
   1312 
   1313       self.assertAllClose(
   1314           keras.backend.eval(last_states[0]), expected_last_state)
   1315 
   1316   def test_normalize_batch_in_training(self):
   1317     val = np.random.random((10, 3, 10, 10))
   1318     x = keras.backend.variable(val)
   1319     reduction_axes = (0, 2, 3)
   1320 
   1321     g_val = np.random.random((3,))
   1322     b_val = np.random.random((3,))
   1323     gamma = keras.backend.variable(g_val)
   1324     beta = keras.backend.variable(b_val)
   1325     normed, mean, var = keras.backend.normalize_batch_in_training(
   1326         x, gamma, beta, reduction_axes, epsilon=1e-3)
   1327     self.assertEqual(normed.shape.as_list(), [10, 3, 10, 10])
   1328     self.assertEqual(mean.shape.as_list(), [
   1329         3,
   1330     ])
   1331     self.assertEqual(var.shape.as_list(), [
   1332         3,
   1333     ])
   1334 
   1335     # case: gamma=None
   1336     gamma = None
   1337     normed, mean, var = keras.backend.normalize_batch_in_training(
   1338         x, gamma, beta, reduction_axes, epsilon=1e-3)
   1339     self.assertEqual(normed.shape.as_list(), [10, 3, 10, 10])
   1340     self.assertEqual(mean.shape.as_list(), [
   1341         3,
   1342     ])
   1343     self.assertEqual(var.shape.as_list(), [
   1344         3,
   1345     ])
   1346 
   1347     # case: beta=None
   1348     beta = None
   1349     normed, mean, var = keras.backend.normalize_batch_in_training(
   1350         x, gamma, beta, reduction_axes, epsilon=1e-3)
   1351     self.assertEqual(normed.shape.as_list(), [10, 3, 10, 10])
   1352     self.assertEqual(mean.shape.as_list(), [
   1353         3,
   1354     ])
   1355     self.assertEqual(var.shape.as_list(), [
   1356         3,
   1357     ])
   1358 
   1359   def test_dropout(self):
   1360     inputs = array_ops.ones((200, 200))
   1361     outputs = keras.backend.dropout(inputs, 0.2)
   1362     outputs_val = keras.backend.eval(outputs)
   1363     self.assertEqual(np.min(outputs_val), 0)
   1364     self.assertAllClose(np.count_nonzero(outputs_val), 32000, atol=1000)
   1365     # Test noise shape
   1366     outputs = keras.backend.dropout(inputs, 0.2, noise_shape=(200, 1))
   1367     outputs_val = keras.backend.eval(outputs)
   1368     self.assertAllClose(outputs_val[2, :], outputs_val[3, :], atol=1e-5)
   1369 
   1370 
   1371 @test_util.run_all_in_graph_and_eager_modes
   1372 class TestCTC(test.TestCase):
   1373 
   1374   def test_ctc_decode(self):
   1375     depth = 6
   1376     seq_len_0 = 5
   1377     input_prob_matrix_0 = np.asarray(
   1378         [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
   1379          [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
   1380          [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
   1381          [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
   1382          [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
   1383          # Random entry added in at time=5
   1384          [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]],
   1385         dtype=np.float32)
   1386 
   1387     # len max_time_steps array of batch_size x depth matrices
   1388     inputs = ([input_prob_matrix_0[t, :][np.newaxis, :]
   1389                for t in range(seq_len_0)] +  # Pad to max_time_steps = 8
   1390               2 * [np.zeros((1, depth), dtype=np.float32)])
   1391 
   1392     inputs = keras.backend.variable(np.asarray(inputs).transpose((1, 0, 2)))
   1393 
   1394     # batch_size length vector of sequence_lengths
   1395     input_length = keras.backend.variable(
   1396         np.array([seq_len_0], dtype=np.int32))
   1397     # batch_size length vector of negative log probabilities
   1398     log_prob_truth = np.array([
   1399         -3.5821197,  # output beam 0
   1400         -3.777835    # output beam 1
   1401     ], np.float32)[np.newaxis, :]
   1402 
   1403     decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
   1404     beam_width = 2
   1405     top_paths = 2
   1406 
   1407     decode_pred_tf, log_prob_pred_tf = keras.backend.ctc_decode(
   1408         inputs,
   1409         input_length,
   1410         greedy=False,
   1411         beam_width=beam_width,
   1412         top_paths=top_paths)
   1413 
   1414     self.assertEqual(len(decode_pred_tf), top_paths)
   1415     log_prob_pred = keras.backend.eval(log_prob_pred_tf)
   1416     for i in range(top_paths):
   1417       self.assertTrue(
   1418           np.alltrue(
   1419               decode_truth[i] == keras.backend.eval(decode_pred_tf[i])))
   1420     self.assertAllClose(log_prob_truth, log_prob_pred)
   1421 
   1422   @test_util.run_v1_only('b/120545219')
   1423   def test_ctc_batch_cost(self):
   1424     with self.cached_session():
   1425       label_lens = np.expand_dims(np.asarray([5, 4]), 1)
   1426       input_lens = np.expand_dims(np.asarray([5, 5]), 1)  # number of timesteps
   1427       loss_log_probs = [3.34211, 5.42262]
   1428 
   1429       # dimensions are batch x time x categories
   1430       labels = np.asarray([[0, 1, 2, 1, 0], [0, 1, 1, 0, -1]])
   1431       inputs = np.asarray(
   1432           [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
   1433             [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
   1434             [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
   1435             [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
   1436             [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
   1437            [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
   1438             [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
   1439             [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
   1440             [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
   1441             [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]]],
   1442           dtype=np.float32)
   1443 
   1444       labels = keras.backend.variable(labels, dtype='int32')
   1445       inputs = keras.backend.variable(inputs, dtype='float32')
   1446       input_lens = keras.backend.variable(input_lens, dtype='int32')
   1447       label_lens = keras.backend.variable(label_lens, dtype='int32')
   1448       res = keras.backend.eval(
   1449           keras.backend.ctc_batch_cost(labels, inputs, input_lens, label_lens))
   1450       self.assertAllClose(res[:, 0], loss_log_probs, atol=1e-05)
   1451 
   1452       # test when batch_size = 1, that is, one sample only
   1453       ref = [3.34211]
   1454       input_lens = np.expand_dims(np.asarray([5]), 1)
   1455       label_lens = np.expand_dims(np.asarray([5]), 1)
   1456 
   1457       labels = np.asarray([[0, 1, 2, 1, 0]])
   1458       inputs = np.asarray(
   1459           [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [
   1460               0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436
   1461           ], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
   1462             [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
   1463             [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]]
   1464           ],
   1465           dtype=np.float32)
   1466 
   1467       k_labels = keras.backend.variable(labels, dtype='int32')
   1468       k_inputs = keras.backend.variable(inputs, dtype='float32')
   1469       k_input_lens = keras.backend.variable(input_lens, dtype='int32')
   1470       k_label_lens = keras.backend.variable(label_lens, dtype='int32')
   1471       res = keras.backend.eval(
   1472           keras.backend.ctc_batch_cost(k_labels, k_inputs, k_input_lens,
   1473                                        k_label_lens))
   1474       self.assertAllClose(res[:, 0], ref, atol=1e-05)
   1475 
   1476 
   1477 @test_util.run_all_in_graph_and_eager_modes
   1478 class TestRandomOps(test.TestCase):
   1479 
   1480   def test_random_binomial(self):
   1481     np.random.seed(123)
   1482     x = keras.backend.random_binomial((1000, 1000), p=0.5)
   1483     self.assertAllClose(np.mean(keras.backend.eval(x)), 0.5, atol=0.1)
   1484 
   1485   def test_truncated_normal(self):
   1486     np.random.seed(123)
   1487     x = keras.backend.truncated_normal((1000, 1000), mean=0.0, stddev=1.0)
   1488     y = keras.backend.eval(x)
   1489     self.assertAllClose(np.mean(y), 0., atol=0.1)
   1490     self.assertAllClose(np.std(y), 0.88, atol=0.1)
   1491     self.assertAllClose(np.max(y), 2., atol=0.1)
   1492     self.assertAllClose(np.min(y), -2., atol=0.1)
   1493 
   1494   def test_string_input(self):
   1495     seq = keras.Sequential([
   1496         keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
   1497         keras.layers.Lambda(lambda x: x[0])
   1498     ])
   1499     preds = seq.predict([['tensorflow eager']])
   1500     self.assertEqual(preds.shape, (1,))
   1501 
   1502 
   1503 class BackendGraphTests(test.TestCase):
   1504 
   1505   @test_util.run_deprecated_v1
   1506   def test_is_placeholder(self):
   1507     x = keras.backend.placeholder(shape=(1,))
   1508     self.assertEqual(keras.backend.is_placeholder(x), True)
   1509     # Test with TF placeholder
   1510     x = keras.backend.array_ops.placeholder(dtype='float32', shape=(1,))
   1511     self.assertEqual(keras.backend.is_placeholder(x), True)
   1512     x = keras.backend.variable(1)
   1513     self.assertEqual(keras.backend.is_placeholder(x), False)
   1514 
   1515   @test_util.run_in_graph_and_eager_modes
   1516   def test_function_basics(self):
   1517     x1 = keras.backend.placeholder(shape=(), dtype='float32')
   1518     x2 = keras.backend.placeholder(shape=(), dtype='int32')
   1519     v = keras.backend.variable(10.)
   1520     with keras.backend.get_graph().as_default():
   1521       y1 = x1 + keras.backend.cast(x2, 'float32') + v
   1522       y2 = x1 * keras.backend.cast(x2, 'float32')
   1523       with ops.control_dependencies([y1]):
   1524         u = keras.backend.update(v, 5.)
   1525     f = keras.backend.function([x1, x2], [y1, y2], updates=[u])
   1526     output_values = f([2, 3])
   1527     self.assertEqual(output_values, [15., 6.])
   1528     self.assertEqual(keras.backend.eval(v), 5.)
   1529 
   1530   @test_util.run_in_graph_and_eager_modes
   1531   def test_function_placeholder_with_default(self):
   1532     with keras.backend.get_graph().as_default():
   1533       x1 = array_ops.placeholder_with_default(
   1534           np.array(2., dtype='float32'), shape=())
   1535       x2 = array_ops.placeholder_with_default(
   1536           np.array(3, dtype='int32'), shape=())
   1537     y1 = x1 + keras.backend.cast(x2, 'float32')
   1538     y2 = x1 * keras.backend.cast(x2, 'float32')
   1539     f = keras.backend.function([x1, x2], [y1, y2])
   1540     output_values = f([4, 5])
   1541     self.assertEqual(output_values, [9., 20.])
   1542     output_values = f([None, None])
   1543     self.assertEqual(output_values, [5., 6.])
   1544 
   1545   @test_util.run_deprecated_v1
   1546   def test_function_tf_feed_symbols(self):
   1547     # Test Keras backend functions with TF tensor inputs.
   1548     with self.cached_session():
   1549       # Test feeding a resource variable to `function`.
   1550       x1 = keras.backend.placeholder(shape=())
   1551       x2 = keras.backend.placeholder(shape=())
   1552       lr = keras.backend.learning_phase()  # Include a placeholder_with_default.
   1553 
   1554       y1 = keras.backend.variable(10.)
   1555       y2 = 3
   1556 
   1557       f = keras.backend.function(
   1558           inputs=[x1, x2, lr],
   1559           outputs=[x1 + 1, keras.backend.in_train_phase(x2 + 2, x2 - 1)])
   1560       outs = f([y1, y2, None])  # Use default learning_phase value.
   1561       self.assertEqual(outs, [11., 2.])
   1562       outs = f([y1, y2, 1])  # Set learning phase value.
   1563       self.assertEqual(outs, [11., 5.])
   1564 
   1565       # Test triggering a callable refresh by changing the input.
   1566       y3 = keras.backend.constant(20.)  # Test with tensor
   1567       outs = f([y3, y2, None])
   1568       self.assertEqual(outs, [21., 2.])
   1569 
   1570       y4 = 4  # Test with non-symbol
   1571       outs = f([y4, y2, None])
   1572       self.assertEqual(outs, [5., 2.])
   1573 
   1574       # Test with a different dtype
   1575       y5 = keras.backend.constant(10., dtype='float64')
   1576       outs = f([y5, y2, None])
   1577       self.assertEqual(outs, [11., 2.])
   1578 
   1579   @test_util.run_deprecated_v1
   1580   def test_function_tf_fetches(self):
   1581     # Additional operations can be passed to tf.Session().run() via its
   1582     # `fetches` arguments. In contrast to `updates` argument of
   1583     # keras.backend.function() these do not have control dependency on `outputs`
   1584     # so they can run in parallel. Also they should not contribute to output of
   1585     # keras.backend.function().
   1586     with self.cached_session():
   1587       x = keras.backend.variable(0.)
   1588       y = keras.backend.variable(0.)
   1589       x_placeholder = keras.backend.placeholder(shape=())
   1590       y_placeholder = keras.backend.placeholder(shape=())
   1591 
   1592       f = keras.backend.function(
   1593           inputs=[x_placeholder, y_placeholder],
   1594           outputs=[x_placeholder + y_placeholder],
   1595           updates=[(x, x_placeholder + 1.)],
   1596           fetches=[keras.backend.update(y, 5.)])
   1597       output = f([10., 20.])
   1598       self.assertEqual(output, [30.])
   1599       self.assertEqual(keras.backend.get_session().run(fetches=[x, y]),
   1600                        [11., 5.])
   1601 
   1602   @test_util.run_deprecated_v1
   1603   def test_function_tf_feed_dict(self):
   1604     # Additional substitutions can be passed to `tf.Session().run()` via its
   1605     # `feed_dict` arguments. Note that the feed_dict is passed once in the
   1606     # constructor but we can modify the values in the dictionary. Through
   1607     # this feed_dict we can provide additional substitutions besides Keras
   1608     # inputs.
   1609     with self.cached_session():
   1610       x = keras.backend.variable(0.)
   1611       y = keras.backend.variable(0.)
   1612       x_placeholder = keras.backend.placeholder(shape=())
   1613       y_placeholder = keras.backend.placeholder(shape=())
   1614 
   1615       feed_dict = {y_placeholder: 3.}
   1616       fetches = [keras.backend.update(y, y_placeholder * 10.)]
   1617       f = keras.backend.function(
   1618           inputs=[x_placeholder],
   1619           outputs=[x_placeholder + 1.],
   1620           updates=[(x, x_placeholder + 10.)],
   1621           feed_dict=feed_dict,
   1622           fetches=fetches)
   1623       output = f([10.])
   1624       self.assertEqual(output, [11.])
   1625       self.assertEqual(keras.backend.get_session().run(fetches=[x, y]),
   1626                        [20., 30.])
   1627 
   1628       # updated value in feed_dict will be modified within the K.function()
   1629       feed_dict[y_placeholder] = 4.
   1630       output = f([20.])
   1631       self.assertEqual(output, [21.])
   1632       self.assertEqual(keras.backend.get_session().run(fetches=[x, y]),
   1633                        [30., 40.])
   1634 
   1635   @test_util.run_deprecated_v1
   1636   def test_function_tf_run_options_with_run_metadata(self):
   1637     with self.cached_session():
   1638       x_placeholder = keras.backend.placeholder(shape=())
   1639       y_placeholder = keras.backend.placeholder(shape=())
   1640 
   1641       run_options = config_pb2.RunOptions(output_partition_graphs=True)
   1642       run_metadata = config_pb2.RunMetadata()
   1643       # enable run_options.
   1644       f = keras.backend.function(
   1645           inputs=[x_placeholder, y_placeholder],
   1646           outputs=[x_placeholder + y_placeholder],
   1647           options=run_options,
   1648           run_metadata=run_metadata)
   1649       output = f([10., 20.])
   1650       self.assertEqual(output, [30.])
   1651       self.assertGreater(len(run_metadata.partition_graphs), 0)
   1652       # disable run_options.
   1653       f1 = keras.backend.function(
   1654           inputs=[x_placeholder, y_placeholder],
   1655           outputs=[x_placeholder + y_placeholder],
   1656           run_metadata=run_metadata)
   1657       output1 = f1([10., 20.])
   1658       self.assertEqual(output1, [30.])
   1659       self.assertEqual(len(run_metadata.partition_graphs), 0)
   1660 
   1661   @test_util.run_deprecated_v1
   1662   def test_function_fetch_callbacks(self):
   1663 
   1664     class CallbackStub(object):
   1665 
   1666       def __init__(self):
   1667         self.times_called = 0
   1668         self.callback_result = 0
   1669 
   1670       def _fetch_callback(self, result):
   1671         self.times_called += 1
   1672         self.callback_result = result
   1673 
   1674     with self.cached_session():
   1675       callback = CallbackStub()
   1676       x_placeholder = keras.backend.placeholder(shape=())
   1677       y_placeholder = keras.backend.placeholder(shape=())
   1678 
   1679       callback_op = x_placeholder * y_placeholder
   1680 
   1681       f = keras.backend.function(
   1682           inputs=[x_placeholder, y_placeholder],
   1683           outputs=[x_placeholder + y_placeholder])
   1684       f.fetches.append(callback_op)
   1685       f.fetch_callbacks[callback_op] = callback._fetch_callback
   1686 
   1687       _ = f([10., 20.])
   1688 
   1689       self.assertEqual(callback.times_called, 1)
   1690       self.assertEqual(callback.callback_result, 200)
   1691 
   1692   @test_util.run_in_graph_and_eager_modes
   1693   def test_function_dict_outputs(self):
   1694     x_ph = keras.backend.placeholder(shape=(), name='x')
   1695     y_ph = keras.backend.placeholder(shape=(), name='y')
   1696     outputs = {'x*y': y_ph * x_ph, 'x*x': x_ph * x_ph}
   1697 
   1698     f = keras.backend.function(inputs=[x_ph, y_ph], outputs=outputs)
   1699     x, y = 2., 5.
   1700     results = f([x, y])
   1701 
   1702     self.assertEqual(results['x*y'], 10.)
   1703     self.assertEqual(results['x*x'], 4)
   1704 
   1705   @test_util.run_in_graph_and_eager_modes
   1706   def test_function_dict_inputs(self):
   1707     placeholders = {
   1708         'x': keras.backend.placeholder(shape=()),
   1709         'y': keras.backend.placeholder(shape=())
   1710     }
   1711     outputs = [placeholders['x'] * placeholders['y']]
   1712 
   1713     f = keras.backend.function(inputs=placeholders, outputs=outputs)
   1714     results = f({'x': 2., 'y': 3.})
   1715     self.assertEqual(results[0], 6.)
   1716 
   1717   @test_util.run_in_graph_and_eager_modes
   1718   def test_function_single_input_output(self):
   1719     x_ph = keras.backend.placeholder(shape=(), name='x')
   1720     output = x_ph * x_ph
   1721     f = keras.backend.function(x_ph, output)
   1722     result = f(2.)
   1723     self.assertEqual(result, 4.)
   1724 
   1725   def test_placeholder(self):
   1726     x = keras.backend.placeholder(shape=(3, 4))
   1727     self.assertEqual(x.shape.as_list(), [3, 4])
   1728     x = keras.backend.placeholder(shape=(3, 4), sparse=True)
   1729     self.assertEqual(x.shape.as_list(), [3, 4])
   1730 
   1731   @test_util.run_deprecated_v1
   1732   def test_batch_normalization(self):
   1733     # No eager CPU kernel.
   1734     g_val = np.random.random((3,))
   1735     b_val = np.random.random((3,))
   1736     gamma = keras.backend.variable(g_val)
   1737     beta = keras.backend.variable(b_val)
   1738 
   1739     # 3D NHC case
   1740     val = np.random.random((10, 5, 3))
   1741     x = keras.backend.variable(val)
   1742     mean, var = nn.moments(x, (0, 1), None, None, False)
   1743     normed = keras.backend.batch_normalization(
   1744         x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
   1745     self.assertEqual(normed.shape.as_list(), [10, 5, 3])
   1746 
   1747     # 4D NHWC case
   1748     val = np.random.random((10, 5, 5, 3))
   1749     x = keras.backend.variable(val)
   1750     mean, var = nn.moments(x, (0, 1, 2), None, None, False)
   1751     normed = keras.backend.batch_normalization(
   1752         x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
   1753     self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
   1754 
   1755     # 4D NCHW case
   1756     val = np.random.random((10, 3, 5, 5))
   1757     x = keras.backend.variable(val)
   1758     mean, var = nn.moments(x, (0, 2, 3), None, None, False)
   1759     normed = keras.backend.batch_normalization(
   1760         x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
   1761     self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
   1762 
   1763   def test_get_session_different_graphs(self):
   1764     with ops.Graph().as_default():
   1765       x = keras.backend.constant(1)
   1766       session = keras.backend.get_session()
   1767       self.assertIs(session, keras.backend.get_session((x,)))
   1768       self.assertIs(session, keras.backend.get_session())
   1769     with ops.Graph().as_default():
   1770       self.assertIs(session, keras.backend.get_session((x,)))
   1771       self.assertIsNot(session, keras.backend.get_session())
   1772 
   1773 
   1774 if __name__ == '__main__':
   1775   test.main()
   1776