1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for RNN cells.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 23 import numpy as np 24 25 from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell 26 from tensorflow.core.protobuf import config_pb2 27 from tensorflow.python.client import session 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import random_seed 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import control_flow_ops 34 from tensorflow.python.ops import gradients_impl 35 from tensorflow.python.ops import init_ops 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.ops import random_ops 38 from tensorflow.python.ops import rnn 39 from tensorflow.python.ops import rnn_cell 40 from tensorflow.python.ops import rnn_cell_impl 41 from tensorflow.python.ops import variable_scope 42 from tensorflow.python.ops import variables 43 from tensorflow.python.platform import test 44 from tensorflow.python.util import nest 45 46 47 class RNNCellTest(test.TestCase): 48 49 def testCoupledInputForgetGateLSTMCell(self): 50 with self.test_session() as sess: 51 num_units = 2 52 state_size = num_units * 2 53 batch_size = 3 54 input_size = 4 55 expected_output = np.array( 56 [[0.121753, 0.121753], [0.103349, 0.103349], [0.100178, 0.100178]], 57 dtype=np.float32) 58 expected_state = np.array( 59 [[0.137523, 0.137523, 0.121753, 0.121753], [ 60 0.105450, 0.105450, 0.103349, 0.103349 61 ], [0.100742, 0.100742, 0.100178, 0.100178]], 62 dtype=np.float32) 63 with variable_scope.variable_scope( 64 "root", initializer=init_ops.constant_initializer(0.5)): 65 x = array_ops.zeros([batch_size, input_size]) 66 m = array_ops.zeros([batch_size, state_size]) 67 output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell( 68 num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m) 69 sess.run([variables.global_variables_initializer()]) 70 res = sess.run( 71 [output, state], { 72 x.name: 73 np.array([[1., 1., 1., 1.], [2., 2., 2., 2.], 74 [3., 3., 3., 3.]]), 75 m.name: 76 0.1 * np.ones((batch_size, state_size)) 77 }) 78 # This is a smoke test: Only making sure expected values didn't change. 79 self.assertEqual(len(res), 2) 80 self.assertAllClose(res[0], expected_output) 81 self.assertAllClose(res[1], expected_state) 82 83 def testTimeFreqLSTMCell(self): 84 with self.test_session() as sess: 85 num_units = 8 86 state_size = num_units * 2 87 batch_size = 3 88 input_size = 4 89 feature_size = 2 90 frequency_skip = 1 91 num_shifts = (input_size - feature_size) // frequency_skip + 1 92 with variable_scope.variable_scope( 93 "root", initializer=init_ops.constant_initializer(0.5)): 94 x = array_ops.zeros([batch_size, input_size]) 95 m = array_ops.zeros([batch_size, state_size * num_shifts]) 96 output, state = contrib_rnn_cell.TimeFreqLSTMCell( 97 num_units=num_units, 98 feature_size=feature_size, 99 frequency_skip=frequency_skip, 100 forget_bias=1.0)(x, m) 101 sess.run([variables.global_variables_initializer()]) 102 res = sess.run( 103 [output, state], { 104 x.name: 105 np.array([[1., 1., 1., 1.], [2., 2., 2., 2.], 106 [3., 3., 3., 3.]]), 107 m.name: 108 0.1 * np.ones((batch_size, int(state_size * (num_shifts)))) 109 }) 110 self.assertEqual(len(res), 2) 111 # The numbers in results were not calculated, this is mostly just a 112 # smoke test. 113 self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts)) 114 self.assertEqual(res[1].shape, (batch_size, state_size * num_shifts)) 115 # Different inputs so different outputs and states 116 for i in range(1, batch_size): 117 self.assertTrue( 118 float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) 119 self.assertTrue( 120 float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) 121 122 def testGridLSTMCell(self): 123 with self.test_session() as sess: 124 num_units = 8 125 batch_size = 3 126 input_size = 4 127 feature_size = 2 128 frequency_skip = 1 129 num_shifts = int((input_size - feature_size) / frequency_skip + 1) 130 with variable_scope.variable_scope( 131 "root", initializer=init_ops.constant_initializer(0.5)): 132 cell = contrib_rnn_cell.GridLSTMCell( 133 num_units=num_units, 134 feature_size=feature_size, 135 frequency_skip=frequency_skip, 136 forget_bias=1.0, 137 num_frequency_blocks=[num_shifts], 138 couple_input_forget_gates=True, 139 state_is_tuple=True) 140 inputs = constant_op.constant( 141 np.array( 142 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 143 dtype=np.float32), 144 dtype=dtypes.float32) 145 state_value = constant_op.constant( 146 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 147 dtype=dtypes.float32) 148 init_state = cell.state_tuple_type(*( 149 [state_value, state_value] * num_shifts)) 150 output, state = cell(inputs, init_state) 151 sess.run([variables.global_variables_initializer()]) 152 res = sess.run([output, state]) 153 self.assertEqual(len(res), 2) 154 # The numbers in results were not calculated, this is mostly just a 155 # smoke test. 156 self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts * 2)) 157 for ss in res[1]: 158 self.assertEqual(ss.shape, (batch_size, num_units)) 159 # Different inputs so different outputs and states 160 for i in range(1, batch_size): 161 self.assertTrue( 162 float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) 163 self.assertTrue( 164 float( 165 np.linalg.norm((res[1].state_f00_b00_c[0, :] - res[1] 166 .state_f00_b00_c[i, :]))) > 1e-6) 167 168 def testGridLSTMCellWithFrequencyBlocks(self): 169 with self.test_session() as sess: 170 num_units = 8 171 batch_size = 3 172 feature_size = 2 173 frequency_skip = 1 174 num_frequency_blocks = [1, 1] 175 total_blocks = num_frequency_blocks[0] + num_frequency_blocks[1] 176 start_freqindex_list = [0, 2] 177 end_freqindex_list = [2, 4] 178 with variable_scope.variable_scope( 179 "root", initializer=init_ops.constant_initializer(0.5)): 180 cell = contrib_rnn_cell.GridLSTMCell( 181 num_units=num_units, 182 feature_size=feature_size, 183 frequency_skip=frequency_skip, 184 forget_bias=1.0, 185 num_frequency_blocks=num_frequency_blocks, 186 start_freqindex_list=start_freqindex_list, 187 end_freqindex_list=end_freqindex_list, 188 couple_input_forget_gates=True, 189 state_is_tuple=True) 190 inputs = constant_op.constant( 191 np.array( 192 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 193 dtype=np.float32), 194 dtype=dtypes.float32) 195 state_value = constant_op.constant( 196 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 197 dtype=dtypes.float32) 198 init_state = cell.state_tuple_type(*( 199 [state_value, state_value] * total_blocks)) 200 output, state = cell(inputs, init_state) 201 sess.run([variables.global_variables_initializer()]) 202 res = sess.run([output, state]) 203 self.assertEqual(len(res), 2) 204 # The numbers in results were not calculated, this is mostly just a 205 # smoke test. 206 self.assertEqual(res[0].shape, 207 (batch_size, num_units * total_blocks * 2)) 208 for ss in res[1]: 209 self.assertEqual(ss.shape, (batch_size, num_units)) 210 # Different inputs so different outputs and states 211 for i in range(1, batch_size): 212 self.assertTrue( 213 float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) 214 self.assertTrue( 215 float( 216 np.linalg.norm((res[1].state_f00_b00_c[0, :] - res[1] 217 .state_f00_b00_c[i, :]))) > 1e-6) 218 219 def testGridLstmCellWithCoupledInputForgetGates(self): 220 num_units = 2 221 batch_size = 3 222 input_size = 4 223 feature_size = 2 224 frequency_skip = 1 225 num_shifts = int((input_size - feature_size) / frequency_skip + 1) 226 expected_output = np.array( 227 [[ 228 0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020, 229 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699 230 ], [ 231 0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342, 232 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171 233 ], [ 234 0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533, 235 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250 236 ]], 237 dtype=np.float32) 238 expected_state = np.array( 239 [[ 240 0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134, 241 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865 242 ], [ 243 0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432, 244 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245 245 ], [ 246 0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522, 247 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390 248 ]], 249 dtype=np.float32) 250 for state_is_tuple in [False, True]: 251 with self.test_session() as sess: 252 with variable_scope.variable_scope( 253 "state_is_tuple" + str(state_is_tuple), 254 initializer=init_ops.constant_initializer(0.5)): 255 cell = contrib_rnn_cell.GridLSTMCell( 256 num_units=num_units, 257 feature_size=feature_size, 258 frequency_skip=frequency_skip, 259 forget_bias=1.0, 260 num_frequency_blocks=[num_shifts], 261 couple_input_forget_gates=True, 262 state_is_tuple=state_is_tuple) 263 inputs = constant_op.constant( 264 np.array( 265 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 266 dtype=np.float32), 267 dtype=dtypes.float32) 268 if state_is_tuple: 269 state_value = constant_op.constant( 270 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 271 dtype=dtypes.float32) 272 init_state = cell.state_tuple_type(*( 273 [state_value, state_value] * num_shifts)) 274 else: 275 init_state = constant_op.constant( 276 0.1 * np.ones( 277 (batch_size, num_units * num_shifts * 2), dtype=np.float32), 278 dtype=dtypes.float32) 279 output, state = cell(inputs, init_state) 280 sess.run([variables.global_variables_initializer()]) 281 res = sess.run([output, state]) 282 # This is a smoke test: Only making sure expected values not change. 283 self.assertEqual(len(res), 2) 284 self.assertAllClose(res[0], expected_output) 285 if not state_is_tuple: 286 self.assertAllClose(res[1], expected_state) 287 else: 288 # There should be num_shifts * 2 states in the tuple. 289 self.assertEqual(len(res[1]), num_shifts * 2) 290 # Checking the shape of each state to be batch_size * num_units 291 for ss in res[1]: 292 self.assertEqual(ss.shape[0], batch_size) 293 self.assertEqual(ss.shape[1], num_units) 294 self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) 295 296 def testBidirectionGridLSTMCell(self): 297 with self.test_session() as sess: 298 num_units = 2 299 batch_size = 3 300 input_size = 4 301 feature_size = 2 302 frequency_skip = 1 303 num_shifts = int((input_size - feature_size) / frequency_skip + 1) 304 expected_output = np.array( 305 [[ 306 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283, 307 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774, 308 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341, 309 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218 310 ], [ 311 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057, 312 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359, 313 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517, 314 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840 315 ], [ 316 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357, 317 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604, 318 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552, 319 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731 320 ]], 321 dtype=np.float32) 322 expected_state = np.array( 323 [[ 324 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293, 325 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638, 326 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836, 327 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773 328 ], [ 329 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811, 330 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559, 331 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288, 332 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984 333 ], [ 334 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919, 335 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530, 336 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101, 337 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035 338 ]], 339 dtype=np.float32) 340 with variable_scope.variable_scope( 341 "root", initializer=init_ops.constant_initializer(0.5)): 342 cell = contrib_rnn_cell.BidirectionalGridLSTMCell( 343 num_units=num_units, 344 feature_size=feature_size, 345 share_time_frequency_weights=True, 346 frequency_skip=frequency_skip, 347 forget_bias=1.0, 348 num_frequency_blocks=[num_shifts]) 349 inputs = constant_op.constant( 350 np.array( 351 [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], 352 [3.0, 3.1, 3.2, 3.3]], 353 dtype=np.float32), 354 dtype=dtypes.float32) 355 state_value = constant_op.constant( 356 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 357 dtype=dtypes.float32) 358 init_state = cell.state_tuple_type(*( 359 [state_value, state_value] * num_shifts * 2)) 360 output, state = cell(inputs, init_state) 361 sess.run([variables.global_variables_initializer()]) 362 res = sess.run([output, state]) 363 self.assertEqual(len(res), 2) 364 # The numbers in results were not calculated, this is mostly just a 365 # smoke test. 366 self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts * 4)) 367 self.assertAllClose(res[0], expected_output) 368 # There should be num_shifts * 4 states in the tuple. 369 self.assertEqual(len(res[1]), num_shifts * 4) 370 # Checking the shape of each state to be batch_size * num_units 371 for ss in res[1]: 372 self.assertEqual(ss.shape[0], batch_size) 373 self.assertEqual(ss.shape[1], num_units) 374 self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) 375 376 def testBidirectionGridLSTMCellWithSliceOffset(self): 377 with self.test_session() as sess: 378 num_units = 2 379 batch_size = 3 380 input_size = 4 381 feature_size = 2 382 frequency_skip = 1 383 num_shifts = int((input_size - feature_size) / frequency_skip + 1) 384 expected_output = np.array( 385 [[ 386 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283, 387 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774, 388 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654, 389 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071 390 ], [ 391 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057, 392 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359, 393 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828, 394 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958 395 ], [ 396 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357, 397 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604, 398 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345, 399 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858 400 ]], 401 dtype=np.float32) 402 expected_state = np.array( 403 [[ 404 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293, 405 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638, 406 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628, 407 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446 408 ], [ 409 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811, 410 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559, 411 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488, 412 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429 413 ], [ 414 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919, 415 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530, 416 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978, 417 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597 418 ]], 419 dtype=np.float32) 420 with variable_scope.variable_scope( 421 "root", initializer=init_ops.constant_initializer(0.5)): 422 cell = contrib_rnn_cell.BidirectionalGridLSTMCell( 423 num_units=num_units, 424 feature_size=feature_size, 425 share_time_frequency_weights=True, 426 frequency_skip=frequency_skip, 427 forget_bias=1.0, 428 num_frequency_blocks=[num_shifts], 429 backward_slice_offset=1) 430 inputs = constant_op.constant( 431 np.array( 432 [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], 433 [3.0, 3.1, 3.2, 3.3]], 434 dtype=np.float32), 435 dtype=dtypes.float32) 436 state_value = constant_op.constant( 437 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 438 dtype=dtypes.float32) 439 init_state = cell.state_tuple_type(*( 440 [state_value, state_value] * num_shifts * 2)) 441 output, state = cell(inputs, init_state) 442 sess.run([variables.global_variables_initializer()]) 443 res = sess.run([output, state]) 444 self.assertEqual(len(res), 2) 445 # The numbers in results were not calculated, this is mostly just a 446 # smoke test. 447 self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts * 4)) 448 self.assertAllClose(res[0], expected_output) 449 # There should be num_shifts * 4 states in the tuple. 450 self.assertEqual(len(res[1]), num_shifts * 4) 451 # Checking the shape of each state to be batch_size * num_units 452 for ss in res[1]: 453 self.assertEqual(ss.shape[0], batch_size) 454 self.assertEqual(ss.shape[1], num_units) 455 self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) 456 457 def testAttentionCellWrapperFailures(self): 458 with self.assertRaisesRegexp(TypeError, 459 "The parameter cell is not RNNCell."): 460 contrib_rnn_cell.AttentionCellWrapper(None, 0) 461 462 num_units = 8 463 for state_is_tuple in [False, True]: 464 with ops.Graph().as_default(): 465 lstm_cell = rnn_cell.BasicLSTMCell( 466 num_units, state_is_tuple=state_is_tuple) 467 with self.assertRaisesRegexp( 468 ValueError, "attn_length should be greater than zero, got 0"): 469 contrib_rnn_cell.AttentionCellWrapper( 470 lstm_cell, 0, state_is_tuple=state_is_tuple) 471 with self.assertRaisesRegexp( 472 ValueError, "attn_length should be greater than zero, got -1"): 473 contrib_rnn_cell.AttentionCellWrapper( 474 lstm_cell, -1, state_is_tuple=state_is_tuple) 475 with ops.Graph().as_default(): 476 lstm_cell = rnn_cell.BasicLSTMCell(num_units, state_is_tuple=True) 477 with self.assertRaisesRegexp( 478 ValueError, "Cell returns tuple of states, but the flag " 479 "state_is_tuple is not set. State size is: *"): 480 contrib_rnn_cell.AttentionCellWrapper( 481 lstm_cell, 4, state_is_tuple=False) 482 483 def testAttentionCellWrapperZeros(self): 484 num_units = 8 485 attn_length = 16 486 batch_size = 3 487 input_size = 4 488 for state_is_tuple in [False, True]: 489 with ops.Graph().as_default(): 490 with self.test_session() as sess: 491 with variable_scope.variable_scope( 492 "state_is_tuple_" + str(state_is_tuple)): 493 lstm_cell = rnn_cell.BasicLSTMCell( 494 num_units, state_is_tuple=state_is_tuple) 495 cell = contrib_rnn_cell.AttentionCellWrapper( 496 lstm_cell, attn_length, state_is_tuple=state_is_tuple) 497 if state_is_tuple: 498 zeros = array_ops.zeros([batch_size, num_units], dtype=np.float32) 499 attn_state_zeros = array_ops.zeros( 500 [batch_size, attn_length * num_units], dtype=np.float32) 501 zero_state = ((zeros, zeros), zeros, attn_state_zeros) 502 else: 503 zero_state = array_ops.zeros( 504 [ 505 batch_size, 506 num_units * 2 + attn_length * num_units + num_units 507 ], 508 dtype=np.float32) 509 inputs = array_ops.zeros( 510 [batch_size, input_size], dtype=dtypes.float32) 511 output, state = cell(inputs, zero_state) 512 self.assertEquals(output.get_shape(), [batch_size, num_units]) 513 if state_is_tuple: 514 self.assertEquals(len(state), 3) 515 self.assertEquals(len(state[0]), 2) 516 self.assertEquals(state[0][0].get_shape(), 517 [batch_size, num_units]) 518 self.assertEquals(state[0][1].get_shape(), 519 [batch_size, num_units]) 520 self.assertEquals(state[1].get_shape(), [batch_size, num_units]) 521 self.assertEquals(state[2].get_shape(), 522 [batch_size, attn_length * num_units]) 523 tensors = [output] + list(state) 524 else: 525 self.assertEquals(state.get_shape(), [ 526 batch_size, 527 num_units * 2 + num_units + attn_length * num_units 528 ]) 529 tensors = [output, state] 530 zero_result = sum( 531 [math_ops.reduce_sum(math_ops.abs(x)) for x in tensors]) 532 sess.run(variables.global_variables_initializer()) 533 self.assertTrue(sess.run(zero_result) < 1e-6) 534 535 def testAttentionCellWrapperValues(self): 536 num_units = 8 537 attn_length = 16 538 batch_size = 3 539 for state_is_tuple in [False, True]: 540 with ops.Graph().as_default(): 541 with self.test_session() as sess: 542 with variable_scope.variable_scope( 543 "state_is_tuple_" + str(state_is_tuple)): 544 lstm_cell = rnn_cell.BasicLSTMCell( 545 num_units, state_is_tuple=state_is_tuple) 546 cell = contrib_rnn_cell.AttentionCellWrapper( 547 lstm_cell, attn_length, state_is_tuple=state_is_tuple) 548 if state_is_tuple: 549 zeros = constant_op.constant( 550 0.1 * np.ones([batch_size, num_units], dtype=np.float32), 551 dtype=dtypes.float32) 552 attn_state_zeros = constant_op.constant( 553 0.1 * np.ones( 554 [batch_size, attn_length * num_units], dtype=np.float32), 555 dtype=dtypes.float32) 556 zero_state = ((zeros, zeros), zeros, attn_state_zeros) 557 else: 558 zero_state = constant_op.constant( 559 0.1 * np.ones( 560 [ 561 batch_size, 562 num_units * 2 + num_units + attn_length * num_units 563 ], 564 dtype=np.float32), 565 dtype=dtypes.float32) 566 inputs = constant_op.constant( 567 np.array( 568 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 569 dtype=np.float32), 570 dtype=dtypes.float32) 571 output, state = cell(inputs, zero_state) 572 if state_is_tuple: 573 concat_state = array_ops.concat( 574 [state[0][0], state[0][1], state[1], state[2]], 1) 575 else: 576 concat_state = state 577 sess.run(variables.global_variables_initializer()) 578 output, state = sess.run([output, concat_state]) 579 # Different inputs so different outputs and states 580 for i in range(1, batch_size): 581 self.assertTrue( 582 float(np.linalg.norm((output[0, :] - output[i, :]))) > 1e-6) 583 self.assertTrue( 584 float(np.linalg.norm((state[0, :] - state[i, :]))) > 1e-6) 585 586 def _testAttentionCellWrapperCorrectResult(self): 587 num_units = 4 588 attn_length = 6 589 batch_size = 2 590 expected_output = np.array( 591 [[1.068372, 0.45496, -0.678277, 0.340538], 592 [1.018088, 0.378983, -0.572179, 0.268591]], 593 dtype=np.float32) 594 expected_state = np.array( 595 [[ 596 0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962, 597 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077, 598 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536, 599 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063, 600 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228, 601 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432, 602 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152, 603 0.51843399 604 ], [ 605 0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637, 606 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857, 607 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689, 608 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957, 609 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421, 610 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971, 611 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457, 612 0.70582712 613 ]], 614 dtype=np.float32) 615 seed = 12345 616 random_seed.set_random_seed(seed) 617 rnn_scope = None 618 for state_is_tuple in [False, True]: 619 with session.Session() as sess: 620 with variable_scope.variable_scope( 621 "state_is_tuple", 622 reuse=state_is_tuple, 623 initializer=init_ops.glorot_uniform_initializer()): 624 lstm_cell = rnn_cell.BasicLSTMCell( 625 num_units, state_is_tuple=state_is_tuple) 626 cell = contrib_rnn_cell.AttentionCellWrapper( 627 lstm_cell, attn_length, state_is_tuple=state_is_tuple) 628 # This is legacy behavior to preserve the test. Weight 629 # sharing no longer works by creating a new RNNCell in the 630 # same variable scope; so here we restore the scope of the 631 # RNNCells after the first use below. 632 if rnn_scope is not None: 633 (cell._scope, lstm_cell._scope) = rnn_scope # pylint: disable=protected-access,unpacking-non-sequence 634 zeros1 = random_ops.random_uniform( 635 (batch_size, num_units), 0.0, 1.0, seed=seed + 1) 636 zeros2 = random_ops.random_uniform( 637 (batch_size, num_units), 0.0, 1.0, seed=seed + 2) 638 zeros3 = random_ops.random_uniform( 639 (batch_size, num_units), 0.0, 1.0, seed=seed + 3) 640 attn_state_zeros = random_ops.random_uniform( 641 (batch_size, attn_length * num_units), 0.0, 1.0, seed=seed + 4) 642 zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros) 643 if not state_is_tuple: 644 zero_state = array_ops.concat([ 645 zero_state[0][0], zero_state[0][1], zero_state[1], zero_state[2] 646 ], 1) 647 inputs = random_ops.random_uniform( 648 (batch_size, num_units), 0.0, 1.0, seed=seed + 5) 649 output, state = cell(inputs, zero_state) 650 # This is legacy behavior to preserve the test. Weight 651 # sharing no longer works by creating a new RNNCell in the 652 # same variable scope; so here we store the scope of the 653 # first RNNCell for reuse above. 654 if rnn_scope is None: 655 rnn_scope = (cell._scope, lstm_cell._scope) # pylint: disable=protected-access 656 if state_is_tuple: 657 state = array_ops.concat( 658 [state[0][0], state[0][1], state[1], state[2]], 1) 659 sess.run(variables.global_variables_initializer()) 660 self.assertAllClose(sess.run(output), expected_output) 661 self.assertAllClose(sess.run(state), expected_state) 662 663 def testNASCell(self): 664 num_units = 6 665 batch_size = 3 666 expected_output = np.array( 667 [[0.576751, 0.576751, 0.576751, 0.576751, 0.576751, 0.576751], 668 [0.618936, 0.618936, 0.618936, 0.618936, 0.618936, 0.618936], 669 [0.627393, 0.627393, 0.627393, 0.627393, 0.627393, 0.627393]]) 670 expected_state = np.array([[ 671 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 672 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087 673 ], [ 674 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 675 0.6189357, 0.6189357, 0.61893570, 0.6189357, 0.6189357, 0.6189357 676 ], [ 677 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653, 678 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348 679 ]]) 680 with self.test_session() as sess: 681 with variable_scope.variable_scope( 682 "nas_test", initializer=init_ops.constant_initializer(0.5)): 683 cell = contrib_rnn_cell.NASCell(num_units=num_units) 684 inputs = constant_op.constant( 685 np.array( 686 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 687 dtype=np.float32), 688 dtype=dtypes.float32) 689 state_value = constant_op.constant( 690 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 691 dtype=dtypes.float32) 692 init_state = rnn_cell.LSTMStateTuple(state_value, state_value) 693 output, state = cell(inputs, init_state) 694 sess.run([variables.global_variables_initializer()]) 695 res = sess.run([output, state]) 696 697 # This is a smoke test: Only making sure expected values not change. 698 self.assertEqual(len(res), 2) 699 self.assertAllClose(res[0], expected_output) 700 # There should be 2 states in the tuple. 701 self.assertEqual(len(res[1]), 2) 702 # Checking the shape of each state to be batch_size * num_units 703 new_c, new_h = res[1] 704 self.assertEqual(new_c.shape[0], batch_size) 705 self.assertEqual(new_c.shape[1], num_units) 706 self.assertEqual(new_h.shape[0], batch_size) 707 self.assertEqual(new_h.shape[1], num_units) 708 self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) 709 710 def testNASCellProj(self): 711 num_units = 6 712 batch_size = 3 713 num_proj = 5 714 expected_output = np.array( 715 [[1.697418, 1.697418, 1.697418, 1.697418, 716 1.697418], [1.840037, 1.840037, 1.840037, 1.840037, 1.840037], 717 [1.873985, 1.873985, 1.873985, 1.873985, 1.873985]]) 718 expected_state = np.array([[ 719 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 720 1.69741797, 1.69741797, 1.69741797, 1.69741797, 1.69741797 721 ], [ 722 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 723 1.84003687, 1.84003687, 1.84003687, 1.84003687, 1.84003687 724 ], [ 725 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 726 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517 727 ]]) 728 with self.test_session() as sess: 729 with variable_scope.variable_scope( 730 "nas_proj_test", initializer=init_ops.constant_initializer(0.5)): 731 cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj) 732 inputs = constant_op.constant( 733 np.array( 734 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 735 dtype=np.float32), 736 dtype=dtypes.float32) 737 state_value_c = constant_op.constant( 738 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 739 dtype=dtypes.float32) 740 state_value_h = constant_op.constant( 741 0.1 * np.ones((batch_size, num_proj), dtype=np.float32), 742 dtype=dtypes.float32) 743 init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h) 744 output, state = cell(inputs, init_state) 745 sess.run([variables.global_variables_initializer()]) 746 res = sess.run([output, state]) 747 748 # This is a smoke test: Only making sure expected values not change. 749 self.assertEqual(len(res), 2) 750 self.assertAllClose(res[0], expected_output) 751 # There should be 2 states in the tuple. 752 self.assertEqual(len(res[1]), 2) 753 # Checking the shape of each state to be batch_size * num_units 754 new_c, new_h = res[1] 755 self.assertEqual(new_c.shape[0], batch_size) 756 self.assertEqual(new_c.shape[1], num_units) 757 self.assertEqual(new_h.shape[0], batch_size) 758 self.assertEqual(new_h.shape[1], num_proj) 759 self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) 760 761 def testUGRNNCell(self): 762 num_units = 2 763 batch_size = 3 764 expected_state_and_output = np.array( 765 [[0.13752282, 0.13752282], [0.10545051, 0.10545051], 766 [0.10074195, 0.10074195]], 767 dtype=np.float32) 768 with self.test_session() as sess: 769 with variable_scope.variable_scope( 770 "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)): 771 cell = contrib_rnn_cell.UGRNNCell(num_units=num_units) 772 inputs = constant_op.constant( 773 np.array( 774 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 775 dtype=np.float32), 776 dtype=dtypes.float32) 777 init_state = constant_op.constant( 778 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 779 dtype=dtypes.float32) 780 output, state = cell(inputs, init_state) 781 sess.run([variables.global_variables_initializer()]) 782 res = sess.run([output, state]) 783 # This is a smoke test: Only making sure expected values didn't change. 784 self.assertEqual(len(res), 2) 785 self.assertAllClose(res[0], expected_state_and_output) 786 self.assertAllClose(res[1], expected_state_and_output) 787 788 def testIntersectionRNNCell(self): 789 num_units = 2 790 batch_size = 3 791 expected_state = np.array( 792 [[0.13752282, 0.13752282], [0.10545051, 0.10545051], 793 [0.10074195, 0.10074195]], 794 dtype=np.float32) 795 expected_output = np.array( 796 [[2.00431061, 2.00431061], [4.00060606, 4.00060606], 797 [6.00008249, 6.00008249]], 798 dtype=np.float32) 799 with self.test_session() as sess: 800 with variable_scope.variable_scope( 801 "intersection_rnn_cell_test", 802 initializer=init_ops.constant_initializer(0.5)): 803 cell = contrib_rnn_cell.IntersectionRNNCell( 804 num_units=num_units, num_in_proj=num_units) 805 inputs = constant_op.constant( 806 np.array( 807 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 808 dtype=np.float32), 809 dtype=dtypes.float32) 810 init_state = constant_op.constant( 811 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 812 dtype=dtypes.float32) 813 output, state = cell(inputs, init_state) 814 sess.run([variables.global_variables_initializer()]) 815 res = sess.run([output, state]) 816 # This is a smoke test: Only making sure expected values didn't change. 817 self.assertEqual(len(res), 2) 818 self.assertAllClose(res[0], expected_output) 819 self.assertAllClose(res[1], expected_state) 820 821 def testIntersectionRNNCellFailure(self): 822 num_units = 2 823 batch_size = 3 824 cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units) 825 inputs = constant_op.constant( 826 np.array( 827 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]], 828 dtype=np.float32), 829 dtype=dtypes.float32) 830 init_state = constant_op.constant( 831 0.1 * np.ones((batch_size, num_units), dtype=np.float32), 832 dtype=dtypes.float32) 833 with self.assertRaisesRegexp(ValueError, 834 "Must have input size == output size for " 835 "Intersection RNN. To fix, num_in_proj should " 836 "be set to num_units at cell init."): 837 cell(inputs, init_state) 838 839 def testPhasedLSTMCell(self): 840 with self.test_session() as sess: 841 num_units = 2 842 batch_size = 3 843 input_size = 4 844 expected_state_c = np.array( 845 [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04], 846 [4.401947e-04, 9.143004e-04]], 847 dtype=np.float32) 848 expected_state_h = np.array( 849 [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04], 850 [3.347936e-04, 6.953785e-04]], 851 dtype=np.float32) 852 with variable_scope.variable_scope( 853 "root", initializer=init_ops.constant_initializer(0.5)): 854 t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64) 855 x = array_ops.zeros([batch_size, input_size]) 856 c0 = array_ops.zeros([batch_size, 2]) 857 h0 = array_ops.zeros([batch_size, 2]) 858 state0 = rnn_cell.LSTMStateTuple(c0, h0) 859 output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)( 860 (t, x), state0) 861 sess.run([variables.global_variables_initializer()]) 862 res = sess.run( 863 [output, state], { 864 t.name: 865 np.array([[1.], [2.], [3.]]), 866 x.name: 867 np.array([[1., 1., 1., 1.], [2., 2., 2., 2.], 868 [3., 3., 3., 3.]]), 869 }) 870 # This is a smoke test, making sure expected values are unchanged. 871 self.assertEqual(len(res), 2) 872 self.assertAllClose(res[0], res[1].h) 873 self.assertAllClose(res[1].c, expected_state_c) 874 self.assertAllClose(res[1].h, expected_state_h) 875 876 def testConv1DLSTMCell(self): 877 with self.test_session() as sess: 878 shape = [2, 1] 879 filter_size = [3] 880 num_features = 1 881 batch_size = 2 882 expected_state_c = np.array( 883 [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]], 884 dtype=np.float32) 885 expected_state_h = np.array( 886 [[[0.6529865603], [0.6529865603]], [[0.8736877431], [0.8736877431]]], 887 dtype=np.float32) 888 with variable_scope.variable_scope( 889 "root", initializer=init_ops.constant_initializer(1.0 / 2.0)): 890 x = array_ops.placeholder(dtypes.float32, [None, None, 1]) 891 cell = contrib_rnn_cell.Conv1DLSTMCell( 892 input_shape=shape, 893 kernel_shape=filter_size, 894 output_channels=num_features) 895 hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) 896 output, state = cell(x, hidden) 897 898 sess.run([variables.global_variables_initializer()]) 899 res = sess.run( 900 [output, state], { 901 hidden[0].name: np.array([[[1.], [1.]], [[2.], [2.]]]), 902 x.name: np.array([[[1.], [1.]], [[2.], [2.]]]), 903 }) 904 # This is a smoke test, making sure expected values are unchanged. 905 self.assertEqual(len(res), 2) 906 self.assertAllClose(res[0], res[1].h) 907 self.assertAllClose(res[1].c, expected_state_c) 908 self.assertAllClose(res[1].h, expected_state_h) 909 910 def testConv2DLSTMCell(self): 911 with self.test_session() as sess: 912 shape = [2, 2, 1] 913 filter_size = [3, 3] 914 num_features = 1 915 batch_size = 2 916 expected_state_c = np.array( 917 [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]], 918 [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]] 919 ]], 920 dtype=np.float32) 921 expected_state_h = np.array( 922 [[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]], 923 [[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]] 924 ]], 925 dtype=np.float32) 926 with variable_scope.variable_scope( 927 "root", initializer=init_ops.constant_initializer(1.0 / 4.0)): 928 x = array_ops.placeholder(dtypes.float32, [None, None, None, 1]) 929 cell = contrib_rnn_cell.Conv2DLSTMCell( 930 input_shape=shape, 931 kernel_shape=filter_size, 932 output_channels=num_features) 933 hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) 934 output, state = cell(x, hidden) 935 936 sess.run([variables.global_variables_initializer()]) 937 res = sess.run( 938 [output, state], { 939 hidden[0].name: 940 np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]], 941 [[2.], [2.]]]]), 942 x.name: 943 np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]], 944 [[2.], [2.]]]]), 945 }) 946 # This is a smoke test, making sure expected values are unchanged. 947 self.assertEqual(len(res), 2) 948 self.assertAllClose(res[0], res[1].h) 949 self.assertAllClose(res[1].c, expected_state_c) 950 self.assertAllClose(res[1].h, expected_state_h) 951 952 def testConv3DLSTMCell(self): 953 with self.test_session() as sess: 954 shape = [2, 2, 2, 1] 955 filter_size = [3, 3, 3] 956 num_features = 1 957 batch_size = 2 958 expected_state_c = np.array( 959 [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]] 960 ], [[[1.4375670191], [1.4375670191]], [[1.4375670191], 961 [1.4375670191]]]], 962 [[[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]] 963 ], [[[2.7542609292], [2.7542609292]], [[2.7542609292], 964 [2.7542609292]]]]], 965 dtype=np.float32) 966 expected_state_h = np.array( 967 [[[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]] 968 ], [[[0.6529865603], [0.6529865603]], [[0.6529865603], 969 [0.6529865603]]]], 970 [[[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]] 971 ], [[[0.8736877431], [0.8736877431]], [[0.8736877431], 972 [0.8736877431]]]]], 973 dtype=np.float32) 974 with variable_scope.variable_scope( 975 "root", initializer=init_ops.constant_initializer(1.0 / 8.0)): 976 x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1]) 977 cell = contrib_rnn_cell.Conv3DLSTMCell( 978 input_shape=shape, 979 kernel_shape=filter_size, 980 output_channels=num_features) 981 hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) 982 output, state = cell(x, hidden) 983 984 sess.run([variables.global_variables_initializer()]) 985 res = sess.run( 986 [output, state], { 987 hidden[0].name: 988 np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[ 989 1. 990 ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], 991 [[[2.], [2.]], [[2.], [2.]]]]]), 992 x.name: 993 np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[ 994 1. 995 ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]], 996 [[2.], [2.]]]]]) 997 }) 998 # This is a smoke test, making sure expected values are unchanged. 999 self.assertEqual(len(res), 2) 1000 self.assertAllClose(res[0], res[1].h) 1001 self.assertAllClose(res[1].c, expected_state_c) 1002 self.assertAllClose(res[1].h, expected_state_h) 1003 1004 def testHighwayWrapper(self): 1005 with self.test_session() as sess: 1006 with variable_scope.variable_scope( 1007 "base_cell", initializer=init_ops.constant_initializer(0.5)): 1008 x = array_ops.zeros([1, 3]) 1009 m = array_ops.zeros([1, 3]) 1010 base_cell = rnn_cell.GRUCell(3) 1011 g, m_new = base_cell(x, m) 1012 with variable_scope.variable_scope( 1013 "hw_cell", initializer=init_ops.constant_initializer(0.5)): 1014 hw_cell = contrib_rnn_cell.HighwayWrapper( 1015 rnn_cell.GRUCell(3), carry_bias_init=-100.0) 1016 g_res, m_new_res = hw_cell(x, m) 1017 sess.run([variables.global_variables_initializer()]) 1018 res = sess.run([g, g_res, m_new, m_new_res], { 1019 x: np.array([[1., 1., 1.]]), 1020 m: np.array([[0.1, 0.1, 0.1]]) 1021 }) 1022 # As carry_bias_init is very negative, the carry gate is 'open' and the 1023 # transform gate is 'closed'. This means the output equals the input. 1024 self.assertAllClose(res[1], res[0]) 1025 # States are left untouched 1026 self.assertAllClose(res[2], res[3]) 1027 1028 def testGLSTMCell(self): 1029 # Ensure that G-LSTM matches LSTM when number_of_groups = 1 1030 batch_size = 2 1031 num_units = 4 1032 number_of_groups = 1 1033 1034 with self.test_session() as sess: 1035 with variable_scope.variable_scope( 1036 "root1", initializer=init_ops.constant_initializer(0.5)): 1037 x = array_ops.ones([batch_size, num_units]) 1038 # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM 1039 gcell = contrib_rnn_cell.GLSTMCell( 1040 num_units=num_units, number_of_groups=number_of_groups) 1041 cell = rnn_cell.LSTMCell(num_units=num_units) 1042 self.assertTrue(isinstance(gcell.state_size, tuple)) 1043 zero_state = gcell.zero_state( 1044 batch_size=batch_size, dtype=dtypes.float32) 1045 gh, gs = gcell(x, zero_state) 1046 h, g = cell(x, zero_state) 1047 1048 sess.run([variables.global_variables_initializer()]) 1049 glstm_result = sess.run([gh, gs]) 1050 lstm_result = sess.run([h, g]) 1051 1052 self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5) 1053 self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5) 1054 1055 # Test that G-LSTM subgroup act like corresponding sub-LSTMs 1056 batch_size = 2 1057 num_units = 4 1058 number_of_groups = 2 1059 1060 with self.test_session() as sess: 1061 with variable_scope.variable_scope( 1062 "root2", initializer=init_ops.constant_initializer(0.5)): 1063 # input for G-LSTM with 2 groups 1064 glstm_input = array_ops.ones([batch_size, num_units]) 1065 gcell = contrib_rnn_cell.GLSTMCell( 1066 num_units=num_units, number_of_groups=number_of_groups) 1067 gcell_zero_state = gcell.zero_state( 1068 batch_size=batch_size, dtype=dtypes.float32) 1069 gh, gs = gcell(glstm_input, gcell_zero_state) 1070 1071 # input for LSTM cell simulating single G-LSTM group 1072 lstm_input = array_ops.ones([batch_size, num_units / number_of_groups]) 1073 # note division by number_of_groups. This cell one simulates G-LSTM group 1074 cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups)) 1075 cell_zero_state = cell.zero_state( 1076 batch_size=batch_size, dtype=dtypes.float32) 1077 h, g = cell(lstm_input, cell_zero_state) 1078 1079 sess.run([variables.global_variables_initializer()]) 1080 [gh_res, h_res] = sess.run([gh, h]) 1081 self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)], 1082 h_res, 1e-5) 1083 self.assertAllClose(gh_res[:, int(num_units / number_of_groups):], 1084 h_res, 1e-5) 1085 1086 1087 class LayerNormBasicLSTMCellTest(test.TestCase): 1088 1089 # NOTE: all the values in the current test case have been calculated. 1090 1091 def testBasicLSTMCell(self): 1092 with self.test_session() as sess: 1093 with variable_scope.variable_scope( 1094 "root", initializer=init_ops.constant_initializer(0.5)): 1095 x = array_ops.zeros([1, 2]) 1096 c0 = array_ops.zeros([1, 2]) 1097 h0 = array_ops.zeros([1, 2]) 1098 state0 = rnn_cell.LSTMStateTuple(c0, h0) 1099 c1 = array_ops.zeros([1, 2]) 1100 h1 = array_ops.zeros([1, 2]) 1101 state1 = rnn_cell.LSTMStateTuple(c1, h1) 1102 state = (state0, state1) 1103 single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2) 1104 cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)]) 1105 g, out_m = cell(x, state) 1106 sess.run([variables.global_variables_initializer()]) 1107 res = sess.run( 1108 [g, out_m], { 1109 x.name: np.array([[1., 1.]]), 1110 c0.name: 0.1 * np.asarray([[0, 1]]), 1111 h0.name: 0.1 * np.asarray([[2, 3]]), 1112 c1.name: 0.1 * np.asarray([[4, 5]]), 1113 h1.name: 0.1 * np.asarray([[6, 7]]), 1114 }) 1115 1116 expected_h = np.array([[-0.38079708, 0.38079708]]) 1117 expected_state0_c = np.array([[-1.0, 1.0]]) 1118 expected_state0_h = np.array([[-0.38079708, 0.38079708]]) 1119 expected_state1_c = np.array([[-1.0, 1.0]]) 1120 expected_state1_h = np.array([[-0.38079708, 0.38079708]]) 1121 1122 actual_h = res[0] 1123 actual_state0_c = res[1][0].c 1124 actual_state0_h = res[1][0].h 1125 actual_state1_c = res[1][1].c 1126 actual_state1_h = res[1][1].h 1127 1128 self.assertAllClose(actual_h, expected_h, 1e-5) 1129 self.assertAllClose(expected_state0_c, actual_state0_c, 1e-5) 1130 self.assertAllClose(expected_state0_h, actual_state0_h, 1e-5) 1131 self.assertAllClose(expected_state1_c, actual_state1_c, 1e-5) 1132 self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5) 1133 1134 with variable_scope.variable_scope( 1135 "other", initializer=init_ops.constant_initializer(0.5)): 1136 x = array_ops.zeros( 1137 [1, 3]) # Test BasicLSTMCell with input_size != num_units. 1138 c = array_ops.zeros([1, 2]) 1139 h = array_ops.zeros([1, 2]) 1140 state = rnn_cell.LSTMStateTuple(c, h) 1141 cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2) 1142 g, out_m = cell(x, state) 1143 sess.run([variables.global_variables_initializer()]) 1144 res = sess.run( 1145 [g, out_m], { 1146 x.name: np.array([[1., 1., 1.]]), 1147 c.name: 0.1 * np.asarray([[0, 1]]), 1148 h.name: 0.1 * np.asarray([[2, 3]]), 1149 }) 1150 1151 expected_h = np.array([[-0.38079708, 0.38079708]]) 1152 expected_c = np.array([[-1.0, 1.0]]) 1153 self.assertEqual(len(res), 2) 1154 self.assertAllClose(res[0], expected_h, 1e-5) 1155 self.assertAllClose(res[1].c, expected_c, 1e-5) 1156 self.assertAllClose(res[1].h, expected_h, 1e-5) 1157 1158 def testBasicLSTMCellWithoutNorm(self): 1159 """Tests that BasicLSTMCell with layer_norm=False.""" 1160 with self.test_session() as sess: 1161 with variable_scope.variable_scope( 1162 "root", initializer=init_ops.constant_initializer(0.5)): 1163 x = array_ops.zeros([1, 2]) 1164 c0 = array_ops.zeros([1, 2]) 1165 h0 = array_ops.zeros([1, 2]) 1166 state0 = rnn_cell.LSTMStateTuple(c0, h0) 1167 c1 = array_ops.zeros([1, 2]) 1168 h1 = array_ops.zeros([1, 2]) 1169 state1 = rnn_cell.LSTMStateTuple(c1, h1) 1170 state = (state0, state1) 1171 single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) 1172 cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)]) 1173 g, out_m = cell(x, state) 1174 sess.run([variables.global_variables_initializer()]) 1175 res = sess.run( 1176 [g, out_m], { 1177 x.name: np.array([[1., 1.]]), 1178 c0.name: 0.1 * np.asarray([[0, 1]]), 1179 h0.name: 0.1 * np.asarray([[2, 3]]), 1180 c1.name: 0.1 * np.asarray([[4, 5]]), 1181 h1.name: 0.1 * np.asarray([[6, 7]]), 1182 }) 1183 1184 expected_h = np.array([[0.70230919, 0.72581059]]) 1185 expected_state0_c = np.array([[0.8020075, 0.89599884]]) 1186 expected_state0_h = np.array([[0.56668288, 0.60858738]]) 1187 expected_state1_c = np.array([[1.17500675, 1.26892781]]) 1188 expected_state1_h = np.array([[0.70230919, 0.72581059]]) 1189 1190 actual_h = res[0] 1191 actual_state0_c = res[1][0].c 1192 actual_state0_h = res[1][0].h 1193 actual_state1_c = res[1][1].c 1194 actual_state1_h = res[1][1].h 1195 1196 self.assertAllClose(actual_h, expected_h, 1e-5) 1197 self.assertAllClose(expected_state0_c, actual_state0_c, 1e-5) 1198 self.assertAllClose(expected_state0_h, actual_state0_h, 1e-5) 1199 self.assertAllClose(expected_state1_c, actual_state1_c, 1e-5) 1200 self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5) 1201 1202 with variable_scope.variable_scope( 1203 "other", initializer=init_ops.constant_initializer(0.5)) as vs: 1204 x = array_ops.zeros( 1205 [1, 3]) # Test BasicLSTMCell with input_size != num_units. 1206 c = array_ops.zeros([1, 2]) 1207 h = array_ops.zeros([1, 2]) 1208 state = rnn_cell.LSTMStateTuple(c, h) 1209 cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) 1210 g, out_m = cell(x, state) 1211 sess.run([variables.global_variables_initializer()]) 1212 res = sess.run( 1213 [g, out_m], { 1214 x.name: np.array([[1., 1., 1.]]), 1215 c.name: 0.1 * np.asarray([[0, 1]]), 1216 h.name: 0.1 * np.asarray([[2, 3]]), 1217 }) 1218 1219 expected_h = np.array([[0.64121795, 0.68166804]]) 1220 expected_c = np.array([[0.88477188, 0.98103917]]) 1221 self.assertEqual(len(res), 2) 1222 self.assertAllClose(res[0], expected_h, 1e-5) 1223 self.assertAllClose(res[1].c, expected_c, 1e-5) 1224 self.assertAllClose(res[1].h, expected_h, 1e-5) 1225 1226 def testBasicLSTMCellWithStateTuple(self): 1227 with self.test_session() as sess: 1228 with variable_scope.variable_scope( 1229 "root", initializer=init_ops.constant_initializer(0.5)): 1230 x = array_ops.zeros([1, 2]) 1231 c0 = array_ops.zeros([1, 2]) 1232 h0 = array_ops.zeros([1, 2]) 1233 state0 = rnn_cell.LSTMStateTuple(c0, h0) 1234 c1 = array_ops.zeros([1, 2]) 1235 h1 = array_ops.zeros([1, 2]) 1236 state1 = rnn_cell.LSTMStateTuple(c1, h1) 1237 cell = rnn_cell.MultiRNNCell( 1238 [contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)]) 1239 h, (s0, s1) = cell(x, (state0, state1)) 1240 sess.run([variables.global_variables_initializer()]) 1241 res = sess.run( 1242 [h, s0, s1], { 1243 x.name: np.array([[1., 1.]]), 1244 c0.name: 0.1 * np.asarray([[0, 1]]), 1245 h0.name: 0.1 * np.asarray([[2, 3]]), 1246 c1.name: 0.1 * np.asarray([[4, 5]]), 1247 h1.name: 0.1 * np.asarray([[6, 7]]), 1248 }) 1249 1250 expected_h = np.array([[-0.38079708, 0.38079708]]) 1251 expected_h0 = np.array([[-0.38079708, 0.38079708]]) 1252 expected_c0 = np.array([[-1.0, 1.0]]) 1253 expected_h1 = np.array([[-0.38079708, 0.38079708]]) 1254 expected_c1 = np.array([[-1.0, 1.0]]) 1255 1256 self.assertEqual(len(res), 3) 1257 self.assertAllClose(res[0], expected_h, 1e-5) 1258 self.assertAllClose(res[1].c, expected_c0, 1e-5) 1259 self.assertAllClose(res[1].h, expected_h0, 1e-5) 1260 self.assertAllClose(res[2].c, expected_c1, 1e-5) 1261 self.assertAllClose(res[2].h, expected_h1, 1e-5) 1262 1263 def testBasicLSTMCellWithStateTupleLayerNorm(self): 1264 """The results of LSTMCell and LayerNormBasicLSTMCell should be the same.""" 1265 with self.test_session() as sess: 1266 with variable_scope.variable_scope( 1267 "root", initializer=init_ops.constant_initializer(0.5)): 1268 x = array_ops.zeros([1, 2]) 1269 c0 = array_ops.zeros([1, 2]) 1270 h0 = array_ops.zeros([1, 2]) 1271 state0 = rnn_cell_impl.LSTMStateTuple(c0, h0) 1272 c1 = array_ops.zeros([1, 2]) 1273 h1 = array_ops.zeros([1, 2]) 1274 state1 = rnn_cell_impl.LSTMStateTuple(c1, h1) 1275 cell = rnn_cell_impl.MultiRNNCell([ 1276 contrib_rnn_cell.LayerNormLSTMCell( 1277 2, layer_norm=True, norm_gain=1.0, norm_shift=0.0) 1278 for _ in range(2) 1279 ]) 1280 h, (s0, s1) = cell(x, (state0, state1)) 1281 sess.run([variables.global_variables_initializer()]) 1282 res = sess.run( 1283 [h, s0, s1], { 1284 x.name: np.array([[1., 1.]]), 1285 c0.name: 0.1 * np.asarray([[0, 1]]), 1286 h0.name: 0.1 * np.asarray([[2, 3]]), 1287 c1.name: 0.1 * np.asarray([[4, 5]]), 1288 h1.name: 0.1 * np.asarray([[6, 7]]), 1289 }) 1290 1291 expected_h = np.array([[-0.38079708, 0.38079708]]) 1292 expected_h0 = np.array([[-0.38079708, 0.38079708]]) 1293 expected_c0 = np.array([[-1.0, 1.0]]) 1294 expected_h1 = np.array([[-0.38079708, 0.38079708]]) 1295 expected_c1 = np.array([[-1.0, 1.0]]) 1296 1297 self.assertEqual(len(res), 3) 1298 self.assertAllClose(res[0], expected_h, 1e-5) 1299 self.assertAllClose(res[1].c, expected_c0, 1e-5) 1300 self.assertAllClose(res[1].h, expected_h0, 1e-5) 1301 self.assertAllClose(res[2].c, expected_c1, 1e-5) 1302 self.assertAllClose(res[2].h, expected_h1, 1e-5) 1303 1304 def testBasicLSTMCellWithDropout(self): 1305 1306 def _is_close(x, y, digits=4): 1307 delta = x - y 1308 return delta < 10**(-digits) 1309 1310 def _is_close_in(x, items, digits=4): 1311 for i in items: 1312 if _is_close(x, i, digits): 1313 return True 1314 return False 1315 1316 keep_prob = 0.5 1317 c_high = 2.9998924946 1318 c_low = 0.999983298578 1319 h_low = 0.761552567265 1320 h_high = 0.995008519604 1321 num_units = 5 1322 allowed_low = [1, 2, 3] 1323 1324 with self.test_session() as sess: 1325 with variable_scope.variable_scope( 1326 "other", initializer=init_ops.constant_initializer(1)): 1327 x = array_ops.zeros([1, 5]) 1328 c = array_ops.zeros([1, 5]) 1329 h = array_ops.zeros([1, 5]) 1330 state = rnn_cell.LSTMStateTuple(c, h) 1331 cell = contrib_rnn_cell.LayerNormBasicLSTMCell( 1332 num_units, layer_norm=False, dropout_keep_prob=keep_prob) 1333 1334 g, s = cell(x, state) 1335 sess.run([variables.global_variables_initializer()]) 1336 res = sess.run( 1337 [g, s], { 1338 x.name: np.ones([1, 5]), 1339 c.name: np.ones([1, 5]), 1340 h.name: np.ones([1, 5]), 1341 }) 1342 1343 # Since the returned tensors are of size [1,n] 1344 # get the first component right now. 1345 actual_h = res[0][0] 1346 actual_state_c = res[1].c[0] 1347 actual_state_h = res[1].h[0] 1348 1349 # For each item in `c` (the cell inner state) check that 1350 # it is equal to one of the allowed values `c_high` (not 1351 # dropped out) or `c_low` (dropped out) and verify that the 1352 # corresponding item in `h` (the cell activation) is coherent. 1353 # Count the dropped activations and check that their number is 1354 # coherent with the dropout probability. 1355 dropped_count = 0 1356 self.assertTrue((actual_h == actual_state_h).all()) 1357 for citem, hitem in zip(actual_state_c, actual_state_h): 1358 self.assertTrue(_is_close_in(citem, [c_low, c_high])) 1359 if _is_close(citem, c_low): 1360 self.assertTrue(_is_close(hitem, h_low)) 1361 dropped_count += 1 1362 elif _is_close(citem, c_high): 1363 self.assertTrue(_is_close(hitem, h_high)) 1364 self.assertIn(dropped_count, allowed_low) 1365 1366 1367 def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, num_layers, 1368 max_time, compiled): 1369 with variable_scope.variable_scope( 1370 "root", 1371 initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)): 1372 inputs = variable_scope.get_variable( 1373 "inputs", 1374 initializer=random_ops.random_uniform( 1375 (max_time, batch_size, input_depth), seed=1)) 1376 maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c 1377 cell = rnn_cell.MultiRNNCell( 1378 [maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)]) 1379 initial_state = cell.zero_state(batch_size=batch_size, dtype=dtypes.float32) 1380 outputs, final_state = rnn.dynamic_rnn( 1381 cell=cell, inputs=inputs, initial_state=initial_state, time_major=True) 1382 flat_final_state = nest.flatten(final_state) 1383 trainable_variables = variables.trainable_variables() 1384 outputs_grad = gradients_impl.gradients( 1385 [outputs], trainable_variables + [inputs] + nest.flatten(initial_state)) 1386 final_state_grad = gradients_impl.gradients( 1387 flat_final_state, 1388 trainable_variables + [inputs] + nest.flatten(initial_state)) 1389 1390 return { 1391 "outputs": outputs, 1392 "final_state": flat_final_state, 1393 "outputs_grad": outputs_grad, 1394 "final_state_grad": final_state_grad 1395 } 1396 1397 1398 class CompiledWrapperTest(test.TestCase): 1399 1400 def testMultiRNNCellWithLSTMCellAndXLA(self): 1401 # TODO(b/34735319): Don't run this test if XLA is not available. 1402 batch_size = 16 1403 num_units = 32 1404 input_depth = 12 1405 num_layers = 2 1406 max_time = 20 1407 1408 atol = 1e-5 1409 1410 random_seed.set_random_seed(1234) 1411 with self.test_session(graph=ops.Graph()) as sess: 1412 xla_ops = _create_multi_lstm_cell_ops( 1413 batch_size=batch_size, 1414 num_units=num_units, 1415 input_depth=input_depth, 1416 num_layers=num_layers, 1417 max_time=max_time, 1418 compiled=True) 1419 sess.run([variables.global_variables_initializer()]) 1420 xla_results = sess.run(xla_ops) 1421 1422 random_seed.set_random_seed(1234) 1423 with self.test_session(graph=ops.Graph()) as sess: 1424 non_xla_ops = _create_multi_lstm_cell_ops( 1425 batch_size=batch_size, 1426 num_units=num_units, 1427 input_depth=input_depth, 1428 num_layers=num_layers, 1429 max_time=max_time, 1430 compiled=False) 1431 sess.run([variables.global_variables_initializer()]) 1432 non_xla_results = sess.run(non_xla_ops) 1433 1434 self.assertAllClose( 1435 non_xla_results["outputs"], xla_results["outputs"], atol=atol) 1436 1437 for xla_value, non_xla_value in zip(xla_results["final_state"], 1438 non_xla_results["final_state"]): 1439 self.assertAllClose(xla_value, non_xla_value, atol=atol) 1440 1441 for xla_g, non_xla_g in zip(xla_results["outputs_grad"], 1442 non_xla_results["outputs_grad"]): 1443 self.assertAllClose(xla_g, non_xla_g, atol=atol) 1444 1445 for xla_g, non_xla_g in zip(xla_results["final_state_grad"], 1446 non_xla_results["final_state_grad"]): 1447 self.assertAllClose(xla_g, non_xla_g, atol=atol) 1448 1449 def testMultiRNNCellWithStateTuple(self): 1450 with self.test_session() as sess: 1451 with variable_scope.variable_scope( 1452 "root", initializer=init_ops.constant_initializer(0.5)): 1453 x = array_ops.zeros([1, 2]) 1454 m_bad = array_ops.zeros([1, 4]) 1455 m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])) 1456 1457 # Test incorrectness of state 1458 with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): 1459 rnn_cell.MultiRNNCell( 1460 [rnn_cell.GRUCell(2) for _ in range(2)], 1461 state_is_tuple=True)(x, m_bad) 1462 1463 _, ml = rnn_cell.MultiRNNCell( 1464 [rnn_cell.GRUCell(2) for _ in range(2)], 1465 state_is_tuple=True)(x, m_good) 1466 1467 sess.run([variables.global_variables_initializer()]) 1468 res = sess.run( 1469 ml, { 1470 x.name: np.array([[1., 1.]]), 1471 m_good[0].name: np.array([[0.1, 0.1]]), 1472 m_good[1].name: np.array([[0.1, 0.1]]) 1473 }) 1474 1475 # The numbers in results were not calculated, this is just a 1476 # smoke test. However, these numbers should match those of 1477 # the test testMultiRNNCell. 1478 self.assertAllClose(res[0], [[0.175991, 0.175991]]) 1479 self.assertAllClose(res[1], [[0.13248, 0.13248]]) 1480 1481 1482 class BenchmarkLSTMCellXLA(test.Benchmark): 1483 1484 def benchmarkDynamicRNNWithMultiLSTMCell(self): 1485 num_layers = 3 1486 max_time = 50 1487 print("benchmarkDynamicRNNWithMultiLSTMCell") 1488 print("\t" + "\t".join([ 1489 "inter_th", "intra_th", "batch_size", "num_units", "input_depth", 1490 "device", "compiled", "wall_time" 1491 ])) 1492 1493 warmup_run = True 1494 for (threads, device, num_units, batch_size, input_depth, 1495 compiled) in itertools.product([{ 1496 "inter": 0, 1497 "intra": 0 1498 }, { 1499 "inter": 1, 1500 "intra": 4 1501 }], ["cpu", "gpu"], [32, 512], [1, 32, 256], [32, 512], [False, True]): 1502 if threads["inter"] != 0: 1503 # We only care about testing inter/intra op limitations on 1504 # CPU with small batch size, to mimic embedded devices. 1505 if device != "cpu" or batch_size != 1: 1506 continue 1507 if device == "cpu" and batch_size > 32: 1508 continue 1509 random_seed.set_random_seed(1234) 1510 config = config_pb2.ConfigProto( 1511 inter_op_parallelism_threads=threads["inter"], 1512 intra_op_parallelism_threads=threads["intra"], 1513 allow_soft_placement=False) 1514 with session.Session(config=config, graph=ops.Graph()) as sess: 1515 with ops.device("/%s:0" % device): 1516 ops_dict = _create_multi_lstm_cell_ops( 1517 batch_size=batch_size, 1518 num_units=num_units, 1519 input_depth=input_depth, 1520 num_layers=num_layers, 1521 max_time=max_time, 1522 compiled=compiled) 1523 sess.run([variables.global_variables_initializer()]) 1524 all_ops = nest.flatten(ops_dict.values()) 1525 all_ops_group = control_flow_ops.group(*all_ops) 1526 name_suffix = ("inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d" 1527 "_device_%s_xla_%s" % 1528 (threads["inter"], threads["intra"], batch_size, 1529 num_units, input_depth, device, compiled)) 1530 if warmup_run: 1531 self.run_op_benchmark( 1532 sess, all_ops_group, min_iters=30, name="ignore_warmup") 1533 warmup_run = False 1534 benchmark_results = self.run_op_benchmark( 1535 sess, 1536 all_ops_group, 1537 min_iters=50, 1538 name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix) 1539 print("\t" + "\t".join([ 1540 "%s" % x 1541 for x in [ 1542 threads["inter"], threads["intra"], batch_size, num_units, 1543 input_depth, device, compiled, benchmark_results["wall_time"] 1544 ] 1545 ])) 1546 1547 1548 class WeightNormLSTMCellTest(test.TestCase): 1549 """Compared cell output with pre-calculated values.""" 1550 1551 def _cell_output(self, cell): 1552 """Calculate cell output""" 1553 1554 with self.test_session() as sess: 1555 init = init_ops.constant_initializer(0.5) 1556 with variable_scope.variable_scope("root", 1557 initializer=init): 1558 x = array_ops.zeros([1, 2]) 1559 c0 = array_ops.zeros([1, 2]) 1560 h0 = array_ops.zeros([1, 2]) 1561 1562 state0 = rnn_cell.LSTMStateTuple(c0, h0) 1563 1564 xout, sout = cell()(x, state0) 1565 1566 sess.run([variables.global_variables_initializer()]) 1567 res = sess.run([xout, sout], { 1568 x.name: np.array([[1., 1.]]), 1569 c0.name: 0.1 * np.asarray([[0, 1]]), 1570 h0.name: 0.1 * np.asarray([[2, 3]]), 1571 }) 1572 1573 actual_state_c = res[1].c 1574 actual_state_h = res[1].h 1575 1576 return actual_state_c, actual_state_h 1577 1578 def testBasicCell(self): 1579 """Tests cell w/o peepholes and w/o normalisation""" 1580 1581 def cell(): 1582 return contrib_rnn_cell.WeightNormLSTMCell(2, 1583 norm=False, 1584 use_peepholes=False) 1585 1586 actual_c, actual_h = self._cell_output(cell) 1587 1588 expected_c = np.array([[0.65937078, 0.74983585]]) 1589 expected_h = np.array([[0.44923624, 0.49362513]]) 1590 1591 self.assertAllClose(expected_c, actual_c, 1e-5) 1592 self.assertAllClose(expected_h, actual_h, 1e-5) 1593 1594 def testNonbasicCell(self): 1595 """Tests cell with peepholes and w/o normalisation""" 1596 1597 def cell(): 1598 return contrib_rnn_cell.WeightNormLSTMCell(2, 1599 norm=False, 1600 use_peepholes=True) 1601 1602 actual_c, actual_h = self._cell_output(cell) 1603 1604 expected_c = np.array([[0.65937084, 0.7574988]]) 1605 expected_h = np.array([[0.4792085, 0.53470564]]) 1606 1607 self.assertAllClose(expected_c, actual_c, 1e-5) 1608 self.assertAllClose(expected_h, actual_h, 1e-5) 1609 1610 1611 def testBasicCellWithNorm(self): 1612 """Tests cell w/o peepholes and with normalisation""" 1613 1614 def cell(): 1615 return contrib_rnn_cell.WeightNormLSTMCell(2, 1616 norm=True, 1617 use_peepholes=False) 1618 1619 actual_c, actual_h = self._cell_output(cell) 1620 1621 expected_c = np.array([[0.50125383, 0.58805949]]) 1622 expected_h = np.array([[0.32770363, 0.37397948]]) 1623 1624 self.assertAllClose(expected_c, actual_c, 1e-5) 1625 self.assertAllClose(expected_h, actual_h, 1e-5) 1626 1627 def testNonBasicCellWithNorm(self): 1628 """Tests cell with peepholes and with normalisation""" 1629 1630 def cell(): 1631 return contrib_rnn_cell.WeightNormLSTMCell(2, 1632 norm=True, 1633 use_peepholes=True) 1634 1635 actual_c, actual_h = self._cell_output(cell) 1636 1637 expected_c = np.array([[0.50125383, 0.59587258]]) 1638 expected_h = np.array([[0.35041603, 0.40873795]]) 1639 1640 self.assertAllClose(expected_c, actual_c, 1e-5) 1641 self.assertAllClose(expected_h, actual_h, 1e-5) 1642 1643 if __name__ == "__main__": 1644 test.main() 1645