Home | History | Annotate | Download | only in spinn
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for SPINN data module."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import shutil
     23 import tempfile
     24 
     25 import numpy as np
     26 import tensorflow as tf
     27 
     28 from tensorflow.contrib.eager.python.examples.spinn import data
     29 
     30 
     31 class DataTest(tf.test.TestCase):
     32 
     33   def setUp(self):
     34     super(DataTest, self).setUp()
     35     self._temp_data_dir = tempfile.mkdtemp()
     36 
     37   def tearDown(self):
     38     shutil.rmtree(self._temp_data_dir)
     39     super(DataTest, self).tearDown()
     40 
     41   def testGenNonParenthesisWords(self):
     42     seq_with_parse = (
     43         "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and "
     44         ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )")
     45     self.assertEqual(
     46         ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing",
     47          "in", "a", "crowd", "of", "people", "."],
     48         data.get_non_parenthesis_words(seq_with_parse.split(" ")))
     49 
     50   def testGetShiftReduce(self):
     51     seq_with_parse = (
     52         "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and "
     53         ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )")
     54     self.assertEqual(
     55         [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2,
     56          3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" ")))
     57 
     58   def testPadAndReverseWordIds(self):
     59     id_sequences = [[0, 2, 3, 4, 5],
     60                     [6, 7, 8],
     61                     [9, 10, 11, 12, 13, 14, 15, 16]]
     62     self.assertAllClose(
     63         [[1, 1, 1, 1, 5, 4, 3, 2, 0],
     64          [1, 1, 1, 1, 1, 1, 8, 7, 6],
     65          [1, 16, 15, 14, 13, 12, 11, 10, 9]],
     66         data.pad_and_reverse_word_ids(id_sequences))
     67 
     68   def testPadTransitions(self):
     69     unpadded = [[3, 3, 3, 2, 2, 2, 2],
     70                 [3, 3, 2, 2, 2]]
     71     self.assertAllClose(
     72         [[3, 3, 3, 2, 2, 2, 2],
     73          [3, 3, 2, 2, 2, 1, 1]],
     74         data.pad_transitions(unpadded))
     75 
     76   def testCalculateBins(self):
     77     length2count = {
     78         1: 10,
     79         2: 15,
     80         3: 25,
     81         4: 40,
     82         5: 35,
     83         6: 10}
     84     self.assertEqual([2, 3, 4, 5, 6],
     85                      data.calculate_bins(length2count, 20))
     86     self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40))
     87     self.assertEqual([4, 6], data.calculate_bins(length2count, 60))
     88 
     89   def testLoadVoacbulary(self):
     90     snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
     91     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
     92     fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt")
     93     os.makedirs(snli_1_0_dir)
     94 
     95     with open(fake_train_file, "wt") as f:
     96       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
     97               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
     98               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
     99       f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t"
    100               "DummySentence1Parse\tDummySentence2Parse\t"
    101               "Foo bar.\tfoo baz.\t"
    102               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    103               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    104     with open(fake_dev_file, "wt") as f:
    105       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
    106               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
    107               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
    108       f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t"
    109               "DummySentence1Parse\tDummySentence2Parse\t"
    110               "Quux quuz?\t.Corge grault!\t"
    111               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    112               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    113 
    114     vocab = data.load_vocabulary(self._temp_data_dir)
    115     self.assertSetEqual(
    116         {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"},
    117         vocab)
    118 
    119   def testLoadVoacbularyWithoutFileRaisesError(self):
    120     with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
    121       data.load_vocabulary(self._temp_data_dir)
    122 
    123     os.makedirs(os.path.join(self._temp_data_dir, "snli"))
    124     with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
    125       data.load_vocabulary(self._temp_data_dir)
    126 
    127     os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0"))
    128     with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
    129       data.load_vocabulary(self._temp_data_dir)
    130 
    131   def testLoadWordVectors(self):
    132     glove_dir = os.path.join(self._temp_data_dir, "glove")
    133     os.makedirs(glove_dir)
    134     glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    135 
    136     words = [".", ",", "foo", "bar", "baz"]
    137     with open(glove_file, "wt") as f:
    138       for i, word in enumerate(words):
    139         f.write("%s " % word)
    140         for j in range(data.WORD_VECTOR_LEN):
    141           f.write("%.5f" % (i * 0.1))
    142           if j < data.WORD_VECTOR_LEN - 1:
    143             f.write(" ")
    144           else:
    145             f.write("\n")
    146 
    147     vocab = {"foo", "bar", "baz", "qux", "."}
    148     # Notice that "qux" is not present in `words`.
    149     word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
    150 
    151     self.assertEqual(6, len(word2index))
    152     self.assertEqual(0, word2index["<unk>"])
    153     self.assertEqual(1, word2index["<pad>"])
    154     self.assertEqual(2, word2index["."])
    155     self.assertEqual(3, word2index["foo"])
    156     self.assertEqual(4, word2index["bar"])
    157     self.assertEqual(5, word2index["baz"])
    158     self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape)
    159     self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :])
    160     self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :])
    161     self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :])
    162     self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :])
    163     self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :])
    164     self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :])
    165 
    166   def testLoadWordVectorsWithoutFileRaisesError(self):
    167     vocab = {"foo", "bar", "baz", "qux", "."}
    168     with self.assertRaisesRegexp(
    169         ValueError, "Cannot find GloVe embedding file at"):
    170       data.load_word_vectors(self._temp_data_dir, vocab)
    171 
    172     os.makedirs(os.path.join(self._temp_data_dir, "glove"))
    173     with self.assertRaisesRegexp(
    174         ValueError, "Cannot find GloVe embedding file at"):
    175       data.load_word_vectors(self._temp_data_dir, vocab)
    176 
    177   def _createFakeSnliData(self, fake_snli_file):
    178     # Four sentences in total.
    179     with open(fake_snli_file, "wt") as f:
    180       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
    181               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
    182               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
    183       f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
    184               "DummySentence1Parse\tDummySentence2Parse\t"
    185               "Foo bar.\tfoo baz.\t"
    186               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    187               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    188       f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
    189               "DummySentence1Parse\tDummySentence2Parse\t"
    190               "Foo bar.\tfoo baz.\t"
    191               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    192               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    193       f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
    194               "DummySentence1Parse\tDummySentence2Parse\t"
    195               "Foo bar.\tfoo baz.\t"
    196               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    197               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    198       f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
    199               "DummySentence1Parse\tDummySentence2Parse\t"
    200               "Foo bar.\tfoo baz.\t"
    201               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    202               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    203 
    204   def _createFakeGloveData(self, glove_file):
    205     words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
    206     with open(glove_file, "wt") as f:
    207       for i, word in enumerate(words):
    208         f.write("%s " % word)
    209         for j in range(data.WORD_VECTOR_LEN):
    210           f.write("%.5f" % (i * 0.1))
    211           if j < data.WORD_VECTOR_LEN - 1:
    212             f.write(" ")
    213           else:
    214             f.write("\n")
    215 
    216   def testEncodeSingleSentence(self):
    217     snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    218     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    219     os.makedirs(snli_1_0_dir)
    220     self._createFakeSnliData(fake_train_file)
    221     vocab = data.load_vocabulary(self._temp_data_dir)
    222     glove_dir = os.path.join(self._temp_data_dir, "glove")
    223     os.makedirs(glove_dir)
    224     glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    225     self._createFakeGloveData(glove_file)
    226     word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
    227 
    228     sentence_variants = [
    229         "( Foo ( ( bar baz ) . ) )",
    230         " ( Foo ( ( bar baz ) . ) ) ",
    231         "( Foo ( ( bar baz ) . )  )"]
    232     for sentence in sentence_variants:
    233       word_indices, shift_reduce = data.encode_sentence(sentence, word2index)
    234       self.assertEqual(np.int64, word_indices.dtype)
    235       self.assertEqual((5, 1), word_indices.shape)
    236       self.assertAllClose(
    237           np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce)
    238 
    239   def testSnliData(self):
    240     snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    241     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    242     os.makedirs(snli_1_0_dir)
    243     self._createFakeSnliData(fake_train_file)
    244 
    245     glove_dir = os.path.join(self._temp_data_dir, "glove")
    246     os.makedirs(glove_dir)
    247     glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    248     self._createFakeGloveData(glove_file)
    249 
    250     vocab = data.load_vocabulary(self._temp_data_dir)
    251     word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
    252 
    253     train_data = data.SnliData(fake_train_file, word2index)
    254     self.assertEqual(4, train_data.num_batches(1))
    255     self.assertEqual(2, train_data.num_batches(2))
    256     self.assertEqual(2, train_data.num_batches(3))
    257     self.assertEqual(1, train_data.num_batches(4))
    258 
    259     generator = train_data.get_generator(2)()
    260     for _ in range(2):
    261       label, prem, prem_trans, hypo, hypo_trans = next(generator)
    262       self.assertEqual(2, len(label))
    263       self.assertEqual((4, 2), prem.shape)
    264       self.assertEqual((5, 2), prem_trans.shape)
    265       self.assertEqual((3, 2), hypo.shape)
    266       self.assertEqual((3, 2), hypo_trans.shape)
    267 
    268 
    269 if __name__ == "__main__":
    270   tf.test.main()
    271