Home | History | Annotate | Download | only in spinn
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities of SNLI data and GloVe word vectors for SPINN model.
     16 
     17 See more details about the SNLI data set at:
     18   https://nlp.stanford.edu/projects/snli/
     19 
     20 See more details about the GloVe pretrained word embeddings at:
     21   https://nlp.stanford.edu/projects/glove/
     22 """
     23 
     24 from __future__ import absolute_import
     25 from __future__ import division
     26 from __future__ import print_function
     27 
     28 import glob
     29 import math
     30 import os
     31 import random
     32 
     33 import numpy as np
     34 
     35 POSSIBLE_LABELS = ("entailment", "contradiction", "neutral")
     36 
     37 UNK_CODE = 0   # Code for unknown word tokens.
     38 PAD_CODE = 1   # Code for padding tokens.
     39 
     40 SHIFT_CODE = 3
     41 REDUCE_CODE = 2
     42 
     43 WORD_VECTOR_LEN = 300  # Embedding dimensions.
     44 
     45 LEFT_PAREN = "("
     46 RIGHT_PAREN = ")"
     47 PARENTHESES = (LEFT_PAREN, RIGHT_PAREN)
     48 
     49 
     50 def get_non_parenthesis_words(items):
     51   """Get the non-parenthesis items from a SNLI parsed sentence.
     52 
     53   Args:
     54     items: Data items from a parsed SNLI sentence, with parentheses. E.g.,
     55       ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
     56 
     57   Returns:
     58     A list of non-parentheses word items, all converted to lower case. E.g.,
     59       ["man", "wearing", "pass", ...
     60   """
     61   return [x.lower() for x in items if x not in PARENTHESES and x]
     62 
     63 
     64 def get_shift_reduce(items):
     65   """Obtain shift-reduce vector from a list of items from the SNLI data.
     66 
     67   Args:
     68     items: Data items as a list of str, e.g.,
     69        ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
     70 
     71   Returns:
     72     A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and
     73       `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE`
     74       and `REDUCE_CODE`.
     75   """
     76   trans = []
     77   for item in items:
     78     if item == LEFT_PAREN:
     79       continue
     80     elif item == RIGHT_PAREN:
     81       trans.append(REDUCE_CODE)
     82     else:
     83       trans.append(SHIFT_CODE)
     84   return trans
     85 
     86 
     87 def pad_and_reverse_word_ids(sentences):
     88   """Pad a list of sentences to the common maximum length + 1.
     89 
     90   Args:
     91     sentences: A list of sentences as a list of list of integers. Each integer
     92       is a word ID. Each list of integer corresponds to one sentence.
     93 
     94   Returns:
     95     A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length
     96       is the maximum sentence length (in # of words). Each sentence is reversed
     97       and then padded with an extra one at head, as required by the model.
     98   """
     99   max_len = max(len(sent) for sent in sentences)
    100   for sent in sentences:
    101     if len(sent) < max_len:
    102       sent.extend([PAD_CODE] * (max_len - len(sent)))
    103   # Reverse in time order and pad an extra one.
    104   sentences = np.fliplr(np.array(sentences, dtype=np.int64))
    105   sentences = np.concatenate(
    106       [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1)
    107   return sentences
    108 
    109 
    110 def pad_transitions(sentences_transitions):
    111   """Pad a list of shift-reduce transitions to the maximum length."""
    112   max_len = max(len(transitions) for transitions in sentences_transitions)
    113   for transitions in sentences_transitions:
    114     if len(transitions) < max_len:
    115       transitions.extend([PAD_CODE] * (max_len - len(transitions)))
    116   return np.array(sentences_transitions, dtype=np.int64)
    117 
    118 
    119 def load_vocabulary(data_root):
    120   """Load vocabulary from SNLI data files.
    121 
    122   Args:
    123     data_root: Root directory of the data. It is assumed that the SNLI data
    124       files have been downloaded and extracted to the "snli/snli_1.0"
    125       subdirectory of it.
    126 
    127   Returns:
    128     Vocabulary as a set of strings.
    129 
    130   Raises:
    131     ValueError: If SNLI data files cannot be found.
    132   """
    133   snli_path = os.path.join(data_root, "snli")
    134   snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt")
    135   file_names = glob.glob(snli_glob_pattern)
    136   if not file_names:
    137     raise ValueError(
    138         "Cannot find SNLI data files at %s. "
    139         "Please download and extract SNLI data first." % snli_glob_pattern)
    140 
    141   print("Loading vocabulary...")
    142   vocab = set()
    143   for file_name in file_names:
    144     with open(os.path.join(snli_path, file_name), "rt") as f:
    145       for i, line in enumerate(f):
    146         if i == 0:
    147           continue
    148         items = line.split("\t")
    149         premise_words = get_non_parenthesis_words(items[1].split(" "))
    150         hypothesis_words = get_non_parenthesis_words(items[2].split(" "))
    151         vocab.update(premise_words)
    152         vocab.update(hypothesis_words)
    153   return vocab
    154 
    155 
    156 def load_word_vectors(data_root, vocab):
    157   """Load GloVe word vectors for words present in the vocabulary.
    158 
    159   Args:
    160     data_root: Data root directory. It is assumed that the GloVe file
    161      has been downloaded and extracted at the "glove/" subdirectory of it.
    162     vocab: A `set` of words, representing the vocabulary.
    163 
    164   Returns:
    165     1. word2index: A dict from lower-case word to row index in the embedding
    166        matrix, i.e, `embed` below.
    167     2. embed: The embedding matrix as a float32 numpy array. Its shape is
    168        [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab).
    169        WORD_VECTOR_LEN is the embedding dimension (300).
    170 
    171   Raises:
    172     ValueError: If GloVe embedding file cannot be found.
    173   """
    174   glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt")
    175   if not os.path.isfile(glove_path):
    176     raise ValueError(
    177         "Cannot find GloVe embedding file at %s. "
    178         "Please download and extract GloVe embeddings first." % glove_path)
    179 
    180   print("Loading word vectors...")
    181 
    182   word2index = dict()
    183   embed = []
    184 
    185   embed.append([0] * WORD_VECTOR_LEN)  # <unk>
    186   embed.append([0] * WORD_VECTOR_LEN)  # <pad>
    187   word2index["<unk>"] = UNK_CODE
    188   word2index["<pad>"] = PAD_CODE
    189 
    190   with open(glove_path, "rt") as f:
    191     for line in f:
    192       items = line.split(" ")
    193       word = items[0]
    194       if word in vocab and word not in word2index:
    195         word2index[word] = len(embed)
    196         vector = np.array([float(item) for item in items[1:]])
    197         assert (WORD_VECTOR_LEN,) == vector.shape
    198         embed.append(vector)
    199   embed = np.array(embed, dtype=np.float32)
    200   return word2index, embed
    201 
    202 
    203 def calculate_bins(length2count, min_bin_size):
    204   """Calculate bin boundaries given a histogram of lengths and minimum bin size.
    205 
    206   Args:
    207     length2count: A `dict` mapping length to sentence count.
    208     min_bin_size: Minimum bin size in terms of total number of sentence pairs
    209       in the bin.
    210 
    211   Returns:
    212     A `list` representing the right bin boundaries, starting from the inclusive
    213     right boundary of the first bin. For example, if the output is
    214       [10, 20, 35],
    215     it means there are three bins: [1, 10], [11, 20] and [21, 35].
    216   """
    217   bounds = []
    218   lengths = sorted(length2count.keys())
    219   cum_count = 0
    220   for length in lengths:
    221     cum_count += length2count[length]
    222     if cum_count >= min_bin_size:
    223       bounds.append(length)
    224       cum_count = 0
    225   if bounds[-1] != lengths[-1]:
    226     bounds.append(lengths[-1])
    227   return bounds
    228 
    229 
    230 def encode_sentence(sentence, word2index):
    231   """Encode a single sentence as word indices and shift-reduce code.
    232 
    233   Args:
    234     sentence: The sentence with added binary parse information, represented as
    235       a string, with all the word items and parentheses separated by spaces.
    236       E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'.
    237     word2index: A `dict` mapping words to their word indices.
    238 
    239   Returns:
    240      1. Word indices as a numpy array, with shape `(sequence_len, 1)`.
    241      2. Shift-reduce sequence as a numpy array, with shape
    242        `(sequence_len * 2 - 3, 1)`.
    243   """
    244   items = [w for w in sentence.split(" ") if w]
    245   words = get_non_parenthesis_words(items)
    246   shift_reduce = get_shift_reduce(items)
    247   word_indices = pad_and_reverse_word_ids(
    248       [[word2index.get(word, UNK_CODE) for word in words]]).T
    249   return (word_indices,
    250           np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1))
    251 
    252 
    253 class SnliData(object):
    254   """A split of SNLI data."""
    255 
    256   def __init__(self, data_file, word2index, sentence_len_limit=-1):
    257     """SnliData constructor.
    258 
    259     Args:
    260       data_file: Full path to the data file, e.g.,
    261         "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt"
    262       word2index: A dict from lower-case word to row index in the embedding
    263         matrix (see `load_word_vectors()` for details).
    264       sentence_len_limit: Maximum allowed sentence length (# of words).
    265         A value of <= 0 means unlimited. Sentences longer than this limit
    266         are currently discarded, not truncated.
    267     """
    268 
    269     self._labels = []
    270     self._premises = []
    271     self._premise_transitions = []
    272     self._hypotheses = []
    273     self._hypothesis_transitions = []
    274 
    275     with open(data_file, "rt") as f:
    276       for i, line in enumerate(f):
    277         if i == 0:
    278           # Skip header line.
    279           continue
    280         items = line.split("\t")
    281         if items[0] not in POSSIBLE_LABELS:
    282           continue
    283 
    284         premise_items = items[1].split(" ")
    285         hypothesis_items = items[2].split(" ")
    286         premise_words = get_non_parenthesis_words(premise_items)
    287         hypothesis_words = get_non_parenthesis_words(hypothesis_items)
    288 
    289         if (sentence_len_limit > 0 and
    290             (len(premise_words) > sentence_len_limit or
    291              len(hypothesis_words) > sentence_len_limit)):
    292           # TODO(cais): Maybe truncate; do not discard.
    293           continue
    294 
    295         premise_ids = [
    296             word2index.get(word, UNK_CODE) for word in premise_words]
    297         hypothesis_ids = [
    298             word2index.get(word, UNK_CODE) for word in hypothesis_words]
    299 
    300         self._premises.append(premise_ids)
    301         self._hypotheses.append(hypothesis_ids)
    302         self._premise_transitions.append(get_shift_reduce(premise_items))
    303         self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items))
    304         assert (len(self._premise_transitions[-1]) ==
    305                 2 * len(premise_words) - 1)
    306         assert (len(self._hypothesis_transitions[-1]) ==
    307                 2 * len(hypothesis_words) - 1)
    308 
    309         self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1)
    310 
    311     assert len(self._labels) == len(self._premises)
    312     assert len(self._labels) == len(self._hypotheses)
    313     assert len(self._labels) == len(self._premise_transitions)
    314     assert len(self._labels) == len(self._hypothesis_transitions)
    315 
    316   def num_batches(self, batch_size):
    317     """Calculate number of batches given batch size."""
    318     return int(math.ceil(len(self._labels) / batch_size))
    319 
    320   def get_generator(self, batch_size):
    321     """Obtain a generator for batched data.
    322 
    323     All examples of this SnliData object are randomly shuffled, sorted
    324     according to the maximum sentence length of the premise and hypothesis
    325     sentences in the pair, and batched.
    326 
    327     Args:
    328       batch_size: Desired batch size.
    329 
    330     Returns:
    331       A generator for data batches. The generator yields a 5-tuple:
    332         label: An array of the shape (batch_size,).
    333         premise: An array of the shape (max_premise_len, batch_size), wherein
    334           max_premise_len is the maximum length of the (padded) premise
    335           sentence in the batch.
    336         premise_transitions: An array of the shape (2 * max_premise_len -3,
    337           batch_size).
    338         hypothesis: Same as `premise`, but for hypothesis sentences.
    339         hypothesis_transitions: Same as `premise_transitions`, but for
    340           hypothesis sentences.
    341       All the elements of the 5-tuple have dtype `int64`.
    342     """
    343     # Randomly shuffle examples.
    344     zipped = list(zip(
    345         self._labels, self._premises, self._premise_transitions,
    346         self._hypotheses, self._hypothesis_transitions))
    347     random.shuffle(zipped)
    348     # Then sort the examples by maximum of the premise and hypothesis sentence
    349     # lengths in the pair. During training, the batches are expected to be
    350     # shuffled. So it is okay to leave them sorted by max length here.
    351     (labels, premises, premise_transitions, hypotheses,
    352      hypothesis_transitions) = zip(
    353          *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3]))))
    354 
    355     def _generator():
    356       begin = 0
    357       while begin < len(labels):
    358         # The sorting above and the batching here makes sure that sentences of
    359         # similar max lengths are batched together, minimizing the inefficiency
    360         # due to uneven max lengths. The sentences are batched differently in
    361         # each call to get_generator() due to the shuffling before sorting
    362         # above. The pad_and_reverse_word_ids() and pad_transitions() functions
    363         # take care of any remaining unevenness of the max sentence lengths.
    364         end = min(begin + batch_size, len(labels))
    365         # Transpose, because the SPINN model requires time-major, instead of
    366         # batch-major.
    367         yield (labels[begin:end],
    368                pad_and_reverse_word_ids(premises[begin:end]).T,
    369                pad_transitions(premise_transitions[begin:end]).T,
    370                pad_and_reverse_word_ids(hypotheses[begin:end]).T,
    371                pad_transitions(hypothesis_transitions[begin:end]).T)
    372         begin = end
    373     return _generator
    374