Home | History | Annotate | Download | only in estimators
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Estimator for Dynamic RNNs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib import layers
     22 from tensorflow.contrib.layers.python.layers import optimizers
     23 from tensorflow.contrib.learn.python.learn.estimators import constants
     24 from tensorflow.contrib.learn.python.learn.estimators import estimator
     25 from tensorflow.contrib.learn.python.learn.estimators import model_fn
     26 from tensorflow.contrib.learn.python.learn.estimators import prediction_key
     27 from tensorflow.contrib.learn.python.learn.estimators import rnn_common
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import check_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import rnn
     34 from tensorflow.python.training import momentum as momentum_opt
     35 from tensorflow.python.util import nest
     36 
     37 
     38 # TODO(jtbates): Remove PredictionType when all non-experimental targets which
     39 # depend on it point to rnn_common.PredictionType.
     40 class PredictionType(object):
     41   SINGLE_VALUE = 1
     42   MULTIPLE_VALUE = 2
     43 
     44 
     45 def _get_state_name(i):
     46   """Constructs the name string for state component `i`."""
     47   return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
     48 
     49 
     50 def state_tuple_to_dict(state):
     51   """Returns a dict containing flattened `state`.
     52 
     53   Args:
     54     state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
     55     have the same rank and agree on all dimensions except the last.
     56 
     57   Returns:
     58     A dict containing the `Tensor`s that make up `state`. The keys of the dict
     59     are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
     60     in a depth-first traversal of `state`.
     61   """
     62   with ops.name_scope('state_tuple_to_dict'):
     63     flat_state = nest.flatten(state)
     64     state_dict = {}
     65     for i, state_component in enumerate(flat_state):
     66       state_name = _get_state_name(i)
     67       state_value = (None if state_component is None
     68                      else array_ops.identity(state_component, name=state_name))
     69       state_dict[state_name] = state_value
     70   return state_dict
     71 
     72 
     73 def dict_to_state_tuple(input_dict, cell):
     74   """Reconstructs nested `state` from a dict containing state `Tensor`s.
     75 
     76   Args:
     77     input_dict: A dict of `Tensor`s.
     78     cell: An instance of `RNNCell`.
     79   Returns:
     80     If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
     81     where `n` is the number of nested entries in `cell.state_size`, this
     82     function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
     83     is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
     84     tuple.
     85   Raises:
     86     ValueError: State is partially specified. The `input_dict` must contain
     87       values for all state components or none at all.
     88   """
     89   flat_state_sizes = nest.flatten(cell.state_size)
     90   state_tensors = []
     91   with ops.name_scope('dict_to_state_tuple'):
     92     for i, state_size in enumerate(flat_state_sizes):
     93       state_name = _get_state_name(i)
     94       state_tensor = input_dict.get(state_name)
     95       if state_tensor is not None:
     96         rank_check = check_ops.assert_rank(
     97             state_tensor, 2, name='check_state_{}_rank'.format(i))
     98         shape_check = check_ops.assert_equal(
     99             array_ops.shape(state_tensor)[1],
    100             state_size,
    101             name='check_state_{}_shape'.format(i))
    102         with ops.control_dependencies([rank_check, shape_check]):
    103           state_tensor = array_ops.identity(state_tensor, name=state_name)
    104         state_tensors.append(state_tensor)
    105     if not state_tensors:
    106       return None
    107     elif len(state_tensors) == len(flat_state_sizes):
    108       dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
    109       return nest.pack_sequence_as(dummy_state, state_tensors)
    110     else:
    111       raise ValueError(
    112           'RNN state was partially specified.'
    113           'Expected zero or {} state Tensors; got {}'.
    114           format(len(flat_state_sizes), len(state_tensors)))
    115 
    116 
    117 def _concatenate_context_input(sequence_input, context_input):
    118   """Replicates `context_input` across all timesteps of `sequence_input`.
    119 
    120   Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
    121   This value is appended to `sequence_input` on dimension 2 and the result is
    122   returned.
    123 
    124   Args:
    125     sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
    126       padded_length, d0]`.
    127     context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
    128 
    129   Returns:
    130     A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    131     d0 + d1]`.
    132 
    133   Raises:
    134     ValueError: If `sequence_input` does not have rank 3 or `context_input` does
    135       not have rank 2.
    136   """
    137   seq_rank_check = check_ops.assert_rank(
    138       sequence_input,
    139       3,
    140       message='sequence_input must have rank 3',
    141       data=[array_ops.shape(sequence_input)])
    142   seq_type_check = check_ops.assert_type(
    143       sequence_input,
    144       dtypes.float32,
    145       message='sequence_input must have dtype float32; got {}.'.format(
    146           sequence_input.dtype))
    147   ctx_rank_check = check_ops.assert_rank(
    148       context_input,
    149       2,
    150       message='context_input must have rank 2',
    151       data=[array_ops.shape(context_input)])
    152   ctx_type_check = check_ops.assert_type(
    153       context_input,
    154       dtypes.float32,
    155       message='context_input must have dtype float32; got {}.'.format(
    156           context_input.dtype))
    157   with ops.control_dependencies(
    158       [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
    159     padded_length = array_ops.shape(sequence_input)[1]
    160     tiled_context_input = array_ops.tile(
    161         array_ops.expand_dims(context_input, 1),
    162         array_ops.concat([[1], [padded_length], [1]], 0))
    163   return array_ops.concat([sequence_input, tiled_context_input], 2)
    164 
    165 
    166 def build_sequence_input(features,
    167                          sequence_feature_columns,
    168                          context_feature_columns,
    169                          weight_collections=None,
    170                          scope=None):
    171   """Combine sequence and context features into input for an RNN.
    172 
    173   Args:
    174     features: A `dict` containing the input and (optionally) sequence length
    175       information and initial state.
    176     sequence_feature_columns: An iterable containing all the feature columns
    177       describing sequence features. All items in the set should be instances
    178       of classes derived from `FeatureColumn`.
    179     context_feature_columns: An iterable containing all the feature columns
    180       describing context features i.e. features that apply across all time
    181       steps. All items in the set should be instances of classes derived from
    182       `FeatureColumn`.
    183     weight_collections: List of graph collections to which weights are added.
    184     scope: Optional scope, passed through to parsing ops.
    185   Returns:
    186     A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`.
    187     This will be used as input to an RNN.
    188   """
    189   features = features.copy()
    190   features.update(layers.transform_features(
    191       features,
    192       list(sequence_feature_columns) + list(context_feature_columns or [])))
    193   sequence_input = layers.sequence_input_from_feature_columns(
    194       columns_to_tensors=features,
    195       feature_columns=sequence_feature_columns,
    196       weight_collections=weight_collections,
    197       scope=scope)
    198   if context_feature_columns is not None:
    199     context_input = layers.input_from_feature_columns(
    200         columns_to_tensors=features,
    201         feature_columns=context_feature_columns,
    202         weight_collections=weight_collections,
    203         scope=scope)
    204     sequence_input = _concatenate_context_input(sequence_input, context_input)
    205   return sequence_input
    206 
    207 
    208 def construct_rnn(initial_state,
    209                   sequence_input,
    210                   cell,
    211                   num_label_columns,
    212                   dtype=dtypes.float32,
    213                   parallel_iterations=32,
    214                   swap_memory=True):
    215   """Build an RNN and apply a fully connected layer to get the desired output.
    216 
    217   Args:
    218     initial_state: The initial state to pass the RNN. If `None`, the
    219       default starting state for `self._cell` is used.
    220     sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]`
    221       that will be passed as input to the RNN.
    222     cell: An initialized `RNNCell`.
    223     num_label_columns: The desired output dimension.
    224     dtype: dtype of `cell`.
    225     parallel_iterations: Number of iterations to run in parallel. Values >> 1
    226       use more memory but take less time, while smaller values use less memory
    227       but computations take longer.
    228     swap_memory: Transparently swap the tensors produced in forward inference
    229       but needed for back prop from GPU to CPU.  This allows training RNNs
    230       which would typically not fit on a single GPU, with very minimal (or no)
    231       performance penalty.
    232   Returns:
    233     activations: The output of the RNN, projected to `num_label_columns`
    234       dimensions.
    235     final_state: A `Tensor` or nested tuple of `Tensor`s representing the final
    236       state output by the RNN.
    237   """
    238   with ops.name_scope('RNN'):
    239     rnn_outputs, final_state = rnn.dynamic_rnn(
    240         cell=cell,
    241         inputs=sequence_input,
    242         initial_state=initial_state,
    243         dtype=dtype,
    244         parallel_iterations=parallel_iterations,
    245         swap_memory=swap_memory,
    246         time_major=False)
    247     activations = layers.fully_connected(
    248         inputs=rnn_outputs,
    249         num_outputs=num_label_columns,
    250         activation_fn=None,
    251         trainable=True)
    252     return activations, final_state
    253 
    254 
    255 def _single_value_predictions(activations,
    256                               sequence_length,
    257                               target_column,
    258                               problem_type,
    259                               predict_probabilities):
    260   """Maps `activations` from the RNN to predictions for single value models.
    261 
    262   If `predict_probabilities` is `False`, this function returns a `dict`
    263   containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities`
    264   is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The
    265   value of this entry is a `Tensor` of probabilities with shape
    266   `[batch_size, num_classes]`.
    267 
    268   Args:
    269     activations: Output from an RNN. Should have dtype `float32` and shape
    270       `[batch_size, padded_length, ?]`.
    271     sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
    272       containing the length of each sequence in the batch. If `None`, sequences
    273       are assumed to be unpadded.
    274     target_column: An initialized `TargetColumn`, calculate predictions.
    275     problem_type: Either `ProblemType.CLASSIFICATION` or
    276       `ProblemType.LINEAR_REGRESSION`.
    277     predict_probabilities: A Python boolean, indicating whether probabilities
    278       should be returned. Should only be set to `True` for
    279       classification/logistic regression problems.
    280   Returns:
    281     A `dict` mapping strings to `Tensors`.
    282   """
    283   with ops.name_scope('SingleValuePrediction'):
    284     last_activations = rnn_common.select_last_activations(
    285         activations, sequence_length)
    286     predictions_name = (prediction_key.PredictionKey.CLASSES
    287                         if problem_type == constants.ProblemType.CLASSIFICATION
    288                         else prediction_key.PredictionKey.SCORES)
    289     if predict_probabilities:
    290       probabilities = target_column.logits_to_predictions(
    291           last_activations, proba=True)
    292       prediction_dict = {
    293           prediction_key.PredictionKey.PROBABILITIES: probabilities,
    294           predictions_name: math_ops.argmax(probabilities, 1)}
    295     else:
    296       predictions = target_column.logits_to_predictions(
    297           last_activations, proba=False)
    298       prediction_dict = {predictions_name: predictions}
    299     return prediction_dict
    300 
    301 
    302 def _multi_value_loss(
    303     activations, labels, sequence_length, target_column, features):
    304   """Maps `activations` from the RNN to loss for multi value models.
    305 
    306   Args:
    307     activations: Output from an RNN. Should have dtype `float32` and shape
    308       `[batch_size, padded_length, ?]`.
    309     labels: A `Tensor` with length `[batch_size, padded_length]`.
    310     sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
    311       containing the length of each sequence in the batch. If `None`, sequences
    312       are assumed to be unpadded.
    313     target_column: An initialized `TargetColumn`, calculate predictions.
    314     features: A `dict` containing the input and (optionally) sequence length
    315       information and initial state.
    316   Returns:
    317     A scalar `Tensor` containing the loss.
    318   """
    319   with ops.name_scope('MultiValueLoss'):
    320     activations_masked, labels_masked = rnn_common.mask_activations_and_labels(
    321         activations, labels, sequence_length)
    322     return target_column.loss(activations_masked, labels_masked, features)
    323 
    324 
    325 def _single_value_loss(
    326     activations, labels, sequence_length, target_column, features):
    327   """Maps `activations` from the RNN to loss for multi value models.
    328 
    329   Args:
    330     activations: Output from an RNN. Should have dtype `float32` and shape
    331       `[batch_size, padded_length, ?]`.
    332     labels: A `Tensor` with length `[batch_size]`.
    333     sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
    334       containing the length of each sequence in the batch. If `None`, sequences
    335       are assumed to be unpadded.
    336     target_column: An initialized `TargetColumn`, calculate predictions.
    337     features: A `dict` containing the input and (optionally) sequence length
    338       information and initial state.
    339   Returns:
    340     A scalar `Tensor` containing the loss.
    341   """
    342 
    343   with ops.name_scope('SingleValueLoss'):
    344     last_activations = rnn_common.select_last_activations(
    345         activations, sequence_length)
    346     return target_column.loss(last_activations, labels, features)
    347 
    348 
    349 def _get_output_alternatives(prediction_type,
    350                              problem_type,
    351                              prediction_dict):
    352   """Constructs output alternatives dict for `ModelFnOps`.
    353 
    354   Args:
    355     prediction_type: either `MULTIPLE_VALUE` or `SINGLE_VALUE`.
    356     problem_type: either `CLASSIFICATION` or `LINEAR_REGRESSION`.
    357     prediction_dict: a dictionary mapping strings to `Tensor`s containing
    358       predictions.
    359 
    360   Returns:
    361     `None` or a dictionary mapping a string to an output alternative.
    362 
    363   Raises:
    364     ValueError: `prediction_type` is not one of `SINGLE_VALUE` or
    365     `MULTIPLE_VALUE`.
    366   """
    367   if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
    368     return None
    369   if prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
    370     prediction_dict_no_state = {
    371         k: v
    372         for k, v in prediction_dict.items()
    373         if rnn_common.RNNKeys.STATE_PREFIX not in k
    374     }
    375     return {'dynamic_rnn_output': (problem_type, prediction_dict_no_state)}
    376   raise ValueError('Unrecognized prediction_type: {}'.format(prediction_type))
    377 
    378 
    379 def _get_dynamic_rnn_model_fn(
    380     cell_type,
    381     num_units,
    382     target_column,
    383     problem_type,
    384     prediction_type,
    385     optimizer,
    386     sequence_feature_columns,
    387     context_feature_columns=None,
    388     predict_probabilities=False,
    389     learning_rate=None,
    390     gradient_clipping_norm=None,
    391     dropout_keep_probabilities=None,
    392     sequence_length_key=rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY,
    393     dtype=dtypes.float32,
    394     parallel_iterations=None,
    395     swap_memory=True,
    396     name='DynamicRNNModel'):
    397   """Creates an RNN model function for an `Estimator`.
    398 
    399   The model function returns an instance of `ModelFnOps`. When
    400   `problem_type == ProblemType.CLASSIFICATION` and
    401   `predict_probabilities == True`, the returned `ModelFnOps` includes an output
    402   alternative containing the classes and their associated probabilities. When
    403   `predict_probabilities == False`, only the classes are included. When
    404   `problem_type == ProblemType.LINEAR_REGRESSION`, the output alternative
    405   contains only the predicted values.
    406 
    407   Args:
    408     cell_type: A string, a subclass of `RNNCell` or an instance of an `RNNCell`.
    409     num_units: A single `int` or a list of `int`s. The size of the `RNNCell`s.
    410     target_column: An initialized `TargetColumn`, used to calculate prediction
    411       and loss.
    412     problem_type: `ProblemType.CLASSIFICATION` or
    413       `ProblemType.LINEAR_REGRESSION`.
    414     prediction_type: `PredictionType.SINGLE_VALUE` or
    415       `PredictionType.MULTIPLE_VALUE`.
    416     optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a
    417       string.
    418     sequence_feature_columns: An iterable containing all the feature columns
    419       describing sequence features. All items in the set should be instances
    420       of classes derived from `FeatureColumn`.
    421     context_feature_columns: An iterable containing all the feature columns
    422       describing context features, i.e., features that apply across all time
    423       steps. All items in the set should be instances of classes derived from
    424       `FeatureColumn`.
    425     predict_probabilities: A boolean indicating whether to predict probabilities
    426       for all classes. Must only be used with
    427       `ProblemType.CLASSIFICATION`.
    428     learning_rate: Learning rate used for optimization. This argument has no
    429       effect if `optimizer` is an instance of an `Optimizer`.
    430     gradient_clipping_norm: A float. Gradients will be clipped to this value.
    431     dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
    432       If a list is given, it must have length `len(num_units) + 1`.
    433     sequence_length_key: The key that will be used to look up sequence length in
    434       the `features` dict.
    435     dtype: The dtype of the state and output of the given `cell`.
    436     parallel_iterations: Number of iterations to run in parallel. Values >> 1
    437       use more memory but take less time, while smaller values use less memory
    438       but computations take longer.
    439     swap_memory: Transparently swap the tensors produced in forward inference
    440       but needed for back prop from GPU to CPU.  This allows training RNNs
    441       which would typically not fit on a single GPU, with very minimal (or no)
    442       performance penalty.
    443     name: A string that will be used to create a scope for the RNN.
    444 
    445   Returns:
    446     A model function to be passed to an `Estimator`.
    447 
    448   Raises:
    449     ValueError: `problem_type` is not one of
    450       `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
    451     ValueError: `prediction_type` is not one of `PredictionType.SINGLE_VALUE`
    452       or `PredictionType.MULTIPLE_VALUE`.
    453     ValueError: `predict_probabilities` is `True` for `problem_type` other
    454       than `ProblemType.CLASSIFICATION`.
    455     ValueError: `len(dropout_keep_probabilities)` is not `len(num_units) + 1`.
    456   """
    457   if problem_type not in (constants.ProblemType.CLASSIFICATION,
    458                           constants.ProblemType.LINEAR_REGRESSION):
    459     raise ValueError(
    460         'problem_type must be ProblemType.LINEAR_REGRESSION or '
    461         'ProblemType.CLASSIFICATION; got {}'.
    462         format(problem_type))
    463   if prediction_type not in (rnn_common.PredictionType.SINGLE_VALUE,
    464                              rnn_common.PredictionType.MULTIPLE_VALUE):
    465     raise ValueError(
    466         'prediction_type must be PredictionType.MULTIPLE_VALUEs or '
    467         'PredictionType.SINGLE_VALUE; got {}'.
    468         format(prediction_type))
    469   if (problem_type != constants.ProblemType.CLASSIFICATION
    470       and predict_probabilities):
    471     raise ValueError(
    472         'predict_probabilities can only be set to True for problem_type'
    473         ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type))
    474   def _dynamic_rnn_model_fn(features, labels, mode):
    475     """The model to be passed to an `Estimator`."""
    476     with ops.name_scope(name):
    477       sequence_length = features.get(sequence_length_key)
    478       sequence_input = build_sequence_input(features,
    479                                             sequence_feature_columns,
    480                                             context_feature_columns)
    481       dropout = (dropout_keep_probabilities
    482                  if mode == model_fn.ModeKeys.TRAIN
    483                  else None)
    484       # This class promises to use the cell type selected by that function.
    485       cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout)
    486       initial_state = dict_to_state_tuple(features, cell)
    487       rnn_activations, final_state = construct_rnn(
    488           initial_state,
    489           sequence_input,
    490           cell,
    491           target_column.num_label_columns,
    492           dtype=dtype,
    493           parallel_iterations=parallel_iterations,
    494           swap_memory=swap_memory)
    495 
    496       loss = None  # Created below for modes TRAIN and EVAL.
    497       if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
    498         prediction_dict = rnn_common.multi_value_predictions(
    499             rnn_activations, target_column, problem_type, predict_probabilities)
    500         if mode != model_fn.ModeKeys.INFER:
    501           loss = _multi_value_loss(
    502               rnn_activations, labels, sequence_length, target_column, features)
    503       elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
    504         prediction_dict = _single_value_predictions(
    505             rnn_activations, sequence_length, target_column,
    506             problem_type, predict_probabilities)
    507         if mode != model_fn.ModeKeys.INFER:
    508           loss = _single_value_loss(
    509               rnn_activations, labels, sequence_length, target_column, features)
    510       state_dict = state_tuple_to_dict(final_state)
    511       prediction_dict.update(state_dict)
    512 
    513       eval_metric_ops = None
    514       if mode != model_fn.ModeKeys.INFER:
    515         eval_metric_ops = rnn_common.get_eval_metric_ops(
    516             problem_type, prediction_type, sequence_length, prediction_dict,
    517             labels)
    518 
    519       train_op = None
    520       if mode == model_fn.ModeKeys.TRAIN:
    521         train_op = optimizers.optimize_loss(
    522             loss=loss,
    523             global_step=None,  # Get it internally.
    524             learning_rate=learning_rate,
    525             optimizer=optimizer,
    526             clip_gradients=gradient_clipping_norm,
    527             summaries=optimizers.OPTIMIZER_SUMMARIES)
    528 
    529     output_alternatives = _get_output_alternatives(prediction_type,
    530                                                    problem_type,
    531                                                    prediction_dict)
    532 
    533     return model_fn.ModelFnOps(mode=mode,
    534                                predictions=prediction_dict,
    535                                loss=loss,
    536                                train_op=train_op,
    537                                eval_metric_ops=eval_metric_ops,
    538                                output_alternatives=output_alternatives)
    539   return _dynamic_rnn_model_fn
    540 
    541 
    542 class DynamicRnnEstimator(estimator.Estimator):
    543 
    544   def __init__(self,
    545                problem_type,
    546                prediction_type,
    547                sequence_feature_columns,
    548                context_feature_columns=None,
    549                num_classes=None,
    550                num_units=None,
    551                cell_type='basic_rnn',
    552                optimizer='SGD',
    553                learning_rate=0.1,
    554                predict_probabilities=False,
    555                momentum=None,
    556                gradient_clipping_norm=5.0,
    557                dropout_keep_probabilities=None,
    558                model_dir=None,
    559                feature_engineering_fn=None,
    560                config=None):
    561     """Initializes a `DynamicRnnEstimator`.
    562 
    563     The input function passed to this `Estimator` optionally contains keys
    564     `RNNKeys.SEQUENCE_LENGTH_KEY`. The value corresponding to
    565     `RNNKeys.SEQUENCE_LENGTH_KEY` must be vector of size `batch_size` where
    566     entry `n` corresponds to the length of the `n`th sequence in the batch. The
    567     sequence length feature is required for batches of varying sizes. It will be
    568     used to calculate loss and evaluation metrics. If
    569     `RNNKeys.SEQUENCE_LENGTH_KEY` is not included, all sequences are assumed to
    570     have length equal to the size of dimension 1 of the input to the RNN.
    571 
    572     In order to specify an initial state, the input function must include keys
    573     `STATE_PREFIX_i` for all `0 <= i < n` where `n` is the number of nested
    574     elements in `cell.state_size`. The input function must contain values for
    575     all state components or none of them. If none are included, then the default
    576     (zero) state is used as an initial state. See the documentation for
    577     `dict_to_state_tuple` and `state_tuple_to_dict` for further details.
    578     The input function can call rnn_common.construct_rnn_cell() to obtain the
    579     same cell type that this class will select from arguments to __init__.
    580 
    581     The `predict()` method of the `Estimator` returns a dictionary with keys
    582     `STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements
    583     in `cell.state_size`, along with `PredictionKey.CLASSES` for problem type
    584     `CLASSIFICATION` or `PredictionKey.SCORES` for problem type
    585     `LINEAR_REGRESSION`.  The value keyed by
    586     `PredictionKey.CLASSES` or `PredictionKey.SCORES` has shape
    587     `[batch_size, padded_length]` in the multi-value case and shape
    588     `[batch_size]` in the single-value case.  Here, `padded_length` is the
    589     largest value in the `RNNKeys.SEQUENCE_LENGTH` `Tensor` passed as input.
    590     Entry `[i, j]` is the prediction associated with sequence `i` and time step
    591     `j`. If the problem type is `CLASSIFICATION` and `predict_probabilities` is
    592     `True`, it will also include key`PredictionKey.PROBABILITIES`.
    593 
    594     Args:
    595       problem_type: whether the `Estimator` is intended for a regression or
    596         classification problem. Value must be one of
    597         `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`.
    598       prediction_type: whether the `Estimator` should return a value for each
    599         step in the sequence, or just a single value for the final time step.
    600         Must be one of `PredictionType.SINGLE_VALUE` or
    601         `PredictionType.MULTIPLE_VALUE`.
    602       sequence_feature_columns: An iterable containing all the feature columns
    603         describing sequence features. All items in the iterable should be
    604         instances of classes derived from `FeatureColumn`.
    605       context_feature_columns: An iterable containing all the feature columns
    606         describing context features, i.e., features that apply across all time
    607         steps. All items in the set should be instances of classes derived from
    608         `FeatureColumn`.
    609       num_classes: the number of classes for a classification problem. Only
    610         used when `problem_type=ProblemType.CLASSIFICATION`.
    611       num_units: A list of integers indicating the number of units in the
    612         `RNNCell`s in each layer.
    613       cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
    614       optimizer: The type of optimizer to use. Either a subclass of
    615         `Optimizer`, an instance of an `Optimizer`, a callback that returns an
    616         optimizer, or a string. Strings must be one of 'Adagrad', 'Adam',
    617         'Ftrl', 'Momentum', 'RMSProp' or 'SGD. See `layers.optimize_loss` for
    618         more details.
    619       learning_rate: Learning rate. This argument has no effect if `optimizer`
    620         is an instance of an `Optimizer`.
    621       predict_probabilities: A boolean indicating whether to predict
    622         probabilities for all classes. Used only if `problem_type` is
    623         `ProblemType.CLASSIFICATION`
    624       momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
    625       gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
    626         then no clipping is performed.
    627       dropout_keep_probabilities: a list of dropout probabilities or `None`.
    628         If a list is given, it must have length `len(num_units) + 1`. If
    629         `None`, then no dropout is applied.
    630       model_dir: The directory in which to save and restore the model graph,
    631         parameters, etc.
    632       feature_engineering_fn: Takes features and labels which are the output of
    633         `input_fn` and returns features and labels which will be fed into
    634         `model_fn`. Please check `model_fn` for a definition of features and
    635         labels.
    636       config: A `RunConfig` instance.
    637 
    638     Raises:
    639       ValueError: `problem_type` is not one of
    640         `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
    641       ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
    642         `num_classes` is not specified.
    643       ValueError: `prediction_type` is not one of
    644         `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`.
    645     """
    646     if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
    647       name = 'MultiValueDynamicRNN'
    648     elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
    649       name = 'SingleValueDynamicRNN'
    650     else:
    651       raise ValueError(
    652           'prediction_type must be one of PredictionType.MULTIPLE_VALUE or '
    653           'PredictionType.SINGLE_VALUE; got {}'.format(prediction_type))
    654 
    655     if problem_type == constants.ProblemType.LINEAR_REGRESSION:
    656       name += 'Regressor'
    657       target_column = layers.regression_target()
    658     elif problem_type == constants.ProblemType.CLASSIFICATION:
    659       if not num_classes:
    660         raise ValueError('For CLASSIFICATION problem_type, num_classes must be '
    661                          'specified.')
    662       target_column = layers.multi_class_target(n_classes=num_classes)
    663       name += 'Classifier'
    664     else:
    665       raise ValueError(
    666           'problem_type must be either ProblemType.LINEAR_REGRESSION '
    667           'or ProblemType.CLASSIFICATION; got {}'.format(
    668               problem_type))
    669 
    670     if optimizer == 'Momentum':
    671       optimizer = momentum_opt.MomentumOptimizer(learning_rate, momentum)
    672     dynamic_rnn_model_fn = _get_dynamic_rnn_model_fn(
    673         cell_type=cell_type,
    674         num_units=num_units,
    675         target_column=target_column,
    676         problem_type=problem_type,
    677         prediction_type=prediction_type,
    678         optimizer=optimizer,
    679         sequence_feature_columns=sequence_feature_columns,
    680         context_feature_columns=context_feature_columns,
    681         predict_probabilities=predict_probabilities,
    682         learning_rate=learning_rate,
    683         gradient_clipping_norm=gradient_clipping_norm,
    684         dropout_keep_probabilities=dropout_keep_probabilities,
    685         name=name)
    686 
    687     super(DynamicRnnEstimator, self).__init__(
    688         model_fn=dynamic_rnn_model_fn,
    689         model_dir=model_dir,
    690         config=config,
    691         feature_engineering_fn=feature_engineering_fn)
    692