Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for functional style sequence-to-sequence models."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 import math
     23 import random
     24 
     25 import numpy as np
     26 
     27 from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
     28 from tensorflow.contrib.rnn.python.ops import core_rnn_cell
     29 from tensorflow.python.framework import constant_op
     30 from tensorflow.python.framework import dtypes
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import random_seed
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import clip_ops
     35 from tensorflow.python.ops import gradients_impl
     36 from tensorflow.python.ops import init_ops
     37 from tensorflow.python.ops import nn_impl
     38 from tensorflow.python.ops import rnn
     39 from tensorflow.python.ops import rnn_cell
     40 from tensorflow.python.ops import state_ops
     41 from tensorflow.python.ops import variable_scope
     42 from tensorflow.python.ops import variables
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.training import adam
     45 
     46 
     47 class Seq2SeqTest(test.TestCase):
     48 
     49   def testRNNDecoder(self):
     50     with self.test_session() as sess:
     51       with variable_scope.variable_scope(
     52           "root", initializer=init_ops.constant_initializer(0.5)):
     53         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
     54         _, enc_state = rnn.static_rnn(
     55             rnn_cell.GRUCell(2), inp, dtype=dtypes.float32)
     56         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
     57         cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
     58         dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
     59         sess.run([variables.global_variables_initializer()])
     60         res = sess.run(dec)
     61         self.assertEqual(3, len(res))
     62         self.assertEqual((2, 4), res[0].shape)
     63 
     64         res = sess.run([mem])
     65         self.assertEqual((2, 2), res[0].shape)
     66 
     67   def testBasicRNNSeq2Seq(self):
     68     with self.test_session() as sess:
     69       with variable_scope.variable_scope(
     70           "root", initializer=init_ops.constant_initializer(0.5)):
     71         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
     72         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
     73         cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
     74         dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
     75         sess.run([variables.global_variables_initializer()])
     76         res = sess.run(dec)
     77         self.assertEqual(3, len(res))
     78         self.assertEqual((2, 4), res[0].shape)
     79 
     80         res = sess.run([mem])
     81         self.assertEqual((2, 2), res[0].shape)
     82 
     83   def testTiedRNNSeq2Seq(self):
     84     with self.test_session() as sess:
     85       with variable_scope.variable_scope(
     86           "root", initializer=init_ops.constant_initializer(0.5)):
     87         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
     88         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
     89         cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
     90         dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
     91         sess.run([variables.global_variables_initializer()])
     92         res = sess.run(dec)
     93         self.assertEqual(3, len(res))
     94         self.assertEqual((2, 4), res[0].shape)
     95 
     96         res = sess.run([mem])
     97         self.assertEqual(1, len(res))
     98         self.assertEqual((2, 2), res[0].shape)
     99 
    100   def testEmbeddingRNNDecoder(self):
    101     with self.test_session() as sess:
    102       with variable_scope.variable_scope(
    103           "root", initializer=init_ops.constant_initializer(0.5)):
    104         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
    105         cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
    106         cell = cell_fn()
    107         _, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
    108         dec_inp = [
    109             constant_op.constant(
    110                 i, dtypes.int32, shape=[2]) for i in range(3)
    111         ]
    112         # Use a new cell instance since the attention decoder uses a
    113         # different variable scope.
    114         dec, mem = seq2seq_lib.embedding_rnn_decoder(
    115             dec_inp, enc_state, cell_fn(), num_symbols=4, embedding_size=2)
    116         sess.run([variables.global_variables_initializer()])
    117         res = sess.run(dec)
    118         self.assertEqual(3, len(res))
    119         self.assertEqual((2, 2), res[0].shape)
    120 
    121         res = sess.run([mem])
    122         self.assertEqual(1, len(res))
    123         self.assertEqual((2, 2), res[0].c.shape)
    124         self.assertEqual((2, 2), res[0].h.shape)
    125 
    126   def testEmbeddingRNNSeq2Seq(self):
    127     with self.test_session() as sess:
    128       with variable_scope.variable_scope(
    129           "root", initializer=init_ops.constant_initializer(0.5)):
    130         enc_inp = [
    131             constant_op.constant(
    132                 1, dtypes.int32, shape=[2]) for i in range(2)
    133         ]
    134         dec_inp = [
    135             constant_op.constant(
    136                 i, dtypes.int32, shape=[2]) for i in range(3)
    137         ]
    138         cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
    139         cell = cell_fn()
    140         dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
    141             enc_inp,
    142             dec_inp,
    143             cell,
    144             num_encoder_symbols=2,
    145             num_decoder_symbols=5,
    146             embedding_size=2)
    147         sess.run([variables.global_variables_initializer()])
    148         res = sess.run(dec)
    149         self.assertEqual(3, len(res))
    150         self.assertEqual((2, 5), res[0].shape)
    151 
    152         res = sess.run([mem])
    153         self.assertEqual((2, 2), res[0].c.shape)
    154         self.assertEqual((2, 2), res[0].h.shape)
    155 
    156         # Test with state_is_tuple=False.
    157         with variable_scope.variable_scope("no_tuple"):
    158           cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
    159           dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
    160               enc_inp,
    161               dec_inp,
    162               cell_nt,
    163               num_encoder_symbols=2,
    164               num_decoder_symbols=5,
    165               embedding_size=2)
    166           sess.run([variables.global_variables_initializer()])
    167           res = sess.run(dec)
    168           self.assertEqual(3, len(res))
    169           self.assertEqual((2, 5), res[0].shape)
    170 
    171           res = sess.run([mem])
    172           self.assertEqual((2, 4), res[0].shape)
    173 
    174         # Test externally provided output projection.
    175         w = variable_scope.get_variable("proj_w", [2, 5])
    176         b = variable_scope.get_variable("proj_b", [5])
    177         with variable_scope.variable_scope("proj_seq2seq"):
    178           dec, _ = seq2seq_lib.embedding_rnn_seq2seq(
    179               enc_inp,
    180               dec_inp,
    181               cell_fn(),
    182               num_encoder_symbols=2,
    183               num_decoder_symbols=5,
    184               embedding_size=2,
    185               output_projection=(w, b))
    186         sess.run([variables.global_variables_initializer()])
    187         res = sess.run(dec)
    188         self.assertEqual(3, len(res))
    189         self.assertEqual((2, 2), res[0].shape)
    190 
    191         # Test that previous-feeding model ignores inputs after the first.
    192         dec_inp2 = [
    193             constant_op.constant(
    194                 0, dtypes.int32, shape=[2]) for _ in range(3)
    195         ]
    196         with variable_scope.variable_scope("other"):
    197           d3, _ = seq2seq_lib.embedding_rnn_seq2seq(
    198               enc_inp,
    199               dec_inp2,
    200               cell_fn(),
    201               num_encoder_symbols=2,
    202               num_decoder_symbols=5,
    203               embedding_size=2,
    204               feed_previous=constant_op.constant(True))
    205         with variable_scope.variable_scope("other_2"):
    206           d1, _ = seq2seq_lib.embedding_rnn_seq2seq(
    207               enc_inp,
    208               dec_inp,
    209               cell_fn(),
    210               num_encoder_symbols=2,
    211               num_decoder_symbols=5,
    212               embedding_size=2,
    213               feed_previous=True)
    214         with variable_scope.variable_scope("other_3"):
    215           d2, _ = seq2seq_lib.embedding_rnn_seq2seq(
    216               enc_inp,
    217               dec_inp2,
    218               cell_fn(),
    219               num_encoder_symbols=2,
    220               num_decoder_symbols=5,
    221               embedding_size=2,
    222               feed_previous=True)
    223         sess.run([variables.global_variables_initializer()])
    224         res1 = sess.run(d1)
    225         res2 = sess.run(d2)
    226         res3 = sess.run(d3)
    227         self.assertAllClose(res1, res2)
    228         self.assertAllClose(res1, res3)
    229 
    230   def testEmbeddingTiedRNNSeq2Seq(self):
    231     with self.test_session() as sess:
    232       with variable_scope.variable_scope(
    233           "root", initializer=init_ops.constant_initializer(0.5)):
    234         enc_inp = [
    235             constant_op.constant(
    236                 1, dtypes.int32, shape=[2]) for i in range(2)
    237         ]
    238         dec_inp = [
    239             constant_op.constant(
    240                 i, dtypes.int32, shape=[2]) for i in range(3)
    241         ]
    242         cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True)
    243         dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
    244             enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
    245         sess.run([variables.global_variables_initializer()])
    246         res = sess.run(dec)
    247         self.assertEqual(3, len(res))
    248         self.assertEqual((2, 5), res[0].shape)
    249 
    250         res = sess.run([mem])
    251         self.assertEqual((2, 2), res[0].c.shape)
    252         self.assertEqual((2, 2), res[0].h.shape)
    253 
    254         # Test when num_decoder_symbols is provided, the size of decoder output
    255         # is num_decoder_symbols.
    256         with variable_scope.variable_scope("decoder_symbols_seq2seq"):
    257           dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
    258               enc_inp,
    259               dec_inp,
    260               cell(),
    261               num_symbols=5,
    262               num_decoder_symbols=3,
    263               embedding_size=2)
    264         sess.run([variables.global_variables_initializer()])
    265         res = sess.run(dec)
    266         self.assertEqual(3, len(res))
    267         self.assertEqual((2, 3), res[0].shape)
    268 
    269         # Test externally provided output projection.
    270         w = variable_scope.get_variable("proj_w", [2, 5])
    271         b = variable_scope.get_variable("proj_b", [5])
    272         with variable_scope.variable_scope("proj_seq2seq"):
    273           dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
    274               enc_inp,
    275               dec_inp,
    276               cell(),
    277               num_symbols=5,
    278               embedding_size=2,
    279               output_projection=(w, b))
    280         sess.run([variables.global_variables_initializer()])
    281         res = sess.run(dec)
    282         self.assertEqual(3, len(res))
    283         self.assertEqual((2, 2), res[0].shape)
    284 
    285         # Test that previous-feeding model ignores inputs after the first.
    286         dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3
    287         with variable_scope.variable_scope("other"):
    288           d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
    289               enc_inp,
    290               dec_inp2,
    291               cell(),
    292               num_symbols=5,
    293               embedding_size=2,
    294               feed_previous=constant_op.constant(True))
    295         with variable_scope.variable_scope("other_2"):
    296           d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
    297               enc_inp,
    298               dec_inp,
    299               cell(),
    300               num_symbols=5,
    301               embedding_size=2,
    302               feed_previous=True)
    303         with variable_scope.variable_scope("other_3"):
    304           d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
    305               enc_inp,
    306               dec_inp2,
    307               cell(),
    308               num_symbols=5,
    309               embedding_size=2,
    310               feed_previous=True)
    311         sess.run([variables.global_variables_initializer()])
    312         res1 = sess.run(d1)
    313         res2 = sess.run(d2)
    314         res3 = sess.run(d3)
    315         self.assertAllClose(res1, res2)
    316         self.assertAllClose(res1, res3)
    317 
    318   def testAttentionDecoder1(self):
    319     with self.test_session() as sess:
    320       with variable_scope.variable_scope(
    321           "root", initializer=init_ops.constant_initializer(0.5)):
    322         cell_fn = lambda: rnn_cell.GRUCell(2)
    323         cell = cell_fn()
    324         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
    325         enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
    326         attn_states = array_ops.concat([
    327             array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
    328         ], 1)
    329         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
    330 
    331         # Create a new cell instance for the decoder, since it uses a
    332         # different variable scope
    333         dec, mem = seq2seq_lib.attention_decoder(
    334             dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
    335         sess.run([variables.global_variables_initializer()])
    336         res = sess.run(dec)
    337         self.assertEqual(3, len(res))
    338         self.assertEqual((2, 4), res[0].shape)
    339 
    340         res = sess.run([mem])
    341         self.assertEqual((2, 2), res[0].shape)
    342 
    343   def testAttentionDecoder2(self):
    344     with self.test_session() as sess:
    345       with variable_scope.variable_scope(
    346           "root", initializer=init_ops.constant_initializer(0.5)):
    347         cell_fn = lambda: rnn_cell.GRUCell(2)
    348         cell = cell_fn()
    349         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
    350         enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
    351         attn_states = array_ops.concat([
    352             array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
    353         ], 1)
    354         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
    355 
    356         # Use a new cell instance since the attention decoder uses a
    357         # different variable scope.
    358         dec, mem = seq2seq_lib.attention_decoder(
    359             dec_inp, enc_state, attn_states, cell_fn(),
    360             output_size=4, num_heads=2)
    361         sess.run([variables.global_variables_initializer()])
    362         res = sess.run(dec)
    363         self.assertEqual(3, len(res))
    364         self.assertEqual((2, 4), res[0].shape)
    365 
    366         res = sess.run([mem])
    367         self.assertEqual((2, 2), res[0].shape)
    368 
    369   def testDynamicAttentionDecoder1(self):
    370     with self.test_session() as sess:
    371       with variable_scope.variable_scope(
    372           "root", initializer=init_ops.constant_initializer(0.5)):
    373         cell_fn = lambda: rnn_cell.GRUCell(2)
    374         cell = cell_fn()
    375         inp = constant_op.constant(0.5, shape=[2, 2, 2])
    376         enc_outputs, enc_state = rnn.dynamic_rnn(
    377             cell, inp, dtype=dtypes.float32)
    378         attn_states = enc_outputs
    379         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
    380 
    381         # Use a new cell instance since the attention decoder uses a
    382         # different variable scope.
    383         dec, mem = seq2seq_lib.attention_decoder(
    384             dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
    385         sess.run([variables.global_variables_initializer()])
    386         res = sess.run(dec)
    387         self.assertEqual(3, len(res))
    388         self.assertEqual((2, 4), res[0].shape)
    389 
    390         res = sess.run([mem])
    391         self.assertEqual((2, 2), res[0].shape)
    392 
    393   def testDynamicAttentionDecoder2(self):
    394     with self.test_session() as sess:
    395       with variable_scope.variable_scope(
    396           "root", initializer=init_ops.constant_initializer(0.5)):
    397         cell_fn = lambda: rnn_cell.GRUCell(2)
    398         cell = cell_fn()
    399         inp = constant_op.constant(0.5, shape=[2, 2, 2])
    400         enc_outputs, enc_state = rnn.dynamic_rnn(
    401             cell, inp, dtype=dtypes.float32)
    402         attn_states = enc_outputs
    403         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
    404 
    405         # Use a new cell instance since the attention decoder uses a
    406         # different variable scope.
    407         dec, mem = seq2seq_lib.attention_decoder(
    408             dec_inp, enc_state, attn_states, cell_fn(),
    409             output_size=4, num_heads=2)
    410         sess.run([variables.global_variables_initializer()])
    411         res = sess.run(dec)
    412         self.assertEqual(3, len(res))
    413         self.assertEqual((2, 4), res[0].shape)
    414 
    415         res = sess.run([mem])
    416         self.assertEqual((2, 2), res[0].shape)
    417 
    418   def testAttentionDecoderStateIsTuple(self):
    419     with self.test_session() as sess:
    420       with variable_scope.variable_scope(
    421           "root", initializer=init_ops.constant_initializer(0.5)):
    422         single_cell = lambda: rnn_cell.BasicLSTMCell(  # pylint: disable=g-long-lambda
    423             2, state_is_tuple=True)
    424         cell_fn = lambda: rnn_cell.MultiRNNCell(  # pylint: disable=g-long-lambda
    425             cells=[single_cell() for _ in range(2)], state_is_tuple=True)
    426         cell = cell_fn()
    427         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
    428         enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
    429         attn_states = array_ops.concat([
    430             array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
    431         ], 1)
    432         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
    433 
    434         # Use a new cell instance since the attention decoder uses a
    435         # different variable scope.
    436         dec, mem = seq2seq_lib.attention_decoder(
    437             dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
    438         sess.run([variables.global_variables_initializer()])
    439         res = sess.run(dec)
    440         self.assertEqual(3, len(res))
    441         self.assertEqual((2, 4), res[0].shape)
    442 
    443         res = sess.run([mem])
    444         self.assertEqual(2, len(res[0]))
    445         self.assertEqual((2, 2), res[0][0].c.shape)
    446         self.assertEqual((2, 2), res[0][0].h.shape)
    447         self.assertEqual((2, 2), res[0][1].c.shape)
    448         self.assertEqual((2, 2), res[0][1].h.shape)
    449 
    450   def testDynamicAttentionDecoderStateIsTuple(self):
    451     with self.test_session() as sess:
    452       with variable_scope.variable_scope(
    453           "root", initializer=init_ops.constant_initializer(0.5)):
    454         cell_fn = lambda: rnn_cell.MultiRNNCell(  # pylint: disable=g-long-lambda
    455             cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)])
    456         cell = cell_fn()
    457         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
    458         enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
    459         attn_states = array_ops.concat([
    460             array_ops.reshape(e, [-1, 1, cell.output_size])
    461             for e in enc_outputs
    462         ], 1)
    463         dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
    464 
    465         # Use a new cell instance since the attention decoder uses a
    466         # different variable scope.
    467         dec, mem = seq2seq_lib.attention_decoder(
    468             dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
    469         sess.run([variables.global_variables_initializer()])
    470         res = sess.run(dec)
    471         self.assertEqual(3, len(res))
    472         self.assertEqual((2, 4), res[0].shape)
    473 
    474         res = sess.run([mem])
    475         self.assertEqual(2, len(res[0]))
    476         self.assertEqual((2, 2), res[0][0].c.shape)
    477         self.assertEqual((2, 2), res[0][0].h.shape)
    478         self.assertEqual((2, 2), res[0][1].c.shape)
    479         self.assertEqual((2, 2), res[0][1].h.shape)
    480 
    481   def testEmbeddingAttentionDecoder(self):
    482     with self.test_session() as sess:
    483       with variable_scope.variable_scope(
    484           "root", initializer=init_ops.constant_initializer(0.5)):
    485         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
    486         cell_fn = lambda: rnn_cell.GRUCell(2)
    487         cell = cell_fn()
    488         enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
    489         attn_states = array_ops.concat([
    490             array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
    491         ], 1)
    492         dec_inp = [
    493             constant_op.constant(
    494                 i, dtypes.int32, shape=[2]) for i in range(3)
    495         ]
    496 
    497         # Use a new cell instance since the attention decoder uses a
    498         # different variable scope.
    499         dec, mem = seq2seq_lib.embedding_attention_decoder(
    500             dec_inp,
    501             enc_state,
    502             attn_states,
    503             cell_fn(),
    504             num_symbols=4,
    505             embedding_size=2,
    506             output_size=3)
    507         sess.run([variables.global_variables_initializer()])
    508         res = sess.run(dec)
    509         self.assertEqual(3, len(res))
    510         self.assertEqual((2, 3), res[0].shape)
    511 
    512         res = sess.run([mem])
    513         self.assertEqual((2, 2), res[0].shape)
    514 
    515   def testEmbeddingAttentionSeq2Seq(self):
    516     with self.test_session() as sess:
    517       with variable_scope.variable_scope(
    518           "root", initializer=init_ops.constant_initializer(0.5)):
    519         enc_inp = [
    520             constant_op.constant(
    521                 1, dtypes.int32, shape=[2]) for i in range(2)
    522         ]
    523         dec_inp = [
    524             constant_op.constant(
    525                 i, dtypes.int32, shape=[2]) for i in range(3)
    526         ]
    527         cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
    528         cell = cell_fn()
    529         dec, mem = seq2seq_lib.embedding_attention_seq2seq(
    530             enc_inp,
    531             dec_inp,
    532             cell,
    533             num_encoder_symbols=2,
    534             num_decoder_symbols=5,
    535             embedding_size=2)
    536         sess.run([variables.global_variables_initializer()])
    537         res = sess.run(dec)
    538         self.assertEqual(3, len(res))
    539         self.assertEqual((2, 5), res[0].shape)
    540 
    541         res = sess.run([mem])
    542         self.assertEqual((2, 2), res[0].c.shape)
    543         self.assertEqual((2, 2), res[0].h.shape)
    544 
    545         # Test with state_is_tuple=False.
    546         with variable_scope.variable_scope("no_tuple"):
    547           cell_fn = functools.partial(
    548               rnn_cell.BasicLSTMCell, 2, state_is_tuple=False)
    549           cell_nt = cell_fn()
    550           dec, mem = seq2seq_lib.embedding_attention_seq2seq(
    551               enc_inp,
    552               dec_inp,
    553               cell_nt,
    554               num_encoder_symbols=2,
    555               num_decoder_symbols=5,
    556               embedding_size=2)
    557           sess.run([variables.global_variables_initializer()])
    558           res = sess.run(dec)
    559           self.assertEqual(3, len(res))
    560           self.assertEqual((2, 5), res[0].shape)
    561 
    562           res = sess.run([mem])
    563           self.assertEqual((2, 4), res[0].shape)
    564 
    565         # Test externally provided output projection.
    566         w = variable_scope.get_variable("proj_w", [2, 5])
    567         b = variable_scope.get_variable("proj_b", [5])
    568         with variable_scope.variable_scope("proj_seq2seq"):
    569           dec, _ = seq2seq_lib.embedding_attention_seq2seq(
    570               enc_inp,
    571               dec_inp,
    572               cell_fn(),
    573               num_encoder_symbols=2,
    574               num_decoder_symbols=5,
    575               embedding_size=2,
    576               output_projection=(w, b))
    577         sess.run([variables.global_variables_initializer()])
    578         res = sess.run(dec)
    579         self.assertEqual(3, len(res))
    580         self.assertEqual((2, 2), res[0].shape)
    581 
    582         # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse
    583         # within a variable scope that already has a weights tensor.
    584         #
    585         # # Test that previous-feeding model ignores inputs after the first.
    586         # dec_inp2 = [
    587         #     constant_op.constant(
    588         #         0, dtypes.int32, shape=[2]) for _ in range(3)
    589         # ]
    590         # with variable_scope.variable_scope("other"):
    591         #   d3, _ = seq2seq_lib.embedding_attention_seq2seq(
    592         #       enc_inp,
    593         #       dec_inp2,
    594         #       cell_fn(),
    595         #       num_encoder_symbols=2,
    596         #       num_decoder_symbols=5,
    597         #       embedding_size=2,
    598         #       feed_previous=constant_op.constant(True))
    599         # sess.run([variables.global_variables_initializer()])
    600         # variable_scope.get_variable_scope().reuse_variables()
    601         # cell = cell_fn()
    602         # d1, _ = seq2seq_lib.embedding_attention_seq2seq(
    603         #     enc_inp,
    604         #     dec_inp,
    605         #     cell,
    606         #     num_encoder_symbols=2,
    607         #     num_decoder_symbols=5,
    608         #     embedding_size=2,
    609         #     feed_previous=True)
    610         # d2, _ = seq2seq_lib.embedding_attention_seq2seq(
    611         #     enc_inp,
    612         #     dec_inp2,
    613         #     cell,
    614         #     num_encoder_symbols=2,
    615         #     num_decoder_symbols=5,
    616         #     embedding_size=2,
    617         #     feed_previous=True)
    618         # res1 = sess.run(d1)
    619         # res2 = sess.run(d2)
    620         # res3 = sess.run(d3)
    621         # self.assertAllClose(res1, res2)
    622         # self.assertAllClose(res1, res3)
    623 
    624   def testOne2ManyRNNSeq2Seq(self):
    625     with self.test_session() as sess:
    626       with variable_scope.variable_scope(
    627           "root", initializer=init_ops.constant_initializer(0.5)):
    628         enc_inp = [
    629             constant_op.constant(
    630                 1, dtypes.int32, shape=[2]) for i in range(2)
    631         ]
    632         dec_inp_dict = {}
    633         dec_inp_dict["0"] = [
    634             constant_op.constant(
    635                 i, dtypes.int32, shape=[2]) for i in range(3)
    636         ]
    637         dec_inp_dict["1"] = [
    638             constant_op.constant(
    639                 i, dtypes.int32, shape=[2]) for i in range(4)
    640         ]
    641         dec_symbols_dict = {"0": 5, "1": 6}
    642         def EncCellFn():
    643           return rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
    644         def DecCellsFn():
    645           return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True))
    646                       for k in dec_symbols_dict)
    647         outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
    648             enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
    649             2, dec_symbols_dict, embedding_size=2))
    650 
    651         sess.run([variables.global_variables_initializer()])
    652         res = sess.run(outputs_dict["0"])
    653         self.assertEqual(3, len(res))
    654         self.assertEqual((2, 5), res[0].shape)
    655         res = sess.run(outputs_dict["1"])
    656         self.assertEqual(4, len(res))
    657         self.assertEqual((2, 6), res[0].shape)
    658         res = sess.run([state_dict["0"]])
    659         self.assertEqual((2, 2), res[0].c.shape)
    660         self.assertEqual((2, 2), res[0].h.shape)
    661         res = sess.run([state_dict["1"]])
    662         self.assertEqual((2, 2), res[0].c.shape)
    663         self.assertEqual((2, 2), res[0].h.shape)
    664 
    665         # Test that previous-feeding model ignores inputs after the first, i.e.
    666         # dec_inp_dict2 has different inputs from dec_inp_dict after the first
    667         # time-step.
    668         dec_inp_dict2 = {}
    669         dec_inp_dict2["0"] = [
    670             constant_op.constant(
    671                 0, dtypes.int32, shape=[2]) for _ in range(3)
    672         ]
    673         dec_inp_dict2["1"] = [
    674             constant_op.constant(
    675                 0, dtypes.int32, shape=[2]) for _ in range(4)
    676         ]
    677         with variable_scope.variable_scope("other"):
    678           outputs_dict3, _ = seq2seq_lib.one2many_rnn_seq2seq(
    679               enc_inp,
    680               dec_inp_dict2,
    681               EncCellFn(),
    682               DecCellsFn(),
    683               2,
    684               dec_symbols_dict,
    685               embedding_size=2,
    686               feed_previous=constant_op.constant(True))
    687         with variable_scope.variable_scope("other_2"):
    688           outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq(
    689               enc_inp,
    690               dec_inp_dict,
    691               EncCellFn(),
    692               DecCellsFn(),
    693               2,
    694               dec_symbols_dict,
    695               embedding_size=2,
    696               feed_previous=True)
    697         with variable_scope.variable_scope("other_3"):
    698           outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq(
    699               enc_inp,
    700               dec_inp_dict2,
    701               EncCellFn(),
    702               DecCellsFn(),
    703               2,
    704               dec_symbols_dict,
    705               embedding_size=2,
    706               feed_previous=True)
    707         sess.run([variables.global_variables_initializer()])
    708         res1 = sess.run(outputs_dict1["0"])
    709         res2 = sess.run(outputs_dict2["0"])
    710         res3 = sess.run(outputs_dict3["0"])
    711         self.assertAllClose(res1, res2)
    712         self.assertAllClose(res1, res3)
    713 
    714   def testSequenceLoss(self):
    715     with self.test_session() as sess:
    716       logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
    717       targets = [
    718           constant_op.constant(
    719               i, dtypes.int32, shape=[2]) for i in range(3)
    720       ]
    721       weights = [constant_op.constant(1.0, shape=[2]) for i in range(3)]
    722 
    723       average_loss_per_example = seq2seq_lib.sequence_loss(
    724           logits,
    725           targets,
    726           weights,
    727           average_across_timesteps=True,
    728           average_across_batch=True)
    729       res = sess.run(average_loss_per_example)
    730       self.assertAllClose(1.60944, res)
    731 
    732       average_loss_per_sequence = seq2seq_lib.sequence_loss(
    733           logits,
    734           targets,
    735           weights,
    736           average_across_timesteps=False,
    737           average_across_batch=True)
    738       res = sess.run(average_loss_per_sequence)
    739       self.assertAllClose(4.828314, res)
    740 
    741       total_loss = seq2seq_lib.sequence_loss(
    742           logits,
    743           targets,
    744           weights,
    745           average_across_timesteps=False,
    746           average_across_batch=False)
    747       res = sess.run(total_loss)
    748       self.assertAllClose(9.656628, res)
    749 
    750   def testSequenceLossByExample(self):
    751     with self.test_session() as sess:
    752       output_classes = 5
    753       logits = [
    754           constant_op.constant(
    755               i + 0.5, shape=[2, output_classes]) for i in range(3)
    756       ]
    757       targets = [
    758           constant_op.constant(
    759               i, dtypes.int32, shape=[2]) for i in range(3)
    760       ]
    761       weights = [constant_op.constant(1.0, shape=[2]) for i in range(3)]
    762 
    763       average_loss_per_example = (seq2seq_lib.sequence_loss_by_example(
    764           logits, targets, weights, average_across_timesteps=True))
    765       res = sess.run(average_loss_per_example)
    766       self.assertAllClose(np.asarray([1.609438, 1.609438]), res)
    767 
    768       loss_per_sequence = seq2seq_lib.sequence_loss_by_example(
    769           logits, targets, weights, average_across_timesteps=False)
    770       res = sess.run(loss_per_sequence)
    771       self.assertAllClose(np.asarray([4.828314, 4.828314]), res)
    772 
    773   # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse
    774   # within a variable scope that already has a weights tensor.
    775   #
    776   # def testModelWithBucketsScopeAndLoss(self):
    777   #   """Test variable scope reuse is not reset after model_with_buckets."""
    778   #   classes = 10
    779   #   buckets = [(4, 4), (8, 8)]
    780 
    781   #   with self.test_session():
    782   #     # Here comes a sample Seq2Seq model using GRU cells.
    783   #     def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
    784   #       """Example sequence-to-sequence model that uses GRU cells."""
    785 
    786   #       def GRUSeq2Seq(enc_inp, dec_inp):
    787   #         cell = rnn_cell.MultiRNNCell(
    788   #             [rnn_cell.GRUCell(24) for _ in range(2)])
    789   #         return seq2seq_lib.embedding_attention_seq2seq(
    790   #             enc_inp,
    791   #             dec_inp,
    792   #             cell,
    793   #             num_encoder_symbols=classes,
    794   #             num_decoder_symbols=classes,
    795   #             embedding_size=24)
    796 
    797   #       targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0]
    798   #       return seq2seq_lib.model_with_buckets(
    799   #           enc_inp,
    800   #           dec_inp,
    801   #           targets,
    802   #           weights,
    803   #           buckets,
    804   #           GRUSeq2Seq,
    805   #           per_example_loss=per_example_loss)
    806 
    807   #     # Now we construct the copy model.
    808   #     inp = [
    809   #         array_ops.placeholder(
    810   #             dtypes.int32, shape=[None]) for _ in range(8)
    811   #     ]
    812   #     out = [
    813   #         array_ops.placeholder(
    814   #             dtypes.int32, shape=[None]) for _ in range(8)
    815   #     ]
    816   #     weights = [
    817   #         array_ops.ones_like(
    818   #             inp[0], dtype=dtypes.float32) for _ in range(8)
    819   #     ]
    820   #     with variable_scope.variable_scope("root"):
    821   #       _, losses1 = SampleGRUSeq2Seq(
    822   #           inp, out, weights, per_example_loss=False)
    823   #       # Now check that we did not accidentally set reuse.
    824   #       self.assertEqual(False, variable_scope.get_variable_scope().reuse)
    825   #     with variable_scope.variable_scope("new"):
    826   #       _, losses2 = SampleGRUSeq2Seq
    827   #           inp, out, weights, per_example_loss=True)
    828   #       # First loss is scalar, the second one is a 1-dimensional tensor.
    829   #       self.assertEqual([], losses1[0].get_shape().as_list())
    830   #       self.assertEqual([None], losses2[0].get_shape().as_list())
    831 
    832   def testModelWithBuckets(self):
    833     """Larger tests that does full sequence-to-sequence model training."""
    834     # We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
    835     classes = 10
    836     buckets = [(4, 4), (8, 8)]
    837     perplexities = [[], []]  # Results for each bucket.
    838     random_seed.set_random_seed(111)
    839     random.seed(111)
    840     np.random.seed(111)
    841 
    842     with self.test_session() as sess:
    843       # We use sampled softmax so we keep output projection separate.
    844       w = variable_scope.get_variable("proj_w", [24, classes])
    845       w_t = array_ops.transpose(w)
    846       b = variable_scope.get_variable("proj_b", [classes])
    847 
    848       # Here comes a sample Seq2Seq model using GRU cells.
    849       def SampleGRUSeq2Seq(enc_inp, dec_inp, weights):
    850         """Example sequence-to-sequence model that uses GRU cells."""
    851 
    852         def GRUSeq2Seq(enc_inp, dec_inp):
    853           cell = rnn_cell.MultiRNNCell(
    854               [rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True)
    855           return seq2seq_lib.embedding_attention_seq2seq(
    856               enc_inp,
    857               dec_inp,
    858               cell,
    859               num_encoder_symbols=classes,
    860               num_decoder_symbols=classes,
    861               embedding_size=24,
    862               output_projection=(w, b))
    863 
    864         targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0]
    865 
    866         def SampledLoss(labels, logits):
    867           labels = array_ops.reshape(labels, [-1, 1])
    868           return nn_impl.sampled_softmax_loss(
    869               weights=w_t,
    870               biases=b,
    871               labels=labels,
    872               inputs=logits,
    873               num_sampled=8,
    874               num_classes=classes)
    875 
    876         return seq2seq_lib.model_with_buckets(
    877             enc_inp,
    878             dec_inp,
    879             targets,
    880             weights,
    881             buckets,
    882             GRUSeq2Seq,
    883             softmax_loss_function=SampledLoss)
    884 
    885       # Now we construct the copy model.
    886       batch_size = 8
    887       inp = [
    888           array_ops.placeholder(
    889               dtypes.int32, shape=[None]) for _ in range(8)
    890       ]
    891       out = [
    892           array_ops.placeholder(
    893               dtypes.int32, shape=[None]) for _ in range(8)
    894       ]
    895       weights = [
    896           array_ops.ones_like(
    897               inp[0], dtype=dtypes.float32) for _ in range(8)
    898       ]
    899       with variable_scope.variable_scope("root"):
    900         _, losses = SampleGRUSeq2Seq(inp, out, weights)
    901         updates = []
    902         params = variables.global_variables()
    903         optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5)
    904         for i in range(len(buckets)):
    905           full_grads = gradients_impl.gradients(losses[i], params)
    906           grads, _ = clip_ops.clip_by_global_norm(full_grads, 30.0)
    907           update = optimizer.apply_gradients(zip(grads, params))
    908           updates.append(update)
    909         sess.run([variables.global_variables_initializer()])
    910       steps = 6
    911       for _ in range(steps):
    912         bucket = random.choice(np.arange(len(buckets)))
    913         length = buckets[bucket][0]
    914         i = [
    915             np.array(
    916                 [np.random.randint(9) + 1 for _ in range(batch_size)],
    917                 dtype=np.int32) for _ in range(length)
    918         ]
    919         # 0 is our "GO" symbol here.
    920         o = [np.array([0] * batch_size, dtype=np.int32)] + i
    921         feed = {}
    922         for i1, i2, o1, o2 in zip(inp[:length], i[:length], out[:length],
    923                                   o[:length]):
    924           feed[i1.name] = i2
    925           feed[o1.name] = o2
    926         if length < 8:  # For the 4-bucket, we need the 5th as target.
    927           feed[out[length].name] = o[length]
    928         res = sess.run([updates[bucket], losses[bucket]], feed)
    929         perplexities[bucket].append(math.exp(float(res[1])))
    930       for bucket in range(len(buckets)):
    931         if len(perplexities[bucket]) > 1:  # Assert that perplexity went down.
    932           self.assertLess(perplexities[bucket][-1],  # 20% margin of error.
    933                           1.2 * perplexities[bucket][0])
    934 
    935   def testModelWithBooleanFeedPrevious(self):
    936     """Test the model behavior when feed_previous is True.
    937 
    938     For example, the following two cases have the same effect:
    939       - Train `embedding_rnn_seq2seq` with `feed_previous=True`, which contains
    940         a `embedding_rnn_decoder` with `feed_previous=True` and
    941         `update_embedding_for_previous=True`. The decoder is fed with "<Go>"
    942         and outputs "A, B, C".
    943       - Train `embedding_rnn_seq2seq` with `feed_previous=False`. The decoder
    944         is fed with "<Go>, A, B".
    945     """
    946     num_encoder_symbols = 3
    947     num_decoder_symbols = 5
    948     batch_size = 2
    949     num_enc_timesteps = 2
    950     num_dec_timesteps = 3
    951 
    952     def TestModel(seq2seq):
    953       with self.test_session(graph=ops.Graph()) as sess:
    954         random_seed.set_random_seed(111)
    955         random.seed(111)
    956         np.random.seed(111)
    957 
    958         enc_inp = [
    959             constant_op.constant(
    960                 i + 1, dtypes.int32, shape=[batch_size])
    961             for i in range(num_enc_timesteps)
    962         ]
    963         dec_inp_fp_true = [
    964             constant_op.constant(
    965                 i, dtypes.int32, shape=[batch_size])
    966             for i in range(num_dec_timesteps)
    967         ]
    968         dec_inp_holder_fp_false = [
    969             array_ops.placeholder(
    970                 dtypes.int32, shape=[batch_size])
    971             for _ in range(num_dec_timesteps)
    972         ]
    973         targets = [
    974             constant_op.constant(
    975                 i + 1, dtypes.int32, shape=[batch_size])
    976             for i in range(num_dec_timesteps)
    977         ]
    978         weights = [
    979             constant_op.constant(
    980                 1.0, shape=[batch_size]) for i in range(num_dec_timesteps)
    981         ]
    982 
    983         def ForwardBackward(enc_inp, dec_inp, feed_previous):
    984           scope_name = "fp_{}".format(feed_previous)
    985           with variable_scope.variable_scope(scope_name):
    986             dec_op, _ = seq2seq(enc_inp, dec_inp, feed_previous=feed_previous)
    987             net_variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
    988                                                scope_name)
    989           optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5)
    990           update_op = optimizer.minimize(
    991               seq2seq_lib.sequence_loss(dec_op, targets, weights),
    992               var_list=net_variables)
    993           return dec_op, update_op, net_variables
    994 
    995         dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward(
    996             enc_inp, dec_inp_fp_true, feed_previous=True)
    997         _, update_fp_false, variables_fp_false = ForwardBackward(
    998             enc_inp, dec_inp_holder_fp_false, feed_previous=False)
    999 
   1000         sess.run(variables.global_variables_initializer())
   1001 
   1002         # We only check consistencies between the variables existing in both
   1003         # the models with True and False feed_previous. Variables created by
   1004         # the loop_function in the model with True feed_previous are ignored.
   1005         v_false_name_dict = {
   1006             v.name.split("/", 1)[-1]: v
   1007             for v in variables_fp_false
   1008         }
   1009         matched_variables = [(v, v_false_name_dict[v.name.split("/", 1)[-1]])
   1010                              for v in variables_fp_true]
   1011         for v_true, v_false in matched_variables:
   1012           sess.run(state_ops.assign(v_false, v_true))
   1013 
   1014         # Take the symbols generated by the decoder with feed_previous=True as
   1015         # the true input symbols for the decoder with feed_previous=False.
   1016         dec_fp_true = sess.run(dec_op_fp_true)
   1017         output_symbols_fp_true = np.argmax(dec_fp_true, axis=2)
   1018         dec_inp_fp_false = np.vstack((dec_inp_fp_true[0].eval(),
   1019                                       output_symbols_fp_true[:-1]))
   1020         sess.run(update_fp_true)
   1021         sess.run(update_fp_false, {
   1022             holder: inp
   1023             for holder, inp in zip(dec_inp_holder_fp_false, dec_inp_fp_false)
   1024         })
   1025 
   1026         for v_true, v_false in matched_variables:
   1027           self.assertAllClose(v_true.eval(), v_false.eval())
   1028 
   1029     def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
   1030       cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
   1031       return seq2seq_lib.embedding_rnn_seq2seq(
   1032           enc_inp,
   1033           dec_inp,
   1034           cell,
   1035           num_encoder_symbols,
   1036           num_decoder_symbols,
   1037           embedding_size=2,
   1038           feed_previous=feed_previous)
   1039 
   1040     def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
   1041       cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
   1042       return seq2seq_lib.embedding_rnn_seq2seq(
   1043           enc_inp,
   1044           dec_inp,
   1045           cell,
   1046           num_encoder_symbols,
   1047           num_decoder_symbols,
   1048           embedding_size=2,
   1049           feed_previous=feed_previous)
   1050 
   1051     def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
   1052       cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
   1053       return seq2seq_lib.embedding_tied_rnn_seq2seq(
   1054           enc_inp,
   1055           dec_inp,
   1056           cell,
   1057           num_decoder_symbols,
   1058           embedding_size=2,
   1059           feed_previous=feed_previous)
   1060 
   1061     def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
   1062       cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
   1063       return seq2seq_lib.embedding_tied_rnn_seq2seq(
   1064           enc_inp,
   1065           dec_inp,
   1066           cell,
   1067           num_decoder_symbols,
   1068           embedding_size=2,
   1069           feed_previous=feed_previous)
   1070 
   1071     def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
   1072       cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
   1073       return seq2seq_lib.embedding_attention_seq2seq(
   1074           enc_inp,
   1075           dec_inp,
   1076           cell,
   1077           num_encoder_symbols,
   1078           num_decoder_symbols,
   1079           embedding_size=2,
   1080           feed_previous=feed_previous)
   1081 
   1082     def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
   1083       cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
   1084       return seq2seq_lib.embedding_attention_seq2seq(
   1085           enc_inp,
   1086           dec_inp,
   1087           cell,
   1088           num_encoder_symbols,
   1089           num_decoder_symbols,
   1090           embedding_size=2,
   1091           feed_previous=feed_previous)
   1092 
   1093     for model in (EmbeddingRNNSeq2SeqF, EmbeddingRNNSeq2SeqNoTupleF,
   1094                   EmbeddingTiedRNNSeq2Seq, EmbeddingTiedRNNSeq2SeqNoTuple,
   1095                   EmbeddingAttentionSeq2Seq, EmbeddingAttentionSeq2SeqNoTuple):
   1096       TestModel(model)
   1097 
   1098 
   1099 if __name__ == "__main__":
   1100   test.main()
   1101