1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Module for constructing a linear-chain CRF. 16 17 The following snippet is an example of a CRF layer on top of a batched sequence 18 of unary scores (logits for every word). This example also decodes the most 19 likely sequence at test time. There are two ways to do decoding. One 20 is using crf_decode to do decoding in Tensorflow , and the other one is using 21 viterbi_decode in Numpy. 22 23 log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood( 24 unary_scores, gold_tags, sequence_lengths) 25 26 loss = tf.reduce_mean(-log_likelihood) 27 train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) 28 29 # Decoding in Tensorflow. 30 viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode( 31 unary_scores, transition_params, sequence_lengths) 32 33 tf_viterbi_sequence, tf_viterbi_score, _ = session.run( 34 [viterbi_sequence, viterbi_score, train_op]) 35 36 # Decoding in Numpy. 37 tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( 38 [unary_scores, sequence_lengths, transition_params, train_op]) 39 for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, 40 tf_sequence_lengths): 41 # Remove padding. 42 tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] 43 44 # Compute the highest score and its tag sequence. 45 tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( 46 tf_unary_scores_, tf_transition_params) 47 """ 48 49 from __future__ import absolute_import 50 from __future__ import division 51 from __future__ import print_function 52 53 import numpy as np 54 55 from tensorflow.python.framework import dtypes 56 from tensorflow.python.layers import utils 57 from tensorflow.python.ops import array_ops 58 from tensorflow.python.ops import control_flow_ops 59 from tensorflow.python.ops import gen_array_ops 60 from tensorflow.python.ops import math_ops 61 from tensorflow.python.ops import rnn 62 from tensorflow.python.ops import rnn_cell 63 from tensorflow.python.ops import variable_scope as vs 64 65 __all__ = [ 66 "crf_sequence_score", "crf_log_norm", "crf_log_likelihood", 67 "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", 68 "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell", 69 "CrfDecodeBackwardRnnCell" 70 ] 71 72 73 def crf_sequence_score(inputs, tag_indices, sequence_lengths, 74 transition_params): 75 """Computes the unnormalized score for a tag sequence. 76 77 Args: 78 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 79 to use as input to the CRF layer. 80 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we 81 compute the unnormalized score. 82 sequence_lengths: A [batch_size] vector of true sequence lengths. 83 transition_params: A [num_tags, num_tags] transition matrix. 84 Returns: 85 sequence_scores: A [batch_size] vector of unnormalized sequence scores. 86 """ 87 # If max_seq_len is 1, we skip the score calculation and simply gather the 88 # unary potentials of the single tag. 89 def _single_seq_fn(): 90 batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] 91 example_inds = array_ops.reshape( 92 math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) 93 return array_ops.gather_nd( 94 array_ops.squeeze(inputs, [1]), 95 array_ops.concat([example_inds, tag_indices], axis=1)) 96 97 def _multi_seq_fn(): 98 # Compute the scores of the given tag sequence. 99 unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) 100 binary_scores = crf_binary_score(tag_indices, sequence_lengths, 101 transition_params) 102 sequence_scores = unary_scores + binary_scores 103 return sequence_scores 104 105 return utils.smart_cond( 106 pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], 107 1), 108 true_fn=_single_seq_fn, 109 false_fn=_multi_seq_fn) 110 111 112 def crf_log_norm(inputs, sequence_lengths, transition_params): 113 """Computes the normalization for a CRF. 114 115 Args: 116 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 117 to use as input to the CRF layer. 118 sequence_lengths: A [batch_size] vector of true sequence lengths. 119 transition_params: A [num_tags, num_tags] transition matrix. 120 Returns: 121 log_norm: A [batch_size] vector of normalizers for a CRF. 122 """ 123 # Split up the first and rest of the inputs in preparation for the forward 124 # algorithm. 125 first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) 126 first_input = array_ops.squeeze(first_input, [1]) 127 128 # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over 129 # the "initial state" (the unary potentials). 130 def _single_seq_fn(): 131 return math_ops.reduce_logsumexp(first_input, [1]) 132 133 def _multi_seq_fn(): 134 """Forward computation of alpha values.""" 135 rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) 136 137 # Compute the alpha values in the forward algorithm in order to get the 138 # partition function. 139 forward_cell = CrfForwardRnnCell(transition_params) 140 _, alphas = rnn.dynamic_rnn( 141 cell=forward_cell, 142 inputs=rest_of_input, 143 sequence_length=sequence_lengths - 1, 144 initial_state=first_input, 145 dtype=dtypes.float32) 146 log_norm = math_ops.reduce_logsumexp(alphas, [1]) 147 return log_norm 148 149 max_seq_len = array_ops.shape(inputs)[1] 150 return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), 151 true_fn=_single_seq_fn, 152 false_fn=_multi_seq_fn) 153 154 155 def crf_log_likelihood(inputs, 156 tag_indices, 157 sequence_lengths, 158 transition_params=None): 159 """Computes the log-likelihood of tag sequences in a CRF. 160 161 Args: 162 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials 163 to use as input to the CRF layer. 164 tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we 165 compute the log-likelihood. 166 sequence_lengths: A [batch_size] vector of true sequence lengths. 167 transition_params: A [num_tags, num_tags] transition matrix, if available. 168 Returns: 169 log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of 170 each example, given the sequence of tag indices. 171 transition_params: A [num_tags, num_tags] transition matrix. This is either 172 provided by the caller or created in this function. 173 """ 174 # Get shape information. 175 num_tags = inputs.get_shape()[2].value 176 177 # Get the transition matrix if not provided. 178 if transition_params is None: 179 transition_params = vs.get_variable("transitions", [num_tags, num_tags]) 180 181 sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths, 182 transition_params) 183 log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) 184 185 # Normalize the scores to get the log-likelihood per example. 186 log_likelihood = sequence_scores - log_norm 187 return log_likelihood, transition_params 188 189 190 def crf_unary_score(tag_indices, sequence_lengths, inputs): 191 """Computes the unary scores of tag sequences. 192 193 Args: 194 tag_indices: A [batch_size, max_seq_len] matrix of tag indices. 195 sequence_lengths: A [batch_size] vector of true sequence lengths. 196 inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. 197 Returns: 198 unary_scores: A [batch_size] vector of unary scores. 199 """ 200 batch_size = array_ops.shape(inputs)[0] 201 max_seq_len = array_ops.shape(inputs)[1] 202 num_tags = array_ops.shape(inputs)[2] 203 204 flattened_inputs = array_ops.reshape(inputs, [-1]) 205 206 offsets = array_ops.expand_dims( 207 math_ops.range(batch_size) * max_seq_len * num_tags, 1) 208 offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) 209 # Use int32 or int64 based on tag_indices' dtype. 210 if tag_indices.dtype == dtypes.int64: 211 offsets = math_ops.to_int64(offsets) 212 flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) 213 214 unary_scores = array_ops.reshape( 215 array_ops.gather(flattened_inputs, flattened_tag_indices), 216 [batch_size, max_seq_len]) 217 218 masks = array_ops.sequence_mask(sequence_lengths, 219 maxlen=array_ops.shape(tag_indices)[1], 220 dtype=dtypes.float32) 221 222 unary_scores = math_ops.reduce_sum(unary_scores * masks, 1) 223 return unary_scores 224 225 226 def crf_binary_score(tag_indices, sequence_lengths, transition_params): 227 """Computes the binary scores of tag sequences. 228 229 Args: 230 tag_indices: A [batch_size, max_seq_len] matrix of tag indices. 231 sequence_lengths: A [batch_size] vector of true sequence lengths. 232 transition_params: A [num_tags, num_tags] matrix of binary potentials. 233 Returns: 234 binary_scores: A [batch_size] vector of binary scores. 235 """ 236 # Get shape information. 237 num_tags = transition_params.get_shape()[0] 238 num_transitions = array_ops.shape(tag_indices)[1] - 1 239 240 # Truncate by one on each side of the sequence to get the start and end 241 # indices of each transition. 242 start_tag_indices = array_ops.slice(tag_indices, [0, 0], 243 [-1, num_transitions]) 244 end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions]) 245 246 # Encode the indices in a flattened representation. 247 flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices 248 flattened_transition_params = array_ops.reshape(transition_params, [-1]) 249 250 # Get the binary scores based on the flattened representation. 251 binary_scores = array_ops.gather(flattened_transition_params, 252 flattened_transition_indices) 253 254 masks = array_ops.sequence_mask(sequence_lengths, 255 maxlen=array_ops.shape(tag_indices)[1], 256 dtype=dtypes.float32) 257 truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1]) 258 binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1) 259 return binary_scores 260 261 262 class CrfForwardRnnCell(rnn_cell.RNNCell): 263 """Computes the alpha values in a linear-chain CRF. 264 265 See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. 266 """ 267 268 def __init__(self, transition_params): 269 """Initialize the CrfForwardRnnCell. 270 271 Args: 272 transition_params: A [num_tags, num_tags] matrix of binary potentials. 273 This matrix is expanded into a [1, num_tags, num_tags] in preparation 274 for the broadcast summation occurring within the cell. 275 """ 276 self._transition_params = array_ops.expand_dims(transition_params, 0) 277 self._num_tags = transition_params.get_shape()[0].value 278 279 @property 280 def state_size(self): 281 return self._num_tags 282 283 @property 284 def output_size(self): 285 return self._num_tags 286 287 def __call__(self, inputs, state, scope=None): 288 """Build the CrfForwardRnnCell. 289 290 Args: 291 inputs: A [batch_size, num_tags] matrix of unary potentials. 292 state: A [batch_size, num_tags] matrix containing the previous alpha 293 values. 294 scope: Unused variable scope of this cell. 295 296 Returns: 297 new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices 298 values containing the new alpha values. 299 """ 300 state = array_ops.expand_dims(state, 2) 301 302 # This addition op broadcasts self._transitions_params along the zeroth 303 # dimension and state along the second dimension. This performs the 304 # multiplication of previous alpha values and the current binary potentials 305 # in log space. 306 transition_scores = state + self._transition_params 307 new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1]) 308 309 # Both the state and the output of this RNN cell contain the alphas values. 310 # The output value is currently unused and simply satisfies the RNN API. 311 # This could be useful in the future if we need to compute marginal 312 # probabilities, which would require the accumulated alpha values at every 313 # time step. 314 return new_alphas, new_alphas 315 316 317 def viterbi_decode(score, transition_params): 318 """Decode the highest scoring sequence of tags outside of TensorFlow. 319 320 This should only be used at test time. 321 322 Args: 323 score: A [seq_len, num_tags] matrix of unary potentials. 324 transition_params: A [num_tags, num_tags] matrix of binary potentials. 325 326 Returns: 327 viterbi: A [seq_len] list of integers containing the highest scoring tag 328 indices. 329 viterbi_score: A float containing the score for the Viterbi sequence. 330 """ 331 trellis = np.zeros_like(score) 332 backpointers = np.zeros_like(score, dtype=np.int32) 333 trellis[0] = score[0] 334 335 for t in range(1, score.shape[0]): 336 v = np.expand_dims(trellis[t - 1], 1) + transition_params 337 trellis[t] = score[t] + np.max(v, 0) 338 backpointers[t] = np.argmax(v, 0) 339 340 viterbi = [np.argmax(trellis[-1])] 341 for bp in reversed(backpointers[1:]): 342 viterbi.append(bp[viterbi[-1]]) 343 viterbi.reverse() 344 345 viterbi_score = np.max(trellis[-1]) 346 return viterbi, viterbi_score 347 348 349 class CrfDecodeForwardRnnCell(rnn_cell.RNNCell): 350 """Computes the forward decoding in a linear-chain CRF. 351 """ 352 353 def __init__(self, transition_params): 354 """Initialize the CrfDecodeForwardRnnCell. 355 356 Args: 357 transition_params: A [num_tags, num_tags] matrix of binary 358 potentials. This matrix is expanded into a 359 [1, num_tags, num_tags] in preparation for the broadcast 360 summation occurring within the cell. 361 """ 362 self._transition_params = array_ops.expand_dims(transition_params, 0) 363 self._num_tags = transition_params.get_shape()[0].value 364 365 @property 366 def state_size(self): 367 return self._num_tags 368 369 @property 370 def output_size(self): 371 return self._num_tags 372 373 def __call__(self, inputs, state, scope=None): 374 """Build the CrfDecodeForwardRnnCell. 375 376 Args: 377 inputs: A [batch_size, num_tags] matrix of unary potentials. 378 state: A [batch_size, num_tags] matrix containing the previous step's 379 score values. 380 scope: Unused variable scope of this cell. 381 382 Returns: 383 backpointers: A [batch_size, num_tags] matrix of backpointers. 384 new_state: A [batch_size, num_tags] matrix of new score values. 385 """ 386 # For simplicity, in shape comments, denote: 387 # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). 388 state = array_ops.expand_dims(state, 2) # [B, O, 1] 389 390 # This addition op broadcasts self._transitions_params along the zeroth 391 # dimension and state along the second dimension. 392 # [B, O, 1] + [1, O, O] -> [B, O, O] 393 transition_scores = state + self._transition_params # [B, O, O] 394 new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O] 395 backpointers = math_ops.argmax(transition_scores, 1) 396 backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O] 397 return backpointers, new_state 398 399 400 class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): 401 """Computes backward decoding in a linear-chain CRF. 402 """ 403 404 def __init__(self, num_tags): 405 """Initialize the CrfDecodeBackwardRnnCell. 406 407 Args: 408 num_tags: An integer. The number of tags. 409 """ 410 self._num_tags = num_tags 411 412 @property 413 def state_size(self): 414 return 1 415 416 @property 417 def output_size(self): 418 return 1 419 420 def __call__(self, inputs, state, scope=None): 421 """Build the CrfDecodeBackwardRnnCell. 422 423 Args: 424 inputs: A [batch_size, num_tags] matrix of 425 backpointer of next step (in time order). 426 state: A [batch_size, 1] matrix of tag index of next step. 427 scope: Unused variable scope of this cell. 428 429 Returns: 430 new_tags, new_tags: A pair of [batch_size, num_tags] 431 tensors containing the new tag indices. 432 """ 433 state = array_ops.squeeze(state, axis=[1]) # [B] 434 batch_size = array_ops.shape(inputs)[0] 435 b_indices = math_ops.range(batch_size) # [B] 436 indices = array_ops.stack([b_indices, state], axis=1) # [B, 2] 437 new_tags = array_ops.expand_dims( 438 gen_array_ops.gather_nd(inputs, indices), # [B] 439 axis=-1) # [B, 1] 440 441 return new_tags, new_tags 442 443 444 def crf_decode(potentials, transition_params, sequence_length): 445 """Decode the highest scoring sequence of tags in TensorFlow. 446 447 This is a function for tensor. 448 449 Args: 450 potentials: A [batch_size, max_seq_len, num_tags] tensor of 451 unary potentials. 452 transition_params: A [num_tags, num_tags] matrix of 453 binary potentials. 454 sequence_length: A [batch_size] vector of true sequence lengths. 455 456 Returns: 457 decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. 458 Contains the highest scoring tag indices. 459 best_score: A [batch_size] vector, containing the score of `decode_tags`. 460 """ 461 # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag 462 # and the max activation. 463 def _single_seq_fn(): 464 squeezed_potentials = array_ops.squeeze(potentials, [1]) 465 decode_tags = array_ops.expand_dims( 466 math_ops.argmax(squeezed_potentials, axis=1), 1) 467 best_score = math_ops.reduce_max(squeezed_potentials, axis=1) 468 return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score 469 470 def _multi_seq_fn(): 471 """Decoding of highest scoring sequence.""" 472 473 # For simplicity, in shape comments, denote: 474 # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). 475 num_tags = potentials.get_shape()[2].value 476 477 # Computes forward decoding. Get last score and backpointers. 478 crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) 479 initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) 480 initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] 481 inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] 482 backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] 483 crf_fwd_cell, 484 inputs=inputs, 485 sequence_length=sequence_length - 1, 486 initial_state=initial_state, 487 time_major=False, 488 dtype=dtypes.int32) 489 backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] 490 backpointers, sequence_length - 1, seq_dim=1) 491 492 # Computes backward decoding. Extract tag indices from backpointers. 493 crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) 494 initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] 495 dtype=dtypes.int32) 496 initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] 497 decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] 498 crf_bwd_cell, 499 inputs=backpointers, 500 sequence_length=sequence_length - 1, 501 initial_state=initial_state, 502 time_major=False, 503 dtype=dtypes.int32) 504 decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] 505 decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] 506 axis=1) 507 decode_tags = gen_array_ops.reverse_sequence( # [B, T] 508 decode_tags, sequence_length, seq_dim=1) 509 510 best_score = math_ops.reduce_max(last_score, axis=1) # [B] 511 return decode_tags, best_score 512 513 return utils.smart_cond( 514 pred=math_ops.equal( 515 potentials.shape[1].value or array_ops.shape(potentials)[1], 1), 516 true_fn=_single_seq_fn, 517 false_fn=_multi_seq_fn) 518