Home | History | Annotate | Download | only in spinn
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for SPINN data module."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     21 import os
     22 import shutil
     23 import tempfile
     25 import numpy as np
     26 import tensorflow as tf
     28 from tensorflow.contrib.eager.python.examples.spinn import data
     31 class DataTest(tf.test.TestCase):
     33   def setUp(self):
     34     super(DataTest, self).setUp()
     35     self._temp_data_dir = tempfile.mkdtemp()
     37   def tearDown(self):
     38     shutil.rmtree(self._temp_data_dir)
     39     super(DataTest, self).tearDown()
     41   def testGenNonParenthesisWords(self):
     42     seq_with_parse = (
     43         "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and "
     44         ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )")
     45     self.assertEqual(
     46         ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing",
     47          "in", "a", "crowd", "of", "people", "."],
     48         data.get_non_parenthesis_words(seq_with_parse.split(" ")))
     50   def testGetShiftReduce(self):
     51     seq_with_parse = (
     52         "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and "
     53         ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )")
     54     self.assertEqual(
     55         [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2,
     56          3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" ")))
     58   def testPadAndReverseWordIds(self):
     59     id_sequences = [[0, 2, 3, 4, 5],
     60                     [6, 7, 8],
     61                     [9, 10, 11, 12, 13, 14, 15, 16]]
     62     self.assertAllClose(
     63         [[1, 1, 1, 1, 5, 4, 3, 2, 0],
     64          [1, 1, 1, 1, 1, 1, 8, 7, 6],
     65          [1, 16, 15, 14, 13, 12, 11, 10, 9]],
     66         data.pad_and_reverse_word_ids(id_sequences))
     68   def testPadTransitions(self):
     69     unpadded = [[3, 3, 3, 2, 2, 2, 2],
     70                 [3, 3, 2, 2, 2]]
     71     self.assertAllClose(
     72         [[3, 3, 3, 2, 2, 2, 2],
     73          [3, 3, 2, 2, 2, 1, 1]],
     74         data.pad_transitions(unpadded))
     76   def testCalculateBins(self):
     77     length2count = {
     78         1: 10,
     79         2: 15,
     80         3: 25,
     81         4: 40,
     82         5: 35,
     83         6: 10}
     84     self.assertEqual([2, 3, 4, 5, 6],
     85                      data.calculate_bins(length2count, 20))
     86     self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40))
     87     self.assertEqual([4, 6], data.calculate_bins(length2count, 60))
     89   def testLoadVoacbulary(self):
     90     snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
     91     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
     92     fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt")
     93     os.makedirs(snli_1_0_dir)
     95     with open(fake_train_file, "wt") as f:
     96       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
     97               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
     98               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
     99       f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t"
    100               "DummySentence1Parse\tDummySentence2Parse\t"
    101               "Foo bar.\tfoo baz.\t"
    102               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    103               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    104     with open(fake_dev_file, "wt") as f:
    105       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
    106               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
    107               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
    108       f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t"
    109               "DummySentence1Parse\tDummySentence2Parse\t"
    110               "Quux quuz?\t.Corge grault!\t"
    111               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    112               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    114     vocab = data.load_vocabulary(self._temp_data_dir)
    115     self.assertSetEqual(
    116         {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"},
    117         vocab)
    119   def testLoadVoacbularyWithoutFileRaisesError(self):
    120     with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
    121       data.load_vocabulary(self._temp_data_dir)
    123     os.makedirs(os.path.join(self._temp_data_dir, "snli"))
    124     with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
    125       data.load_vocabulary(self._temp_data_dir)
    127     os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0"))
    128     with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
    129       data.load_vocabulary(self._temp_data_dir)
    131   def testLoadWordVectors(self):
    132     glove_dir = os.path.join(self._temp_data_dir, "glove")
    133     os.makedirs(glove_dir)
    134     glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    136     words = [".", ",", "foo", "bar", "baz"]
    137     with open(glove_file, "wt") as f:
    138       for i, word in enumerate(words):
    139         f.write("%s " % word)
    140         for j in range(data.WORD_VECTOR_LEN):
    141           f.write("%.5f" % (i * 0.1))
    142           if j < data.WORD_VECTOR_LEN - 1:
    143             f.write(" ")
    144           else:
    145             f.write("\n")
    147     vocab = {"foo", "bar", "baz", "qux", "."}
    148     # Notice that "qux" is not present in `words`.
    149     word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
    151     self.assertEqual(6, len(word2index))
    152     self.assertEqual(0, word2index["<unk>"])
    153     self.assertEqual(1, word2index["<pad>"])
    154     self.assertEqual(2, word2index["."])
    155     self.assertEqual(3, word2index["foo"])
    156     self.assertEqual(4, word2index["bar"])
    157     self.assertEqual(5, word2index["baz"])
    158     self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape)
    159     self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :])
    160     self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :])
    161     self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :])
    162     self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :])
    163     self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :])
    164     self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :])
    166   def testLoadWordVectorsWithoutFileRaisesError(self):
    167     vocab = {"foo", "bar", "baz", "qux", "."}
    168     with self.assertRaisesRegexp(
    169         ValueError, "Cannot find GloVe embedding file at"):
    170       data.load_word_vectors(self._temp_data_dir, vocab)
    172     os.makedirs(os.path.join(self._temp_data_dir, "glove"))
    173     with self.assertRaisesRegexp(
    174         ValueError, "Cannot find GloVe embedding file at"):
    175       data.load_word_vectors(self._temp_data_dir, vocab)
    177   def _createFakeSnliData(self, fake_snli_file):
    178     # Four sentences in total.
    179     with open(fake_snli_file, "wt") as f:
    180       f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
    181               "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
    182               "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
    183       f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
    184               "DummySentence1Parse\tDummySentence2Parse\t"
    185               "Foo bar.\tfoo baz.\t"
    186               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    187               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    188       f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
    189               "DummySentence1Parse\tDummySentence2Parse\t"
    190               "Foo bar.\tfoo baz.\t"
    191               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    192               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    193       f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
    194               "DummySentence1Parse\tDummySentence2Parse\t"
    195               "Foo bar.\tfoo baz.\t"
    196               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    197               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    198       f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
    199               "DummySentence1Parse\tDummySentence2Parse\t"
    200               "Foo bar.\tfoo baz.\t"
    201               "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
    202               "neutral\tentailment\tneutral\tneutral\tneutral\n")
    204   def _createFakeGloveData(self, glove_file):
    205     words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
    206     with open(glove_file, "wt") as f:
    207       for i, word in enumerate(words):
    208         f.write("%s " % word)
    209         for j in range(data.WORD_VECTOR_LEN):
    210           f.write("%.5f" % (i * 0.1))
    211           if j < data.WORD_VECTOR_LEN - 1:
    212             f.write(" ")
    213           else:
    214             f.write("\n")
    216   def testEncodeSingleSentence(self):
    217     snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    218     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    219     os.makedirs(snli_1_0_dir)
    220     self._createFakeSnliData(fake_train_file)
    221     vocab = data.load_vocabulary(self._temp_data_dir)
    222     glove_dir = os.path.join(self._temp_data_dir, "glove")
    223     os.makedirs(glove_dir)
    224     glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    225     self._createFakeGloveData(glove_file)
    226     word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
    228     sentence_variants = [
    229         "( Foo ( ( bar baz ) . ) )",
    230         " ( Foo ( ( bar baz ) . ) ) ",
    231         "( Foo ( ( bar baz ) . )  )"]
    232     for sentence in sentence_variants:
    233       word_indices, shift_reduce = data.encode_sentence(sentence, word2index)
    234       self.assertEqual(np.int64, word_indices.dtype)
    235       self.assertEqual((5, 1), word_indices.shape)
    236       self.assertAllClose(
    237           np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce)
    239   def testSnliData(self):
    240     snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    241     fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
    242     os.makedirs(snli_1_0_dir)
    243     self._createFakeSnliData(fake_train_file)
    245     glove_dir = os.path.join(self._temp_data_dir, "glove")
    246     os.makedirs(glove_dir)
    247     glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
    248     self._createFakeGloveData(glove_file)
    250     vocab = data.load_vocabulary(self._temp_data_dir)
    251     word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
    253     train_data = data.SnliData(fake_train_file, word2index)
    254     self.assertEqual(4, train_data.num_batches(1))
    255     self.assertEqual(2, train_data.num_batches(2))
    256     self.assertEqual(2, train_data.num_batches(3))
    257     self.assertEqual(1, train_data.num_batches(4))
    259     generator = train_data.get_generator(2)()
    260     for _ in range(2):
    261       label, prem, prem_trans, hypo, hypo_trans = next(generator)
    262       self.assertEqual(2, len(label))
    263       self.assertEqual((4, 2), prem.shape)
    264       self.assertEqual((5, 2), prem_trans.shape)
    265       self.assertEqual((3, 2), hypo.shape)
    266       self.assertEqual((3, 2), hypo_trans.shape)
    269 if __name__ == "__main__":
    270   tf.test.main()