1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for SPINN data module.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import shutil 23 import tempfile 24 25 import numpy as np 26 import tensorflow as tf 27 28 from tensorflow.contrib.eager.python.examples.spinn import data 29 30 31 class DataTest(tf.test.TestCase): 32 33 def setUp(self): 34 super(DataTest, self).setUp() 35 self._temp_data_dir = tempfile.mkdtemp() 36 37 def tearDown(self): 38 shutil.rmtree(self._temp_data_dir) 39 super(DataTest, self).tearDown() 40 41 def testGenNonParenthesisWords(self): 42 seq_with_parse = ( 43 "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " 44 ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") 45 self.assertEqual( 46 ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", 47 "in", "a", "crowd", "of", "people", "."], 48 data.get_non_parenthesis_words(seq_with_parse.split(" "))) 49 50 def testGetShiftReduce(self): 51 seq_with_parse = ( 52 "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " 53 ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") 54 self.assertEqual( 55 [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, 56 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) 57 58 def testPadAndReverseWordIds(self): 59 id_sequences = [[0, 2, 3, 4, 5], 60 [6, 7, 8], 61 [9, 10, 11, 12, 13, 14, 15, 16]] 62 self.assertAllClose( 63 [[1, 1, 1, 1, 5, 4, 3, 2, 0], 64 [1, 1, 1, 1, 1, 1, 8, 7, 6], 65 [1, 16, 15, 14, 13, 12, 11, 10, 9]], 66 data.pad_and_reverse_word_ids(id_sequences)) 67 68 def testPadTransitions(self): 69 unpadded = [[3, 3, 3, 2, 2, 2, 2], 70 [3, 3, 2, 2, 2]] 71 self.assertAllClose( 72 [[3, 3, 3, 2, 2, 2, 2], 73 [3, 3, 2, 2, 2, 1, 1]], 74 data.pad_transitions(unpadded)) 75 76 def testCalculateBins(self): 77 length2count = { 78 1: 10, 79 2: 15, 80 3: 25, 81 4: 40, 82 5: 35, 83 6: 10} 84 self.assertEqual([2, 3, 4, 5, 6], 85 data.calculate_bins(length2count, 20)) 86 self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) 87 self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) 88 89 def testLoadVoacbulary(self): 90 snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") 91 fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") 92 fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") 93 os.makedirs(snli_1_0_dir) 94 95 with open(fake_train_file, "wt") as f: 96 f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" 97 "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" 98 "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") 99 f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" 100 "DummySentence1Parse\tDummySentence2Parse\t" 101 "Foo bar.\tfoo baz.\t" 102 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 103 "neutral\tentailment\tneutral\tneutral\tneutral\n") 104 with open(fake_dev_file, "wt") as f: 105 f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" 106 "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" 107 "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") 108 f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" 109 "DummySentence1Parse\tDummySentence2Parse\t" 110 "Quux quuz?\t.Corge grault!\t" 111 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 112 "neutral\tentailment\tneutral\tneutral\tneutral\n") 113 114 vocab = data.load_vocabulary(self._temp_data_dir) 115 self.assertSetEqual( 116 {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, 117 vocab) 118 119 def testLoadVoacbularyWithoutFileRaisesError(self): 120 with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): 121 data.load_vocabulary(self._temp_data_dir) 122 123 os.makedirs(os.path.join(self._temp_data_dir, "snli")) 124 with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): 125 data.load_vocabulary(self._temp_data_dir) 126 127 os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) 128 with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): 129 data.load_vocabulary(self._temp_data_dir) 130 131 def testLoadWordVectors(self): 132 glove_dir = os.path.join(self._temp_data_dir, "glove") 133 os.makedirs(glove_dir) 134 glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") 135 136 words = [".", ",", "foo", "bar", "baz"] 137 with open(glove_file, "wt") as f: 138 for i, word in enumerate(words): 139 f.write("%s " % word) 140 for j in range(data.WORD_VECTOR_LEN): 141 f.write("%.5f" % (i * 0.1)) 142 if j < data.WORD_VECTOR_LEN - 1: 143 f.write(" ") 144 else: 145 f.write("\n") 146 147 vocab = {"foo", "bar", "baz", "qux", "."} 148 # Notice that "qux" is not present in `words`. 149 word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) 150 151 self.assertEqual(6, len(word2index)) 152 self.assertEqual(0, word2index["<unk>"]) 153 self.assertEqual(1, word2index["<pad>"]) 154 self.assertEqual(2, word2index["."]) 155 self.assertEqual(3, word2index["foo"]) 156 self.assertEqual(4, word2index["bar"]) 157 self.assertEqual(5, word2index["baz"]) 158 self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) 159 self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) 160 self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) 161 self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) 162 self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) 163 self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) 164 self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) 165 166 def testLoadWordVectorsWithoutFileRaisesError(self): 167 vocab = {"foo", "bar", "baz", "qux", "."} 168 with self.assertRaisesRegexp( 169 ValueError, "Cannot find GloVe embedding file at"): 170 data.load_word_vectors(self._temp_data_dir, vocab) 171 172 os.makedirs(os.path.join(self._temp_data_dir, "glove")) 173 with self.assertRaisesRegexp( 174 ValueError, "Cannot find GloVe embedding file at"): 175 data.load_word_vectors(self._temp_data_dir, vocab) 176 177 def _createFakeSnliData(self, fake_snli_file): 178 # Four sentences in total. 179 with open(fake_snli_file, "wt") as f: 180 f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" 181 "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" 182 "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") 183 f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" 184 "DummySentence1Parse\tDummySentence2Parse\t" 185 "Foo bar.\tfoo baz.\t" 186 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 187 "neutral\tentailment\tneutral\tneutral\tneutral\n") 188 f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" 189 "DummySentence1Parse\tDummySentence2Parse\t" 190 "Foo bar.\tfoo baz.\t" 191 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 192 "neutral\tentailment\tneutral\tneutral\tneutral\n") 193 f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" 194 "DummySentence1Parse\tDummySentence2Parse\t" 195 "Foo bar.\tfoo baz.\t" 196 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 197 "neutral\tentailment\tneutral\tneutral\tneutral\n") 198 f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" 199 "DummySentence1Parse\tDummySentence2Parse\t" 200 "Foo bar.\tfoo baz.\t" 201 "4705552913.jpg#2\t4705552913.jpg#2r1n\t" 202 "neutral\tentailment\tneutral\tneutral\tneutral\n") 203 204 def _createFakeGloveData(self, glove_file): 205 words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] 206 with open(glove_file, "wt") as f: 207 for i, word in enumerate(words): 208 f.write("%s " % word) 209 for j in range(data.WORD_VECTOR_LEN): 210 f.write("%.5f" % (i * 0.1)) 211 if j < data.WORD_VECTOR_LEN - 1: 212 f.write(" ") 213 else: 214 f.write("\n") 215 216 def testEncodeSingleSentence(self): 217 snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") 218 fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") 219 os.makedirs(snli_1_0_dir) 220 self._createFakeSnliData(fake_train_file) 221 vocab = data.load_vocabulary(self._temp_data_dir) 222 glove_dir = os.path.join(self._temp_data_dir, "glove") 223 os.makedirs(glove_dir) 224 glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") 225 self._createFakeGloveData(glove_file) 226 word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) 227 228 sentence_variants = [ 229 "( Foo ( ( bar baz ) . ) )", 230 " ( Foo ( ( bar baz ) . ) ) ", 231 "( Foo ( ( bar baz ) . ) )"] 232 for sentence in sentence_variants: 233 word_indices, shift_reduce = data.encode_sentence(sentence, word2index) 234 self.assertEqual(np.int64, word_indices.dtype) 235 self.assertEqual((5, 1), word_indices.shape) 236 self.assertAllClose( 237 np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce) 238 239 def testSnliData(self): 240 snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") 241 fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") 242 os.makedirs(snli_1_0_dir) 243 self._createFakeSnliData(fake_train_file) 244 245 glove_dir = os.path.join(self._temp_data_dir, "glove") 246 os.makedirs(glove_dir) 247 glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") 248 self._createFakeGloveData(glove_file) 249 250 vocab = data.load_vocabulary(self._temp_data_dir) 251 word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) 252 253 train_data = data.SnliData(fake_train_file, word2index) 254 self.assertEqual(4, train_data.num_batches(1)) 255 self.assertEqual(2, train_data.num_batches(2)) 256 self.assertEqual(2, train_data.num_batches(3)) 257 self.assertEqual(1, train_data.num_batches(4)) 258 259 generator = train_data.get_generator(2)() 260 for _ in range(2): 261 label, prem, prem_trans, hypo, hypo_trans = next(generator) 262 self.assertEqual(2, len(label)) 263 self.assertEqual((4, 2), prem.shape) 264 self.assertEqual((5, 2), prem_trans.shape) 265 self.assertEqual((3, 2), hypo.shape) 266 self.assertEqual((3, 2), hypo_trans.shape) 267 268 269 if __name__ == "__main__": 270 tf.test.main() 271