1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Estimator for Dynamic RNNs.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib import layers 22 from tensorflow.contrib.layers.python.layers import optimizers 23 from tensorflow.contrib.learn.python.learn.estimators import constants 24 from tensorflow.contrib.learn.python.learn.estimators import estimator 25 from tensorflow.contrib.learn.python.learn.estimators import model_fn 26 from tensorflow.contrib.learn.python.learn.estimators import prediction_key 27 from tensorflow.contrib.learn.python.learn.estimators import rnn_common 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import check_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import rnn 34 from tensorflow.python.training import momentum as momentum_opt 35 from tensorflow.python.util import nest 36 37 38 # TODO(jtbates): Remove PredictionType when all non-experimental targets which 39 # depend on it point to rnn_common.PredictionType. 40 class PredictionType(object): 41 SINGLE_VALUE = 1 42 MULTIPLE_VALUE = 2 43 44 45 def _get_state_name(i): 46 """Constructs the name string for state component `i`.""" 47 return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) 48 49 50 def state_tuple_to_dict(state): 51 """Returns a dict containing flattened `state`. 52 53 Args: 54 state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must 55 have the same rank and agree on all dimensions except the last. 56 57 Returns: 58 A dict containing the `Tensor`s that make up `state`. The keys of the dict 59 are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor` 60 in a depth-first traversal of `state`. 61 """ 62 with ops.name_scope('state_tuple_to_dict'): 63 flat_state = nest.flatten(state) 64 state_dict = {} 65 for i, state_component in enumerate(flat_state): 66 state_name = _get_state_name(i) 67 state_value = (None if state_component is None 68 else array_ops.identity(state_component, name=state_name)) 69 state_dict[state_name] = state_value 70 return state_dict 71 72 73 def dict_to_state_tuple(input_dict, cell): 74 """Reconstructs nested `state` from a dict containing state `Tensor`s. 75 76 Args: 77 input_dict: A dict of `Tensor`s. 78 cell: An instance of `RNNCell`. 79 Returns: 80 If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n` 81 where `n` is the number of nested entries in `cell.state_size`, this 82 function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size` 83 is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested 84 tuple. 85 Raises: 86 ValueError: State is partially specified. The `input_dict` must contain 87 values for all state components or none at all. 88 """ 89 flat_state_sizes = nest.flatten(cell.state_size) 90 state_tensors = [] 91 with ops.name_scope('dict_to_state_tuple'): 92 for i, state_size in enumerate(flat_state_sizes): 93 state_name = _get_state_name(i) 94 state_tensor = input_dict.get(state_name) 95 if state_tensor is not None: 96 rank_check = check_ops.assert_rank( 97 state_tensor, 2, name='check_state_{}_rank'.format(i)) 98 shape_check = check_ops.assert_equal( 99 array_ops.shape(state_tensor)[1], 100 state_size, 101 name='check_state_{}_shape'.format(i)) 102 with ops.control_dependencies([rank_check, shape_check]): 103 state_tensor = array_ops.identity(state_tensor, name=state_name) 104 state_tensors.append(state_tensor) 105 if not state_tensors: 106 return None 107 elif len(state_tensors) == len(flat_state_sizes): 108 dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool) 109 return nest.pack_sequence_as(dummy_state, state_tensors) 110 else: 111 raise ValueError( 112 'RNN state was partially specified.' 113 'Expected zero or {} state Tensors; got {}'. 114 format(len(flat_state_sizes), len(state_tensors))) 115 116 117 def _concatenate_context_input(sequence_input, context_input): 118 """Replicates `context_input` across all timesteps of `sequence_input`. 119 120 Expands dimension 1 of `context_input` then tiles it `sequence_length` times. 121 This value is appended to `sequence_input` on dimension 2 and the result is 122 returned. 123 124 Args: 125 sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, 126 padded_length, d0]`. 127 context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. 128 129 Returns: 130 A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, 131 d0 + d1]`. 132 133 Raises: 134 ValueError: If `sequence_input` does not have rank 3 or `context_input` does 135 not have rank 2. 136 """ 137 seq_rank_check = check_ops.assert_rank( 138 sequence_input, 139 3, 140 message='sequence_input must have rank 3', 141 data=[array_ops.shape(sequence_input)]) 142 seq_type_check = check_ops.assert_type( 143 sequence_input, 144 dtypes.float32, 145 message='sequence_input must have dtype float32; got {}.'.format( 146 sequence_input.dtype)) 147 ctx_rank_check = check_ops.assert_rank( 148 context_input, 149 2, 150 message='context_input must have rank 2', 151 data=[array_ops.shape(context_input)]) 152 ctx_type_check = check_ops.assert_type( 153 context_input, 154 dtypes.float32, 155 message='context_input must have dtype float32; got {}.'.format( 156 context_input.dtype)) 157 with ops.control_dependencies( 158 [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): 159 padded_length = array_ops.shape(sequence_input)[1] 160 tiled_context_input = array_ops.tile( 161 array_ops.expand_dims(context_input, 1), 162 array_ops.concat([[1], [padded_length], [1]], 0)) 163 return array_ops.concat([sequence_input, tiled_context_input], 2) 164 165 166 def build_sequence_input(features, 167 sequence_feature_columns, 168 context_feature_columns, 169 weight_collections=None, 170 scope=None): 171 """Combine sequence and context features into input for an RNN. 172 173 Args: 174 features: A `dict` containing the input and (optionally) sequence length 175 information and initial state. 176 sequence_feature_columns: An iterable containing all the feature columns 177 describing sequence features. All items in the set should be instances 178 of classes derived from `FeatureColumn`. 179 context_feature_columns: An iterable containing all the feature columns 180 describing context features i.e. features that apply across all time 181 steps. All items in the set should be instances of classes derived from 182 `FeatureColumn`. 183 weight_collections: List of graph collections to which weights are added. 184 scope: Optional scope, passed through to parsing ops. 185 Returns: 186 A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`. 187 This will be used as input to an RNN. 188 """ 189 features = features.copy() 190 features.update(layers.transform_features( 191 features, 192 list(sequence_feature_columns) + list(context_feature_columns or []))) 193 sequence_input = layers.sequence_input_from_feature_columns( 194 columns_to_tensors=features, 195 feature_columns=sequence_feature_columns, 196 weight_collections=weight_collections, 197 scope=scope) 198 if context_feature_columns is not None: 199 context_input = layers.input_from_feature_columns( 200 columns_to_tensors=features, 201 feature_columns=context_feature_columns, 202 weight_collections=weight_collections, 203 scope=scope) 204 sequence_input = _concatenate_context_input(sequence_input, context_input) 205 return sequence_input 206 207 208 def construct_rnn(initial_state, 209 sequence_input, 210 cell, 211 num_label_columns, 212 dtype=dtypes.float32, 213 parallel_iterations=32, 214 swap_memory=True): 215 """Build an RNN and apply a fully connected layer to get the desired output. 216 217 Args: 218 initial_state: The initial state to pass the RNN. If `None`, the 219 default starting state for `self._cell` is used. 220 sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]` 221 that will be passed as input to the RNN. 222 cell: An initialized `RNNCell`. 223 num_label_columns: The desired output dimension. 224 dtype: dtype of `cell`. 225 parallel_iterations: Number of iterations to run in parallel. Values >> 1 226 use more memory but take less time, while smaller values use less memory 227 but computations take longer. 228 swap_memory: Transparently swap the tensors produced in forward inference 229 but needed for back prop from GPU to CPU. This allows training RNNs 230 which would typically not fit on a single GPU, with very minimal (or no) 231 performance penalty. 232 Returns: 233 activations: The output of the RNN, projected to `num_label_columns` 234 dimensions. 235 final_state: A `Tensor` or nested tuple of `Tensor`s representing the final 236 state output by the RNN. 237 """ 238 with ops.name_scope('RNN'): 239 rnn_outputs, final_state = rnn.dynamic_rnn( 240 cell=cell, 241 inputs=sequence_input, 242 initial_state=initial_state, 243 dtype=dtype, 244 parallel_iterations=parallel_iterations, 245 swap_memory=swap_memory, 246 time_major=False) 247 activations = layers.fully_connected( 248 inputs=rnn_outputs, 249 num_outputs=num_label_columns, 250 activation_fn=None, 251 trainable=True) 252 return activations, final_state 253 254 255 def _single_value_predictions(activations, 256 sequence_length, 257 target_column, 258 problem_type, 259 predict_probabilities): 260 """Maps `activations` from the RNN to predictions for single value models. 261 262 If `predict_probabilities` is `False`, this function returns a `dict` 263 containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities` 264 is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The 265 value of this entry is a `Tensor` of probabilities with shape 266 `[batch_size, num_classes]`. 267 268 Args: 269 activations: Output from an RNN. Should have dtype `float32` and shape 270 `[batch_size, padded_length, ?]`. 271 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 272 containing the length of each sequence in the batch. If `None`, sequences 273 are assumed to be unpadded. 274 target_column: An initialized `TargetColumn`, calculate predictions. 275 problem_type: Either `ProblemType.CLASSIFICATION` or 276 `ProblemType.LINEAR_REGRESSION`. 277 predict_probabilities: A Python boolean, indicating whether probabilities 278 should be returned. Should only be set to `True` for 279 classification/logistic regression problems. 280 Returns: 281 A `dict` mapping strings to `Tensors`. 282 """ 283 with ops.name_scope('SingleValuePrediction'): 284 last_activations = rnn_common.select_last_activations( 285 activations, sequence_length) 286 predictions_name = (prediction_key.PredictionKey.CLASSES 287 if problem_type == constants.ProblemType.CLASSIFICATION 288 else prediction_key.PredictionKey.SCORES) 289 if predict_probabilities: 290 probabilities = target_column.logits_to_predictions( 291 last_activations, proba=True) 292 prediction_dict = { 293 prediction_key.PredictionKey.PROBABILITIES: probabilities, 294 predictions_name: math_ops.argmax(probabilities, 1)} 295 else: 296 predictions = target_column.logits_to_predictions( 297 last_activations, proba=False) 298 prediction_dict = {predictions_name: predictions} 299 return prediction_dict 300 301 302 def _multi_value_loss( 303 activations, labels, sequence_length, target_column, features): 304 """Maps `activations` from the RNN to loss for multi value models. 305 306 Args: 307 activations: Output from an RNN. Should have dtype `float32` and shape 308 `[batch_size, padded_length, ?]`. 309 labels: A `Tensor` with length `[batch_size, padded_length]`. 310 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 311 containing the length of each sequence in the batch. If `None`, sequences 312 are assumed to be unpadded. 313 target_column: An initialized `TargetColumn`, calculate predictions. 314 features: A `dict` containing the input and (optionally) sequence length 315 information and initial state. 316 Returns: 317 A scalar `Tensor` containing the loss. 318 """ 319 with ops.name_scope('MultiValueLoss'): 320 activations_masked, labels_masked = rnn_common.mask_activations_and_labels( 321 activations, labels, sequence_length) 322 return target_column.loss(activations_masked, labels_masked, features) 323 324 325 def _single_value_loss( 326 activations, labels, sequence_length, target_column, features): 327 """Maps `activations` from the RNN to loss for multi value models. 328 329 Args: 330 activations: Output from an RNN. Should have dtype `float32` and shape 331 `[batch_size, padded_length, ?]`. 332 labels: A `Tensor` with length `[batch_size]`. 333 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 334 containing the length of each sequence in the batch. If `None`, sequences 335 are assumed to be unpadded. 336 target_column: An initialized `TargetColumn`, calculate predictions. 337 features: A `dict` containing the input and (optionally) sequence length 338 information and initial state. 339 Returns: 340 A scalar `Tensor` containing the loss. 341 """ 342 343 with ops.name_scope('SingleValueLoss'): 344 last_activations = rnn_common.select_last_activations( 345 activations, sequence_length) 346 return target_column.loss(last_activations, labels, features) 347 348 349 def _get_output_alternatives(prediction_type, 350 problem_type, 351 prediction_dict): 352 """Constructs output alternatives dict for `ModelFnOps`. 353 354 Args: 355 prediction_type: either `MULTIPLE_VALUE` or `SINGLE_VALUE`. 356 problem_type: either `CLASSIFICATION` or `LINEAR_REGRESSION`. 357 prediction_dict: a dictionary mapping strings to `Tensor`s containing 358 predictions. 359 360 Returns: 361 `None` or a dictionary mapping a string to an output alternative. 362 363 Raises: 364 ValueError: `prediction_type` is not one of `SINGLE_VALUE` or 365 `MULTIPLE_VALUE`. 366 """ 367 if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: 368 return None 369 if prediction_type == rnn_common.PredictionType.SINGLE_VALUE: 370 prediction_dict_no_state = { 371 k: v 372 for k, v in prediction_dict.items() 373 if rnn_common.RNNKeys.STATE_PREFIX not in k 374 } 375 return {'dynamic_rnn_output': (problem_type, prediction_dict_no_state)} 376 raise ValueError('Unrecognized prediction_type: {}'.format(prediction_type)) 377 378 379 def _get_dynamic_rnn_model_fn( 380 cell_type, 381 num_units, 382 target_column, 383 problem_type, 384 prediction_type, 385 optimizer, 386 sequence_feature_columns, 387 context_feature_columns=None, 388 predict_probabilities=False, 389 learning_rate=None, 390 gradient_clipping_norm=None, 391 dropout_keep_probabilities=None, 392 sequence_length_key=rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY, 393 dtype=dtypes.float32, 394 parallel_iterations=None, 395 swap_memory=True, 396 name='DynamicRNNModel'): 397 """Creates an RNN model function for an `Estimator`. 398 399 The model function returns an instance of `ModelFnOps`. When 400 `problem_type == ProblemType.CLASSIFICATION` and 401 `predict_probabilities == True`, the returned `ModelFnOps` includes an output 402 alternative containing the classes and their associated probabilities. When 403 `predict_probabilities == False`, only the classes are included. When 404 `problem_type == ProblemType.LINEAR_REGRESSION`, the output alternative 405 contains only the predicted values. 406 407 Args: 408 cell_type: A string, a subclass of `RNNCell` or an instance of an `RNNCell`. 409 num_units: A single `int` or a list of `int`s. The size of the `RNNCell`s. 410 target_column: An initialized `TargetColumn`, used to calculate prediction 411 and loss. 412 problem_type: `ProblemType.CLASSIFICATION` or 413 `ProblemType.LINEAR_REGRESSION`. 414 prediction_type: `PredictionType.SINGLE_VALUE` or 415 `PredictionType.MULTIPLE_VALUE`. 416 optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a 417 string. 418 sequence_feature_columns: An iterable containing all the feature columns 419 describing sequence features. All items in the set should be instances 420 of classes derived from `FeatureColumn`. 421 context_feature_columns: An iterable containing all the feature columns 422 describing context features, i.e., features that apply across all time 423 steps. All items in the set should be instances of classes derived from 424 `FeatureColumn`. 425 predict_probabilities: A boolean indicating whether to predict probabilities 426 for all classes. Must only be used with 427 `ProblemType.CLASSIFICATION`. 428 learning_rate: Learning rate used for optimization. This argument has no 429 effect if `optimizer` is an instance of an `Optimizer`. 430 gradient_clipping_norm: A float. Gradients will be clipped to this value. 431 dropout_keep_probabilities: a list of dropout keep probabilities or `None`. 432 If a list is given, it must have length `len(num_units) + 1`. 433 sequence_length_key: The key that will be used to look up sequence length in 434 the `features` dict. 435 dtype: The dtype of the state and output of the given `cell`. 436 parallel_iterations: Number of iterations to run in parallel. Values >> 1 437 use more memory but take less time, while smaller values use less memory 438 but computations take longer. 439 swap_memory: Transparently swap the tensors produced in forward inference 440 but needed for back prop from GPU to CPU. This allows training RNNs 441 which would typically not fit on a single GPU, with very minimal (or no) 442 performance penalty. 443 name: A string that will be used to create a scope for the RNN. 444 445 Returns: 446 A model function to be passed to an `Estimator`. 447 448 Raises: 449 ValueError: `problem_type` is not one of 450 `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. 451 ValueError: `prediction_type` is not one of `PredictionType.SINGLE_VALUE` 452 or `PredictionType.MULTIPLE_VALUE`. 453 ValueError: `predict_probabilities` is `True` for `problem_type` other 454 than `ProblemType.CLASSIFICATION`. 455 ValueError: `len(dropout_keep_probabilities)` is not `len(num_units) + 1`. 456 """ 457 if problem_type not in (constants.ProblemType.CLASSIFICATION, 458 constants.ProblemType.LINEAR_REGRESSION): 459 raise ValueError( 460 'problem_type must be ProblemType.LINEAR_REGRESSION or ' 461 'ProblemType.CLASSIFICATION; got {}'. 462 format(problem_type)) 463 if prediction_type not in (rnn_common.PredictionType.SINGLE_VALUE, 464 rnn_common.PredictionType.MULTIPLE_VALUE): 465 raise ValueError( 466 'prediction_type must be PredictionType.MULTIPLE_VALUEs or ' 467 'PredictionType.SINGLE_VALUE; got {}'. 468 format(prediction_type)) 469 if (problem_type != constants.ProblemType.CLASSIFICATION 470 and predict_probabilities): 471 raise ValueError( 472 'predict_probabilities can only be set to True for problem_type' 473 ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type)) 474 def _dynamic_rnn_model_fn(features, labels, mode): 475 """The model to be passed to an `Estimator`.""" 476 with ops.name_scope(name): 477 sequence_length = features.get(sequence_length_key) 478 sequence_input = build_sequence_input(features, 479 sequence_feature_columns, 480 context_feature_columns) 481 dropout = (dropout_keep_probabilities 482 if mode == model_fn.ModeKeys.TRAIN 483 else None) 484 # This class promises to use the cell type selected by that function. 485 cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout) 486 initial_state = dict_to_state_tuple(features, cell) 487 rnn_activations, final_state = construct_rnn( 488 initial_state, 489 sequence_input, 490 cell, 491 target_column.num_label_columns, 492 dtype=dtype, 493 parallel_iterations=parallel_iterations, 494 swap_memory=swap_memory) 495 496 loss = None # Created below for modes TRAIN and EVAL. 497 if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: 498 prediction_dict = rnn_common.multi_value_predictions( 499 rnn_activations, target_column, problem_type, predict_probabilities) 500 if mode != model_fn.ModeKeys.INFER: 501 loss = _multi_value_loss( 502 rnn_activations, labels, sequence_length, target_column, features) 503 elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE: 504 prediction_dict = _single_value_predictions( 505 rnn_activations, sequence_length, target_column, 506 problem_type, predict_probabilities) 507 if mode != model_fn.ModeKeys.INFER: 508 loss = _single_value_loss( 509 rnn_activations, labels, sequence_length, target_column, features) 510 state_dict = state_tuple_to_dict(final_state) 511 prediction_dict.update(state_dict) 512 513 eval_metric_ops = None 514 if mode != model_fn.ModeKeys.INFER: 515 eval_metric_ops = rnn_common.get_eval_metric_ops( 516 problem_type, prediction_type, sequence_length, prediction_dict, 517 labels) 518 519 train_op = None 520 if mode == model_fn.ModeKeys.TRAIN: 521 train_op = optimizers.optimize_loss( 522 loss=loss, 523 global_step=None, # Get it internally. 524 learning_rate=learning_rate, 525 optimizer=optimizer, 526 clip_gradients=gradient_clipping_norm, 527 summaries=optimizers.OPTIMIZER_SUMMARIES) 528 529 output_alternatives = _get_output_alternatives(prediction_type, 530 problem_type, 531 prediction_dict) 532 533 return model_fn.ModelFnOps(mode=mode, 534 predictions=prediction_dict, 535 loss=loss, 536 train_op=train_op, 537 eval_metric_ops=eval_metric_ops, 538 output_alternatives=output_alternatives) 539 return _dynamic_rnn_model_fn 540 541 542 class DynamicRnnEstimator(estimator.Estimator): 543 544 def __init__(self, 545 problem_type, 546 prediction_type, 547 sequence_feature_columns, 548 context_feature_columns=None, 549 num_classes=None, 550 num_units=None, 551 cell_type='basic_rnn', 552 optimizer='SGD', 553 learning_rate=0.1, 554 predict_probabilities=False, 555 momentum=None, 556 gradient_clipping_norm=5.0, 557 dropout_keep_probabilities=None, 558 model_dir=None, 559 feature_engineering_fn=None, 560 config=None): 561 """Initializes a `DynamicRnnEstimator`. 562 563 The input function passed to this `Estimator` optionally contains keys 564 `RNNKeys.SEQUENCE_LENGTH_KEY`. The value corresponding to 565 `RNNKeys.SEQUENCE_LENGTH_KEY` must be vector of size `batch_size` where 566 entry `n` corresponds to the length of the `n`th sequence in the batch. The 567 sequence length feature is required for batches of varying sizes. It will be 568 used to calculate loss and evaluation metrics. If 569 `RNNKeys.SEQUENCE_LENGTH_KEY` is not included, all sequences are assumed to 570 have length equal to the size of dimension 1 of the input to the RNN. 571 572 In order to specify an initial state, the input function must include keys 573 `STATE_PREFIX_i` for all `0 <= i < n` where `n` is the number of nested 574 elements in `cell.state_size`. The input function must contain values for 575 all state components or none of them. If none are included, then the default 576 (zero) state is used as an initial state. See the documentation for 577 `dict_to_state_tuple` and `state_tuple_to_dict` for further details. 578 The input function can call rnn_common.construct_rnn_cell() to obtain the 579 same cell type that this class will select from arguments to __init__. 580 581 The `predict()` method of the `Estimator` returns a dictionary with keys 582 `STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements 583 in `cell.state_size`, along with `PredictionKey.CLASSES` for problem type 584 `CLASSIFICATION` or `PredictionKey.SCORES` for problem type 585 `LINEAR_REGRESSION`. The value keyed by 586 `PredictionKey.CLASSES` or `PredictionKey.SCORES` has shape 587 `[batch_size, padded_length]` in the multi-value case and shape 588 `[batch_size]` in the single-value case. Here, `padded_length` is the 589 largest value in the `RNNKeys.SEQUENCE_LENGTH` `Tensor` passed as input. 590 Entry `[i, j]` is the prediction associated with sequence `i` and time step 591 `j`. If the problem type is `CLASSIFICATION` and `predict_probabilities` is 592 `True`, it will also include key`PredictionKey.PROBABILITIES`. 593 594 Args: 595 problem_type: whether the `Estimator` is intended for a regression or 596 classification problem. Value must be one of 597 `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`. 598 prediction_type: whether the `Estimator` should return a value for each 599 step in the sequence, or just a single value for the final time step. 600 Must be one of `PredictionType.SINGLE_VALUE` or 601 `PredictionType.MULTIPLE_VALUE`. 602 sequence_feature_columns: An iterable containing all the feature columns 603 describing sequence features. All items in the iterable should be 604 instances of classes derived from `FeatureColumn`. 605 context_feature_columns: An iterable containing all the feature columns 606 describing context features, i.e., features that apply across all time 607 steps. All items in the set should be instances of classes derived from 608 `FeatureColumn`. 609 num_classes: the number of classes for a classification problem. Only 610 used when `problem_type=ProblemType.CLASSIFICATION`. 611 num_units: A list of integers indicating the number of units in the 612 `RNNCell`s in each layer. 613 cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. 614 optimizer: The type of optimizer to use. Either a subclass of 615 `Optimizer`, an instance of an `Optimizer`, a callback that returns an 616 optimizer, or a string. Strings must be one of 'Adagrad', 'Adam', 617 'Ftrl', 'Momentum', 'RMSProp' or 'SGD. See `layers.optimize_loss` for 618 more details. 619 learning_rate: Learning rate. This argument has no effect if `optimizer` 620 is an instance of an `Optimizer`. 621 predict_probabilities: A boolean indicating whether to predict 622 probabilities for all classes. Used only if `problem_type` is 623 `ProblemType.CLASSIFICATION` 624 momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. 625 gradient_clipping_norm: Parameter used for gradient clipping. If `None`, 626 then no clipping is performed. 627 dropout_keep_probabilities: a list of dropout probabilities or `None`. 628 If a list is given, it must have length `len(num_units) + 1`. If 629 `None`, then no dropout is applied. 630 model_dir: The directory in which to save and restore the model graph, 631 parameters, etc. 632 feature_engineering_fn: Takes features and labels which are the output of 633 `input_fn` and returns features and labels which will be fed into 634 `model_fn`. Please check `model_fn` for a definition of features and 635 labels. 636 config: A `RunConfig` instance. 637 638 Raises: 639 ValueError: `problem_type` is not one of 640 `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. 641 ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but 642 `num_classes` is not specified. 643 ValueError: `prediction_type` is not one of 644 `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`. 645 """ 646 if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: 647 name = 'MultiValueDynamicRNN' 648 elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE: 649 name = 'SingleValueDynamicRNN' 650 else: 651 raise ValueError( 652 'prediction_type must be one of PredictionType.MULTIPLE_VALUE or ' 653 'PredictionType.SINGLE_VALUE; got {}'.format(prediction_type)) 654 655 if problem_type == constants.ProblemType.LINEAR_REGRESSION: 656 name += 'Regressor' 657 target_column = layers.regression_target() 658 elif problem_type == constants.ProblemType.CLASSIFICATION: 659 if not num_classes: 660 raise ValueError('For CLASSIFICATION problem_type, num_classes must be ' 661 'specified.') 662 target_column = layers.multi_class_target(n_classes=num_classes) 663 name += 'Classifier' 664 else: 665 raise ValueError( 666 'problem_type must be either ProblemType.LINEAR_REGRESSION ' 667 'or ProblemType.CLASSIFICATION; got {}'.format( 668 problem_type)) 669 670 if optimizer == 'Momentum': 671 optimizer = momentum_opt.MomentumOptimizer(learning_rate, momentum) 672 dynamic_rnn_model_fn = _get_dynamic_rnn_model_fn( 673 cell_type=cell_type, 674 num_units=num_units, 675 target_column=target_column, 676 problem_type=problem_type, 677 prediction_type=prediction_type, 678 optimizer=optimizer, 679 sequence_feature_columns=sequence_feature_columns, 680 context_feature_columns=context_feature_columns, 681 predict_probabilities=predict_probabilities, 682 learning_rate=learning_rate, 683 gradient_clipping_norm=gradient_clipping_norm, 684 dropout_keep_probabilities=dropout_keep_probabilities, 685 name=name) 686 687 super(DynamicRnnEstimator, self).__init__( 688 model_fn=dynamic_rnn_model_fn, 689 model_dir=model_dir, 690 config=config, 691 feature_engineering_fn=feature_engineering_fn) 692