Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Module for constructing a linear-chain CRF.
     16 
     17 The following snippet is an example of a CRF layer on top of a batched sequence
     18 of unary scores (logits for every word). This example also decodes the most
     19 likely sequence at test time. There are two ways to do decoding. One
     20 is using crf_decode to do decoding in Tensorflow , and the other one is using
     21 viterbi_decode in Numpy.
     22 
     23 log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
     24     unary_scores, gold_tags, sequence_lengths)
     25 
     26 loss = tf.reduce_mean(-log_likelihood)
     27 train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
     28 
     29 # Decoding in Tensorflow.
     30 viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
     31     unary_scores, transition_params, sequence_lengths)
     32 
     33 tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
     34     [viterbi_sequence, viterbi_score, train_op])
     35 
     36 # Decoding in Numpy.
     37 tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
     38     [unary_scores, sequence_lengths, transition_params, train_op])
     39 for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
     40                                                  tf_sequence_lengths):
     41 # Remove padding.
     42 tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
     43 
     44 # Compute the highest score and its tag sequence.
     45 tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
     46     tf_unary_scores_, tf_transition_params)
     47 """
     48 
     49 from __future__ import absolute_import
     50 from __future__ import division
     51 from __future__ import print_function
     52 
     53 import numpy as np
     54 
     55 from tensorflow.python.framework import dtypes
     56 from tensorflow.python.layers import utils
     57 from tensorflow.python.ops import array_ops
     58 from tensorflow.python.ops import control_flow_ops
     59 from tensorflow.python.ops import gen_array_ops
     60 from tensorflow.python.ops import math_ops
     61 from tensorflow.python.ops import rnn
     62 from tensorflow.python.ops import rnn_cell
     63 from tensorflow.python.ops import variable_scope as vs
     64 
     65 __all__ = [
     66     "crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
     67     "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
     68     "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
     69     "CrfDecodeBackwardRnnCell"
     70 ]
     71 
     72 
     73 def crf_sequence_score(inputs, tag_indices, sequence_lengths,
     74                        transition_params):
     75   """Computes the unnormalized score for a tag sequence.
     76 
     77   Args:
     78     inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
     79         to use as input to the CRF layer.
     80     tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
     81         compute the unnormalized score.
     82     sequence_lengths: A [batch_size] vector of true sequence lengths.
     83     transition_params: A [num_tags, num_tags] transition matrix.
     84   Returns:
     85     sequence_scores: A [batch_size] vector of unnormalized sequence scores.
     86   """
     87   # If max_seq_len is 1, we skip the score calculation and simply gather the
     88   # unary potentials of the single tag.
     89   def _single_seq_fn():
     90     batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
     91     example_inds = array_ops.reshape(
     92         math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
     93     return array_ops.gather_nd(
     94         array_ops.squeeze(inputs, [1]),
     95         array_ops.concat([example_inds, tag_indices], axis=1))
     96 
     97   def _multi_seq_fn():
     98     # Compute the scores of the given tag sequence.
     99     unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
    100     binary_scores = crf_binary_score(tag_indices, sequence_lengths,
    101                                      transition_params)
    102     sequence_scores = unary_scores + binary_scores
    103     return sequence_scores
    104 
    105   return utils.smart_cond(
    106       pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
    107                           1),
    108       true_fn=_single_seq_fn,
    109       false_fn=_multi_seq_fn)
    110 
    111 
    112 def crf_log_norm(inputs, sequence_lengths, transition_params):
    113   """Computes the normalization for a CRF.
    114 
    115   Args:
    116     inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
    117         to use as input to the CRF layer.
    118     sequence_lengths: A [batch_size] vector of true sequence lengths.
    119     transition_params: A [num_tags, num_tags] transition matrix.
    120   Returns:
    121     log_norm: A [batch_size] vector of normalizers for a CRF.
    122   """
    123   # Split up the first and rest of the inputs in preparation for the forward
    124   # algorithm.
    125   first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
    126   first_input = array_ops.squeeze(first_input, [1])
    127 
    128   # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
    129   # the "initial state" (the unary potentials).
    130   def _single_seq_fn():
    131     return math_ops.reduce_logsumexp(first_input, [1])
    132 
    133   def _multi_seq_fn():
    134     """Forward computation of alpha values."""
    135     rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
    136 
    137     # Compute the alpha values in the forward algorithm in order to get the
    138     # partition function.
    139     forward_cell = CrfForwardRnnCell(transition_params)
    140     _, alphas = rnn.dynamic_rnn(
    141         cell=forward_cell,
    142         inputs=rest_of_input,
    143         sequence_length=sequence_lengths - 1,
    144         initial_state=first_input,
    145         dtype=dtypes.float32)
    146     log_norm = math_ops.reduce_logsumexp(alphas, [1])
    147     return log_norm
    148 
    149   max_seq_len = array_ops.shape(inputs)[1]
    150   return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
    151                                true_fn=_single_seq_fn,
    152                                false_fn=_multi_seq_fn)
    153 
    154 
    155 def crf_log_likelihood(inputs,
    156                        tag_indices,
    157                        sequence_lengths,
    158                        transition_params=None):
    159   """Computes the log-likelihood of tag sequences in a CRF.
    160 
    161   Args:
    162     inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
    163         to use as input to the CRF layer.
    164     tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
    165         compute the log-likelihood.
    166     sequence_lengths: A [batch_size] vector of true sequence lengths.
    167     transition_params: A [num_tags, num_tags] transition matrix, if available.
    168   Returns:
    169     log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
    170       each example, given the sequence of tag indices.
    171     transition_params: A [num_tags, num_tags] transition matrix. This is either
    172         provided by the caller or created in this function.
    173   """
    174   # Get shape information.
    175   num_tags = inputs.get_shape()[2].value
    176 
    177   # Get the transition matrix if not provided.
    178   if transition_params is None:
    179     transition_params = vs.get_variable("transitions", [num_tags, num_tags])
    180 
    181   sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
    182                                        transition_params)
    183   log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
    184 
    185   # Normalize the scores to get the log-likelihood per example.
    186   log_likelihood = sequence_scores - log_norm
    187   return log_likelihood, transition_params
    188 
    189 
    190 def crf_unary_score(tag_indices, sequence_lengths, inputs):
    191   """Computes the unary scores of tag sequences.
    192 
    193   Args:
    194     tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
    195     sequence_lengths: A [batch_size] vector of true sequence lengths.
    196     inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
    197   Returns:
    198     unary_scores: A [batch_size] vector of unary scores.
    199   """
    200   batch_size = array_ops.shape(inputs)[0]
    201   max_seq_len = array_ops.shape(inputs)[1]
    202   num_tags = array_ops.shape(inputs)[2]
    203 
    204   flattened_inputs = array_ops.reshape(inputs, [-1])
    205 
    206   offsets = array_ops.expand_dims(
    207       math_ops.range(batch_size) * max_seq_len * num_tags, 1)
    208   offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
    209   # Use int32 or int64 based on tag_indices' dtype.
    210   if tag_indices.dtype == dtypes.int64:
    211     offsets = math_ops.to_int64(offsets)
    212   flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
    213 
    214   unary_scores = array_ops.reshape(
    215       array_ops.gather(flattened_inputs, flattened_tag_indices),
    216       [batch_size, max_seq_len])
    217 
    218   masks = array_ops.sequence_mask(sequence_lengths,
    219                                   maxlen=array_ops.shape(tag_indices)[1],
    220                                   dtype=dtypes.float32)
    221 
    222   unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
    223   return unary_scores
    224 
    225 
    226 def crf_binary_score(tag_indices, sequence_lengths, transition_params):
    227   """Computes the binary scores of tag sequences.
    228 
    229   Args:
    230     tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
    231     sequence_lengths: A [batch_size] vector of true sequence lengths.
    232     transition_params: A [num_tags, num_tags] matrix of binary potentials.
    233   Returns:
    234     binary_scores: A [batch_size] vector of binary scores.
    235   """
    236   # Get shape information.
    237   num_tags = transition_params.get_shape()[0]
    238   num_transitions = array_ops.shape(tag_indices)[1] - 1
    239 
    240   # Truncate by one on each side of the sequence to get the start and end
    241   # indices of each transition.
    242   start_tag_indices = array_ops.slice(tag_indices, [0, 0],
    243                                       [-1, num_transitions])
    244   end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
    245 
    246   # Encode the indices in a flattened representation.
    247   flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
    248   flattened_transition_params = array_ops.reshape(transition_params, [-1])
    249 
    250   # Get the binary scores based on the flattened representation.
    251   binary_scores = array_ops.gather(flattened_transition_params,
    252                                    flattened_transition_indices)
    253 
    254   masks = array_ops.sequence_mask(sequence_lengths,
    255                                   maxlen=array_ops.shape(tag_indices)[1],
    256                                   dtype=dtypes.float32)
    257   truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
    258   binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
    259   return binary_scores
    260 
    261 
    262 class CrfForwardRnnCell(rnn_cell.RNNCell):
    263   """Computes the alpha values in a linear-chain CRF.
    264 
    265   See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
    266   """
    267 
    268   def __init__(self, transition_params):
    269     """Initialize the CrfForwardRnnCell.
    270 
    271     Args:
    272       transition_params: A [num_tags, num_tags] matrix of binary potentials.
    273           This matrix is expanded into a [1, num_tags, num_tags] in preparation
    274           for the broadcast summation occurring within the cell.
    275     """
    276     self._transition_params = array_ops.expand_dims(transition_params, 0)
    277     self._num_tags = transition_params.get_shape()[0].value
    278 
    279   @property
    280   def state_size(self):
    281     return self._num_tags
    282 
    283   @property
    284   def output_size(self):
    285     return self._num_tags
    286 
    287   def __call__(self, inputs, state, scope=None):
    288     """Build the CrfForwardRnnCell.
    289 
    290     Args:
    291       inputs: A [batch_size, num_tags] matrix of unary potentials.
    292       state: A [batch_size, num_tags] matrix containing the previous alpha
    293           values.
    294       scope: Unused variable scope of this cell.
    295 
    296     Returns:
    297       new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
    298           values containing the new alpha values.
    299     """
    300     state = array_ops.expand_dims(state, 2)
    301 
    302     # This addition op broadcasts self._transitions_params along the zeroth
    303     # dimension and state along the second dimension. This performs the
    304     # multiplication of previous alpha values and the current binary potentials
    305     # in log space.
    306     transition_scores = state + self._transition_params
    307     new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1])
    308 
    309     # Both the state and the output of this RNN cell contain the alphas values.
    310     # The output value is currently unused and simply satisfies the RNN API.
    311     # This could be useful in the future if we need to compute marginal
    312     # probabilities, which would require the accumulated alpha values at every
    313     # time step.
    314     return new_alphas, new_alphas
    315 
    316 
    317 def viterbi_decode(score, transition_params):
    318   """Decode the highest scoring sequence of tags outside of TensorFlow.
    319 
    320   This should only be used at test time.
    321 
    322   Args:
    323     score: A [seq_len, num_tags] matrix of unary potentials.
    324     transition_params: A [num_tags, num_tags] matrix of binary potentials.
    325 
    326   Returns:
    327     viterbi: A [seq_len] list of integers containing the highest scoring tag
    328         indices.
    329     viterbi_score: A float containing the score for the Viterbi sequence.
    330   """
    331   trellis = np.zeros_like(score)
    332   backpointers = np.zeros_like(score, dtype=np.int32)
    333   trellis[0] = score[0]
    334 
    335   for t in range(1, score.shape[0]):
    336     v = np.expand_dims(trellis[t - 1], 1) + transition_params
    337     trellis[t] = score[t] + np.max(v, 0)
    338     backpointers[t] = np.argmax(v, 0)
    339 
    340   viterbi = [np.argmax(trellis[-1])]
    341   for bp in reversed(backpointers[1:]):
    342     viterbi.append(bp[viterbi[-1]])
    343   viterbi.reverse()
    344 
    345   viterbi_score = np.max(trellis[-1])
    346   return viterbi, viterbi_score
    347 
    348 
    349 class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
    350   """Computes the forward decoding in a linear-chain CRF.
    351   """
    352 
    353   def __init__(self, transition_params):
    354     """Initialize the CrfDecodeForwardRnnCell.
    355 
    356     Args:
    357       transition_params: A [num_tags, num_tags] matrix of binary
    358         potentials. This matrix is expanded into a
    359         [1, num_tags, num_tags] in preparation for the broadcast
    360         summation occurring within the cell.
    361     """
    362     self._transition_params = array_ops.expand_dims(transition_params, 0)
    363     self._num_tags = transition_params.get_shape()[0].value
    364 
    365   @property
    366   def state_size(self):
    367     return self._num_tags
    368 
    369   @property
    370   def output_size(self):
    371     return self._num_tags
    372 
    373   def __call__(self, inputs, state, scope=None):
    374     """Build the CrfDecodeForwardRnnCell.
    375 
    376     Args:
    377       inputs: A [batch_size, num_tags] matrix of unary potentials.
    378       state: A [batch_size, num_tags] matrix containing the previous step's
    379             score values.
    380       scope: Unused variable scope of this cell.
    381 
    382     Returns:
    383       backpointers: A [batch_size, num_tags] matrix of backpointers.
    384       new_state: A [batch_size, num_tags] matrix of new score values.
    385     """
    386     # For simplicity, in shape comments, denote:
    387     # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
    388     state = array_ops.expand_dims(state, 2)                         # [B, O, 1]
    389 
    390     # This addition op broadcasts self._transitions_params along the zeroth
    391     # dimension and state along the second dimension.
    392     # [B, O, 1] + [1, O, O] -> [B, O, O]
    393     transition_scores = state + self._transition_params             # [B, O, O]
    394     new_state = inputs + math_ops.reduce_max(transition_scores, [1])  # [B, O]
    395     backpointers = math_ops.argmax(transition_scores, 1)
    396     backpointers = math_ops.cast(backpointers, dtype=dtypes.int32)    # [B, O]
    397     return backpointers, new_state
    398 
    399 
    400 class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
    401   """Computes backward decoding in a linear-chain CRF.
    402   """
    403 
    404   def __init__(self, num_tags):
    405     """Initialize the CrfDecodeBackwardRnnCell.
    406 
    407     Args:
    408       num_tags: An integer. The number of tags.
    409     """
    410     self._num_tags = num_tags
    411 
    412   @property
    413   def state_size(self):
    414     return 1
    415 
    416   @property
    417   def output_size(self):
    418     return 1
    419 
    420   def __call__(self, inputs, state, scope=None):
    421     """Build the CrfDecodeBackwardRnnCell.
    422 
    423     Args:
    424       inputs: A [batch_size, num_tags] matrix of
    425             backpointer of next step (in time order).
    426       state: A [batch_size, 1] matrix of tag index of next step.
    427       scope: Unused variable scope of this cell.
    428 
    429     Returns:
    430       new_tags, new_tags: A pair of [batch_size, num_tags]
    431         tensors containing the new tag indices.
    432     """
    433     state = array_ops.squeeze(state, axis=[1])                # [B]
    434     batch_size = array_ops.shape(inputs)[0]
    435     b_indices = math_ops.range(batch_size)                    # [B]
    436     indices = array_ops.stack([b_indices, state], axis=1)     # [B, 2]
    437     new_tags = array_ops.expand_dims(
    438         gen_array_ops.gather_nd(inputs, indices),             # [B]
    439         axis=-1)                                              # [B, 1]
    440 
    441     return new_tags, new_tags
    442 
    443 
    444 def crf_decode(potentials, transition_params, sequence_length):
    445   """Decode the highest scoring sequence of tags in TensorFlow.
    446 
    447   This is a function for tensor.
    448 
    449   Args:
    450     potentials: A [batch_size, max_seq_len, num_tags] tensor of
    451               unary potentials.
    452     transition_params: A [num_tags, num_tags] matrix of
    453               binary potentials.
    454     sequence_length: A [batch_size] vector of true sequence lengths.
    455 
    456   Returns:
    457     decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
    458                 Contains the highest scoring tag indices.
    459     best_score: A [batch_size] vector, containing the score of `decode_tags`.
    460   """
    461   # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
    462   # and the max activation.
    463   def _single_seq_fn():
    464     squeezed_potentials = array_ops.squeeze(potentials, [1])
    465     decode_tags = array_ops.expand_dims(
    466         math_ops.argmax(squeezed_potentials, axis=1), 1)
    467     best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
    468     return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
    469 
    470   def _multi_seq_fn():
    471     """Decoding of highest scoring sequence."""
    472 
    473     # For simplicity, in shape comments, denote:
    474     # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
    475     num_tags = potentials.get_shape()[2].value
    476 
    477     # Computes forward decoding. Get last score and backpointers.
    478     crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
    479     initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
    480     initial_state = array_ops.squeeze(initial_state, axis=[1])  # [B, O]
    481     inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1])  # [B, T-1, O]
    482     backpointers, last_score = rnn.dynamic_rnn(  # [B, T - 1, O], [B, O]
    483         crf_fwd_cell,
    484         inputs=inputs,
    485         sequence_length=sequence_length - 1,
    486         initial_state=initial_state,
    487         time_major=False,
    488         dtype=dtypes.int32)
    489     backpointers = gen_array_ops.reverse_sequence(  # [B, T - 1, O]
    490         backpointers, sequence_length - 1, seq_dim=1)
    491 
    492     # Computes backward decoding. Extract tag indices from backpointers.
    493     crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
    494     initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),  # [B]
    495                                   dtype=dtypes.int32)
    496     initial_state = array_ops.expand_dims(initial_state, axis=-1)  # [B, 1]
    497     decode_tags, _ = rnn.dynamic_rnn(  # [B, T - 1, 1]
    498         crf_bwd_cell,
    499         inputs=backpointers,
    500         sequence_length=sequence_length - 1,
    501         initial_state=initial_state,
    502         time_major=False,
    503         dtype=dtypes.int32)
    504     decode_tags = array_ops.squeeze(decode_tags, axis=[2])  # [B, T - 1]
    505     decode_tags = array_ops.concat([initial_state, decode_tags],   # [B, T]
    506                                    axis=1)
    507     decode_tags = gen_array_ops.reverse_sequence(  # [B, T]
    508         decode_tags, sequence_length, seq_dim=1)
    509 
    510     best_score = math_ops.reduce_max(last_score, axis=1)  # [B]
    511     return decode_tags, best_score
    512 
    513   return utils.smart_cond(
    514       pred=math_ops.equal(
    515           potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
    516       true_fn=_single_seq_fn,
    517       false_fn=_multi_seq_fn)
    518