Home | History | Annotate | Download | only in layers
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=protected-access
     16 """Recurrent layers and their base classes.
     17 """
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import numbers
     23 import numpy as np
     24 
     25 from tensorflow.python.framework import tensor_shape
     26 from tensorflow.python.keras._impl.keras import activations
     27 from tensorflow.python.keras._impl.keras import backend as K
     28 from tensorflow.python.keras._impl.keras import constraints
     29 from tensorflow.python.keras._impl.keras import initializers
     30 from tensorflow.python.keras._impl.keras import regularizers
     31 from tensorflow.python.keras._impl.keras.engine import InputSpec
     32 from tensorflow.python.keras._impl.keras.engine import Layer
     33 from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
     34 from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
     35 from tensorflow.python.platform import tf_logging as logging
     36 from tensorflow.python.util.tf_export import tf_export
     37 
     38 
     39 @tf_export('keras.layers.StackedRNNCells')
     40 class StackedRNNCells(Layer):
     41   """Wrapper allowing a stack of RNN cells to behave as a single cell.
     42 
     43   Used to implement efficient stacked RNNs.
     44 
     45   Arguments:
     46       cells: List of RNN cell instances.
     47 
     48   Examples:
     49 
     50   ```python
     51       cells = [
     52           keras.layers.LSTMCell(output_dim),
     53           keras.layers.LSTMCell(output_dim),
     54           keras.layers.LSTMCell(output_dim),
     55       ]
     56 
     57       inputs = keras.Input((timesteps, input_dim))
     58       x = keras.layers.RNN(cells)(inputs)
     59   ```
     60   """
     61 
     62   def __init__(self, cells, **kwargs):
     63     for cell in cells:
     64       if not hasattr(cell, 'call'):
     65         raise ValueError('All cells must have a `call` method. '
     66                          'received cells:', cells)
     67       if not hasattr(cell, 'state_size'):
     68         raise ValueError('All cells must have a '
     69                          '`state_size` attribute. '
     70                          'received cells:', cells)
     71     self.cells = cells
     72     super(StackedRNNCells, self).__init__(**kwargs)
     73 
     74   @property
     75   def state_size(self):
     76     # States are a flat list
     77     # in reverse order of the cell stack.
     78     # This allows to preserve the requirement
     79     # `stack.state_size[0] == output_dim`.
     80     # e.g. states of a 2-layer LSTM would be
     81     # `[h2, c2, h1, c1]`
     82     # (assuming one LSTM has states [h, c])
     83     state_size = []
     84     for cell in self.cells[::-1]:
     85       if hasattr(cell.state_size, '__len__'):
     86         state_size += list(cell.state_size)
     87       else:
     88         state_size.append(cell.state_size)
     89     return tuple(state_size)
     90 
     91   def call(self, inputs, states, constants=None, **kwargs):
     92     # Recover per-cell states.
     93     nested_states = []
     94     for cell in self.cells[::-1]:
     95       if hasattr(cell.state_size, '__len__'):
     96         nested_states.append(states[:len(cell.state_size)])
     97         states = states[len(cell.state_size):]
     98       else:
     99         nested_states.append([states[0]])
    100         states = states[1:]
    101     nested_states = nested_states[::-1]
    102 
    103     # Call the cells in order and store the returned states.
    104     new_nested_states = []
    105     for cell, states in zip(self.cells, nested_states):
    106       if has_arg(cell.call, 'constants'):
    107         inputs, states = cell.call(inputs, states, constants=constants,
    108                                    **kwargs)
    109       else:
    110         inputs, states = cell.call(inputs, states, **kwargs)
    111 
    112       new_nested_states.append(states)
    113 
    114     # Format the new states as a flat list
    115     # in reverse cell order.
    116     states = []
    117     for cell_states in new_nested_states[::-1]:
    118       states += cell_states
    119     return inputs, states
    120 
    121   @shape_type_conversion
    122   def build(self, input_shape):
    123     if isinstance(input_shape, list):
    124       constants_shape = input_shape[1:]
    125       input_shape = input_shape[0]
    126     for cell in self.cells:
    127       if isinstance(cell, Layer):
    128         if has_arg(cell.call, 'constants'):
    129           cell.build([input_shape] + constants_shape)
    130         else:
    131           cell.build(input_shape)
    132       if hasattr(cell.state_size, '__len__'):
    133         output_dim = cell.state_size[0]
    134       else:
    135         output_dim = cell.state_size
    136       input_shape = (input_shape[0], output_dim)
    137     self.built = True
    138 
    139   def get_config(self):
    140     cells = []
    141     for cell in self.cells:
    142       cells.append({
    143           'class_name': cell.__class__.__name__,
    144           'config': cell.get_config()
    145       })
    146     config = {'cells': cells}
    147     base_config = super(StackedRNNCells, self).get_config()
    148     return dict(list(base_config.items()) + list(config.items()))
    149 
    150   @classmethod
    151   def from_config(cls, config, custom_objects=None):
    152     from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
    153     cells = []
    154     for cell_config in config.pop('cells'):
    155       cells.append(
    156           deserialize_layer(cell_config, custom_objects=custom_objects))
    157     return cls(cells, **config)
    158 
    159   @property
    160   def trainable_weights(self):
    161     if not self.trainable:
    162       return []
    163     weights = []
    164     for cell in self.cells:
    165       if isinstance(cell, Layer):
    166         weights += cell.trainable_weights
    167     return weights
    168 
    169   @property
    170   def non_trainable_weights(self):
    171     weights = []
    172     for cell in self.cells:
    173       if isinstance(cell, Layer):
    174         weights += cell.non_trainable_weights
    175     if not self.trainable:
    176       trainable_weights = []
    177       for cell in self.cells:
    178         if isinstance(cell, Layer):
    179           trainable_weights += cell.trainable_weights
    180       return trainable_weights + weights
    181     return weights
    182 
    183   def get_weights(self):
    184     """Retrieves the weights of the model.
    185 
    186     Returns:
    187         A flat list of Numpy arrays.
    188     """
    189     weights = []
    190     for cell in self.cells:
    191       if isinstance(cell, Layer):
    192         weights += cell.weights
    193     return K.batch_get_value(weights)
    194 
    195   def set_weights(self, weights):
    196     """Sets the weights of the model.
    197 
    198     Arguments:
    199         weights: A list of Numpy arrays with shapes and types matching
    200             the output of `model.get_weights()`.
    201     """
    202     tuples = []
    203     for cell in self.cells:
    204       if isinstance(cell, Layer):
    205         num_param = len(cell.weights)
    206         weights = weights[:num_param]
    207         for sw, w in zip(cell.weights, weights):
    208           tuples.append((sw, w))
    209         weights = weights[num_param:]
    210     K.batch_set_value(tuples)
    211 
    212   @property
    213   def losses(self):
    214     losses = []
    215     for cell in self.cells:
    216       if isinstance(cell, Layer):
    217         losses += cell.losses
    218     return losses + self._losses
    219 
    220   @property
    221   def updates(self):
    222     updates = []
    223     for cell in self.cells:
    224       if isinstance(cell, Layer):
    225         updates += cell.updates
    226     return updates + self._updates
    227 
    228 
    229 @tf_export('keras.layers.RNN')
    230 class RNN(Layer):
    231   """Base class for recurrent layers.
    232 
    233   Arguments:
    234       cell: A RNN cell instance. A RNN cell is a class that has:
    235           - a `call(input_at_t, states_at_t)` method, returning
    236               `(output_at_t, states_at_t_plus_1)`. The call method of the
    237               cell can also take the optional argument `constants`, see
    238               section "Note on passing external constants" below.
    239           - a `state_size` attribute. This can be a single integer
    240               (single state) in which case it is
    241               the size of the recurrent state
    242               (which should be the same as the size of the cell output).
    243               This can also be a list/tuple of integers
    244               (one size per state). In this case, the first entry
    245               (`state_size[0]`) should be the same as
    246               the size of the cell output.
    247           It is also possible for `cell` to be a list of RNN cell instances,
    248           in which cases the cells get stacked on after the other in the RNN,
    249           implementing an efficient stacked RNN.
    250       return_sequences: Boolean. Whether to return the last output.
    251           in the output sequence, or the full sequence.
    252       return_state: Boolean. Whether to return the last state
    253           in addition to the output.
    254       go_backwards: Boolean (default False).
    255           If True, process the input sequence backwards and return the
    256           reversed sequence.
    257       stateful: Boolean (default False). If True, the last state
    258           for each sample at index i in a batch will be used as initial
    259           state for the sample of index i in the following batch.
    260       unroll: Boolean (default False).
    261           If True, the network will be unrolled,
    262           else a symbolic loop will be used.
    263           Unrolling can speed-up a RNN,
    264           although it tends to be more memory-intensive.
    265           Unrolling is only suitable for short sequences.
    266       input_dim: dimensionality of the input (integer).
    267           This argument (or alternatively,
    268           the keyword argument `input_shape`)
    269           is required when using this layer as the first layer in a model.
    270       input_length: Length of input sequences, to be specified
    271           when it is constant.
    272           This argument is required if you are going to connect
    273           `Flatten` then `Dense` layers upstream
    274           (without it, the shape of the dense outputs cannot be computed).
    275           Note that if the recurrent layer is not the first layer
    276           in your model, you would need to specify the input length
    277           at the level of the first layer
    278           (e.g. via the `input_shape` argument)
    279 
    280   Input shape:
    281       3D tensor with shape `(batch_size, timesteps, input_dim)`.
    282 
    283   Output shape:
    284       - if `return_state`: a list of tensors. The first tensor is
    285           the output. The remaining tensors are the last states,
    286           each with shape `(batch_size, units)`.
    287       - if `return_sequences`: 3D tensor with shape
    288           `(batch_size, timesteps, units)`.
    289       - else, 2D tensor with shape `(batch_size, units)`.
    290 
    291   # Masking
    292       This layer supports masking for input data with a variable number
    293       of timesteps. To introduce masks to your data,
    294       use an [Embedding](embeddings.md) layer with the `mask_zero` parameter
    295       set to `True`.
    296 
    297   # Note on using statefulness in RNNs
    298       You can set RNN layers to be 'stateful', which means that the states
    299       computed for the samples in one batch will be reused as initial states
    300       for the samples in the next batch. This assumes a one-to-one mapping
    301       between samples in different successive batches.
    302 
    303       To enable statefulness:
    304           - specify `stateful=True` in the layer constructor.
    305           - specify a fixed batch size for your model, by passing
    306               if sequential model:
    307                 `batch_input_shape=(...)` to the first layer in your model.
    308               else for functional model with 1 or more Input layers:
    309                 `batch_shape=(...)` to all the first layers in your model.
    310               This is the expected shape of your inputs
    311               *including the batch size*.
    312               It should be a tuple of integers, e.g. `(32, 10, 100)`.
    313           - specify `shuffle=False` when calling fit().
    314 
    315       To reset the states of your model, call `.reset_states()` on either
    316       a specific layer, or on your entire model.
    317 
    318   # Note on specifying the initial state of RNNs
    319       You can specify the initial state of RNN layers symbolically by
    320       calling them with the keyword argument `initial_state`. The value of
    321       `initial_state` should be a tensor or list of tensors representing
    322       the initial state of the RNN layer.
    323 
    324       You can specify the initial state of RNN layers numerically by
    325       calling `reset_states` with the keyword argument `states`. The value of
    326       `states` should be a numpy array or list of numpy arrays representing
    327       the initial state of the RNN layer.
    328 
    329   # Note on passing external constants to RNNs
    330       You can pass "external" constants to the cell using the `constants`
    331       keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
    332       requires that the `cell.call` method accepts the same keyword argument
    333       `constants`. Such constants can be used to condition the cell
    334       transformation on additional static inputs (not changing over time),
    335       a.k.a. an attention mechanism.
    336 
    337   Examples:
    338 
    339   ```python
    340       # First, let's define a RNN Cell, as a layer subclass.
    341 
    342       class MinimalRNNCell(keras.layers.Layer):
    343 
    344           def __init__(self, units, **kwargs):
    345               self.units = units
    346               self.state_size = units
    347               super(MinimalRNNCell, self).__init__(**kwargs)
    348 
    349           def build(self, input_shape):
    350               self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
    351                                             initializer='uniform',
    352                                             name='kernel')
    353               self.recurrent_kernel = self.add_weight(
    354                   shape=(self.units, self.units),
    355                   initializer='uniform',
    356                   name='recurrent_kernel')
    357               self.built = True
    358 
    359           def call(self, inputs, states):
    360               prev_output = states[0]
    361               h = K.dot(inputs, self.kernel)
    362               output = h + K.dot(prev_output, self.recurrent_kernel)
    363               return output, [output]
    364 
    365       # Let's use this cell in a RNN layer:
    366 
    367       cell = MinimalRNNCell(32)
    368       x = keras.Input((None, 5))
    369       layer = RNN(cell)
    370       y = layer(x)
    371 
    372       # Here's how to use the cell to build a stacked RNN:
    373 
    374       cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
    375       x = keras.Input((None, 5))
    376       layer = RNN(cells)
    377       y = layer(x)
    378   ```
    379   """
    380 
    381   def __init__(self,
    382                cell,
    383                return_sequences=False,
    384                return_state=False,
    385                go_backwards=False,
    386                stateful=False,
    387                unroll=False,
    388                **kwargs):
    389     if isinstance(cell, (list, tuple)):
    390       cell = StackedRNNCells(cell)
    391     if not hasattr(cell, 'call'):
    392       raise ValueError('`cell` should have a `call` method. '
    393                        'The RNN was passed:', cell)
    394     if not hasattr(cell, 'state_size'):
    395       raise ValueError('The RNN cell should have '
    396                        'an attribute `state_size` '
    397                        '(tuple of integers, '
    398                        'one integer per RNN state).')
    399     super(RNN, self).__init__(**kwargs)
    400     self.cell = cell
    401     self.return_sequences = return_sequences
    402     self.return_state = return_state
    403     self.go_backwards = go_backwards
    404     self.stateful = stateful
    405     self.unroll = unroll
    406 
    407     self.supports_masking = True
    408     self.input_spec = [InputSpec(ndim=3)]
    409     self.state_spec = None
    410     self._states = None
    411     self.constants_spec = None
    412     self._num_constants = None
    413 
    414   @property
    415   def states(self):
    416     if self._states is None:
    417       if isinstance(self.cell.state_size, numbers.Integral):
    418         num_states = 1
    419       else:
    420         num_states = len(self.cell.state_size)
    421       return [None for _ in range(num_states)]
    422     return self._states
    423 
    424   @states.setter
    425   def states(self, states):
    426     self._states = states
    427 
    428   @shape_type_conversion
    429   def compute_output_shape(self, input_shape):
    430     if isinstance(input_shape, list):
    431       input_shape = input_shape[0]
    432 
    433     if hasattr(self.cell.state_size, '__len__'):
    434       state_size = self.cell.state_size
    435     else:
    436       state_size = [self.cell.state_size]
    437     output_dim = state_size[0]
    438 
    439     if self.return_sequences:
    440       output_shape = (input_shape[0], input_shape[1], output_dim)
    441     else:
    442       output_shape = (input_shape[0], output_dim)
    443 
    444     if self.return_state:
    445       state_shape = [(input_shape[0], dim) for dim in state_size]
    446       return [output_shape] + state_shape
    447     else:
    448       return output_shape
    449 
    450   def compute_mask(self, inputs, mask):
    451     if isinstance(mask, list):
    452       mask = mask[0]
    453     output_mask = mask if self.return_sequences else None
    454     if self.return_state:
    455       state_mask = [None for _ in self.states]
    456       return [output_mask] + state_mask
    457     else:
    458       return output_mask
    459 
    460   @shape_type_conversion
    461   def build(self, input_shape):
    462     # Note input_shape will be list of shapes of initial states and
    463     # constants if these are passed in __call__.
    464     if self._num_constants is not None:
    465       constants_shape = input_shape[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
    466     else:
    467       constants_shape = None
    468 
    469     if isinstance(input_shape, list):
    470       input_shape = input_shape[0]
    471 
    472     batch_size = input_shape[0] if self.stateful else None
    473     input_dim = input_shape[-1]
    474     self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))
    475 
    476     # allow cell (if layer) to build before we set or validate state_spec
    477     if isinstance(self.cell, Layer):
    478       step_input_shape = (input_shape[0],) + input_shape[2:]
    479       if constants_shape is not None:
    480         self.cell.build([step_input_shape] + constants_shape)
    481       else:
    482         self.cell.build(step_input_shape)
    483 
    484     # set or validate state_spec
    485     if hasattr(self.cell.state_size, '__len__'):
    486       state_size = list(self.cell.state_size)
    487     else:
    488       state_size = [self.cell.state_size]
    489 
    490     if self.state_spec is not None:
    491       # initial_state was passed in call, check compatibility
    492       if [spec.shape[-1] for spec in self.state_spec] != state_size:
    493         raise ValueError(
    494             'An `initial_state` was passed that is not compatible with '
    495             '`cell.state_size`. Received `state_spec`={}; '
    496             'however `cell.state_size` is '
    497             '{}'.format(self.state_spec, self.cell.state_size))
    498     else:
    499       self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
    500     if self.stateful:
    501       self.reset_states()
    502 
    503   def get_initial_state(self, inputs):
    504     # build an all-zero tensor of shape (samples, output_dim)
    505     initial_state = K.zeros_like(inputs)  # (samples, timesteps, input_dim)
    506     initial_state = K.sum(initial_state, axis=(1, 2))  # (samples,)
    507     initial_state = K.expand_dims(initial_state)  # (samples, 1)
    508     if hasattr(self.cell.state_size, '__len__'):
    509       return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size]
    510     else:
    511       return [K.tile(initial_state, [1, self.cell.state_size])]
    512 
    513   def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
    514     inputs, initial_state, constants = self._standardize_args(
    515         inputs, initial_state, constants)
    516 
    517     if initial_state is None and constants is None:
    518       return super(RNN, self).__call__(inputs, **kwargs)
    519 
    520     # If any of `initial_state` or `constants` are specified and are Keras
    521     # tensors, then add them to the inputs and temporarily modify the
    522     # input_spec to include them.
    523 
    524     additional_inputs = []
    525     additional_specs = []
    526     if initial_state is not None:
    527       kwargs['initial_state'] = initial_state
    528       additional_inputs += initial_state
    529       self.state_spec = [
    530           InputSpec(shape=K.int_shape(state)) for state in initial_state
    531       ]
    532       additional_specs += self.state_spec
    533     if constants is not None:
    534       kwargs['constants'] = constants
    535       additional_inputs += constants
    536       self.constants_spec = [
    537           InputSpec(shape=K.int_shape(constant)) for constant in constants
    538       ]
    539       self._num_constants = len(constants)
    540       additional_specs += self.constants_spec
    541     # at this point additional_inputs cannot be empty
    542     is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
    543     for tensor in additional_inputs:
    544       if K.is_keras_tensor(tensor) != is_keras_tensor:
    545         raise ValueError('The initial state or constants of an RNN'
    546                          ' layer cannot be specified with a mix of'
    547                          ' Keras tensors and non-Keras tensors'
    548                          '(a "Keras tensor" is a tensor that was'
    549                          'returned by a Keras layer, or by `Input`)')
    550 
    551     if is_keras_tensor:
    552       # Compute the full input spec, including state and constants
    553       full_input = [inputs] + additional_inputs
    554       full_input_spec = self.input_spec + additional_specs
    555       # Perform the call with temporarily replaced input_spec
    556       original_input_spec = self.input_spec
    557       self.input_spec = full_input_spec
    558       output = super(RNN, self).__call__(full_input, **kwargs)
    559       self.input_spec = original_input_spec
    560       return output
    561     else:
    562       return super(RNN, self).__call__(inputs, **kwargs)
    563 
    564   def call(self,
    565            inputs,
    566            mask=None,
    567            training=None,
    568            initial_state=None,
    569            constants=None):
    570     # input shape: `(samples, time (padded with zeros), input_dim)`
    571     # note that the .build() method of subclasses MUST define
    572     # self.input_spec and self.state_spec with complete input shapes.
    573     if isinstance(inputs, list):
    574       inputs = inputs[0]
    575     if initial_state is not None:
    576       pass
    577     elif self.stateful:
    578       initial_state = self.states
    579     else:
    580       initial_state = self.get_initial_state(inputs)
    581 
    582     if isinstance(mask, list):
    583       mask = mask[0]
    584 
    585     if len(initial_state) != len(self.states):
    586       raise ValueError(
    587           'Layer has ' + str(len(self.states)) + ' states but was passed ' +
    588           str(len(initial_state)) + ' initial states.')
    589     input_shape = K.int_shape(inputs)
    590     timesteps = input_shape[1]
    591     if self.unroll and timesteps in [None, 1]:
    592       raise ValueError('Cannot unroll a RNN if the '
    593                        'time dimension is undefined or equal to 1. \n'
    594                        '- If using a Sequential model, '
    595                        'specify the time dimension by passing '
    596                        'an `input_shape` or `batch_input_shape` '
    597                        'argument to your first layer. If your '
    598                        'first layer is an Embedding, you can '
    599                        'also use the `input_length` argument.\n'
    600                        '- If using the functional API, specify '
    601                        'the time dimension by passing a `shape` '
    602                        'or `batch_shape` argument to your Input layer.')
    603 
    604     kwargs = {}
    605     if has_arg(self.cell.call, 'training'):
    606       kwargs['training'] = training
    607 
    608     if constants:
    609       if not has_arg(self.cell.call, 'constants'):
    610         raise ValueError('RNN cell does not support constants')
    611 
    612       def step(inputs, states):
    613         constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
    614         states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
    615         return self.cell.call(inputs, states, constants=constants, **kwargs)
    616     else:
    617 
    618       def step(inputs, states):
    619         return self.cell.call(inputs, states, **kwargs)
    620 
    621     last_output, outputs, states = K.rnn(
    622         step,
    623         inputs,
    624         initial_state,
    625         constants=constants,
    626         go_backwards=self.go_backwards,
    627         mask=mask,
    628         unroll=self.unroll,
    629         input_length=timesteps)
    630     if self.stateful:
    631       updates = []
    632       for i in range(len(states)):
    633         updates.append(K.update(self.states[i], states[i]))
    634       self.add_update(updates, inputs)
    635 
    636     if self.return_sequences:
    637       output = outputs
    638     else:
    639       output = last_output
    640 
    641     # Properly set learning phase
    642     if getattr(last_output, '_uses_learning_phase', False):
    643       output._uses_learning_phase = True
    644       for state in states:
    645         state._uses_learning_phase = True
    646 
    647     if self.return_state:
    648       if not isinstance(states, (list, tuple)):
    649         states = [states]
    650       else:
    651         states = list(states)
    652       return [output] + states
    653     else:
    654       return output
    655 
    656   def _standardize_args(self, inputs, initial_state, constants):
    657     """Standardize `__call__` to a single list of tensor inputs.
    658 
    659     When running a model loaded from file, the input tensors
    660     `initial_state` and `constants` can be passed to `RNN.__call__` as part
    661     of `inputs` instead of by the dedicated keyword arguments. This method
    662     makes sure the arguments are separated and that `initial_state` and
    663     `constants` are lists of tensors (or None).
    664 
    665     Arguments:
    666         inputs: tensor or list/tuple of tensors
    667         initial_state: tensor or list of tensors or None
    668         constants: tensor or list of tensors or None
    669 
    670     Returns:
    671         inputs: tensor
    672         initial_state: list of tensors or None
    673         constants: list of tensors or None
    674     """
    675     if isinstance(inputs, list):
    676       assert initial_state is None and constants is None
    677       if self._num_constants is not None:
    678         constants = inputs[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
    679         inputs = inputs[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
    680       if len(inputs) > 1:
    681         initial_state = inputs[1:]
    682       inputs = inputs[0]
    683 
    684     def to_list_or_none(x):
    685       if x is None or isinstance(x, list):
    686         return x
    687       if isinstance(x, tuple):
    688         return list(x)
    689       return [x]
    690 
    691     initial_state = to_list_or_none(initial_state)
    692     constants = to_list_or_none(constants)
    693 
    694     return inputs, initial_state, constants
    695 
    696   def reset_states(self, states=None):
    697     if not self.stateful:
    698       raise AttributeError('Layer must be stateful.')
    699     batch_size = self.input_spec[0].shape[0]
    700     if not batch_size:
    701       raise ValueError('If a RNN is stateful, it needs to know '
    702                        'its batch size. Specify the batch size '
    703                        'of your input tensors: \n'
    704                        '- If using a Sequential model, '
    705                        'specify the batch size by passing '
    706                        'a `batch_input_shape` '
    707                        'argument to your first layer.\n'
    708                        '- If using the functional API, specify '
    709                        'the batch size by passing a '
    710                        '`batch_shape` argument to your Input layer.')
    711     # initialize state if None
    712     if self.states[0] is None:
    713       if hasattr(self.cell.state_size, '__len__'):
    714         self.states = [
    715             K.zeros((batch_size, dim)) for dim in self.cell.state_size
    716         ]
    717       else:
    718         self.states = [K.zeros((batch_size, self.cell.state_size))]
    719     elif states is None:
    720       if hasattr(self.cell.state_size, '__len__'):
    721         for state, dim in zip(self.states, self.cell.state_size):
    722           K.set_value(state, np.zeros((batch_size, dim)))
    723       else:
    724         K.set_value(self.states[0], np.zeros((batch_size,
    725                                               self.cell.state_size)))
    726     else:
    727       if not isinstance(states, (list, tuple)):
    728         states = [states]
    729       if len(states) != len(self.states):
    730         raise ValueError('Layer ' + self.name + ' expects ' +
    731                          str(len(self.states)) + ' states, '
    732                          'but it received ' + str(len(states)) +
    733                          ' state values. Input received: ' + str(states))
    734       for index, (value, state) in enumerate(zip(states, self.states)):
    735         if hasattr(self.cell.state_size, '__len__'):
    736           dim = self.cell.state_size[index]
    737         else:
    738           dim = self.cell.state_size
    739         if value.shape != (batch_size, dim):
    740           raise ValueError(
    741               'State ' + str(index) + ' is incompatible with layer ' +
    742               self.name + ': expected shape=' + str(
    743                   (batch_size, dim)) + ', found shape=' + str(value.shape))
    744         # TODO(fchollet): consider batch calls to `set_value`.
    745         K.set_value(state, value)
    746 
    747   def get_config(self):
    748     config = {
    749         'return_sequences': self.return_sequences,
    750         'return_state': self.return_state,
    751         'go_backwards': self.go_backwards,
    752         'stateful': self.stateful,
    753         'unroll': self.unroll
    754     }
    755     if self._num_constants is not None:
    756       config['num_constants'] = self._num_constants
    757 
    758     cell_config = self.cell.get_config()
    759     config['cell'] = {
    760         'class_name': self.cell.__class__.__name__,
    761         'config': cell_config
    762     }
    763     base_config = super(RNN, self).get_config()
    764     return dict(list(base_config.items()) + list(config.items()))
    765 
    766   @classmethod
    767   def from_config(cls, config, custom_objects=None):
    768     from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
    769     cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
    770     num_constants = config.pop('num_constants', None)
    771     layer = cls(cell, **config)
    772     layer._num_constants = num_constants
    773     return layer
    774 
    775   @property
    776   def trainable_weights(self):
    777     if not self.trainable:
    778       return []
    779     if isinstance(self.cell, Layer):
    780       return self.cell.trainable_weights
    781     return []
    782 
    783   @property
    784   def non_trainable_weights(self):
    785     if isinstance(self.cell, Layer):
    786       if not self.trainable:
    787         return self.cell.weights
    788       return self.cell.non_trainable_weights
    789     return []
    790 
    791   @property
    792   def losses(self):
    793     losses = []
    794     if isinstance(self.cell, Layer):
    795       losses += self.cell.losses
    796     return losses + self._losses
    797 
    798   @property
    799   def updates(self):
    800     updates = []
    801     if isinstance(self.cell, Layer):
    802       updates += self.cell.updates
    803     return updates + self._updates
    804 
    805 
    806 @tf_export('keras.layers.SimpleRNNCell')
    807 class SimpleRNNCell(Layer):
    808   """Cell class for SimpleRNN.
    809 
    810   Arguments:
    811       units: Positive integer, dimensionality of the output space.
    812       activation: Activation function to use.
    813           Default: hyperbolic tangent (`tanh`).
    814           If you pass `None`, no activation is applied
    815           (ie. "linear" activation: `a(x) = x`).
    816       use_bias: Boolean, whether the layer uses a bias vector.
    817       kernel_initializer: Initializer for the `kernel` weights matrix,
    818           used for the linear transformation of the inputs.
    819       recurrent_initializer: Initializer for the `recurrent_kernel`
    820           weights matrix,
    821           used for the linear transformation of the recurrent state.
    822       bias_initializer: Initializer for the bias vector.
    823       kernel_regularizer: Regularizer function applied to
    824           the `kernel` weights matrix.
    825       recurrent_regularizer: Regularizer function applied to
    826           the `recurrent_kernel` weights matrix.
    827       bias_regularizer: Regularizer function applied to the bias vector.
    828       kernel_constraint: Constraint function applied to
    829           the `kernel` weights matrix.
    830       recurrent_constraint: Constraint function applied to
    831           the `recurrent_kernel` weights matrix.
    832       bias_constraint: Constraint function applied to the bias vector.
    833       dropout: Float between 0 and 1.
    834           Fraction of the units to drop for
    835           the linear transformation of the inputs.
    836       recurrent_dropout: Float between 0 and 1.
    837           Fraction of the units to drop for
    838           the linear transformation of the recurrent state.
    839   """
    840 
    841   def __init__(self,
    842                units,
    843                activation='tanh',
    844                use_bias=True,
    845                kernel_initializer='glorot_uniform',
    846                recurrent_initializer='orthogonal',
    847                bias_initializer='zeros',
    848                kernel_regularizer=None,
    849                recurrent_regularizer=None,
    850                bias_regularizer=None,
    851                kernel_constraint=None,
    852                recurrent_constraint=None,
    853                bias_constraint=None,
    854                dropout=0.,
    855                recurrent_dropout=0.,
    856                **kwargs):
    857     super(SimpleRNNCell, self).__init__(**kwargs)
    858     self.units = units
    859     self.activation = activations.get(activation)
    860     self.use_bias = use_bias
    861 
    862     self.kernel_initializer = initializers.get(kernel_initializer)
    863     self.recurrent_initializer = initializers.get(recurrent_initializer)
    864     self.bias_initializer = initializers.get(bias_initializer)
    865 
    866     self.kernel_regularizer = regularizers.get(kernel_regularizer)
    867     self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
    868     self.bias_regularizer = regularizers.get(bias_regularizer)
    869 
    870     self.kernel_constraint = constraints.get(kernel_constraint)
    871     self.recurrent_constraint = constraints.get(recurrent_constraint)
    872     self.bias_constraint = constraints.get(bias_constraint)
    873 
    874     self.dropout = min(1., max(0., dropout))
    875     self.recurrent_dropout = min(1., max(0., recurrent_dropout))
    876     self.state_size = self.units
    877     self._dropout_mask = None
    878     self._recurrent_dropout_mask = None
    879 
    880   @shape_type_conversion
    881   def build(self, input_shape):
    882     self.kernel = self.add_weight(
    883         shape=(input_shape[-1], self.units),
    884         name='kernel',
    885         initializer=self.kernel_initializer,
    886         regularizer=self.kernel_regularizer,
    887         constraint=self.kernel_constraint)
    888     self.recurrent_kernel = self.add_weight(
    889         shape=(self.units, self.units),
    890         name='recurrent_kernel',
    891         initializer=self.recurrent_initializer,
    892         regularizer=self.recurrent_regularizer,
    893         constraint=self.recurrent_constraint)
    894     if self.use_bias:
    895       self.bias = self.add_weight(
    896           shape=(self.units,),
    897           name='bias',
    898           initializer=self.bias_initializer,
    899           regularizer=self.bias_regularizer,
    900           constraint=self.bias_constraint)
    901     else:
    902       self.bias = None
    903     self.built = True
    904 
    905   def call(self, inputs, states, training=None):
    906     prev_output = states[0]
    907     if 0 < self.dropout < 1 and self._dropout_mask is None:
    908       self._dropout_mask = _generate_dropout_mask(
    909           _generate_dropout_ones(inputs,
    910                                  K.shape(inputs)[-1]),
    911           self.dropout,
    912           training=training)
    913     if (0 < self.recurrent_dropout < 1 and
    914         self._recurrent_dropout_mask is None):
    915       self._recurrent_dropout_mask = _generate_dropout_mask(
    916           _generate_dropout_ones(inputs, self.units),
    917           self.recurrent_dropout,
    918           training=training)
    919 
    920     dp_mask = self._dropout_mask
    921     rec_dp_mask = self._recurrent_dropout_mask
    922 
    923     if dp_mask is not None:
    924       h = K.dot(inputs * dp_mask, self.kernel)
    925     else:
    926       h = K.dot(inputs, self.kernel)
    927     if self.bias is not None:
    928       h = K.bias_add(h, self.bias)
    929 
    930     if rec_dp_mask is not None:
    931       prev_output *= rec_dp_mask
    932     output = h + K.dot(prev_output, self.recurrent_kernel)
    933     if self.activation is not None:
    934       output = self.activation(output)
    935 
    936     # Properly set learning phase on output tensor.
    937     if 0 < self.dropout + self.recurrent_dropout:
    938       if training is None:
    939         output._uses_learning_phase = True
    940     return output, [output]
    941 
    942   def get_config(self):
    943     config = {
    944         'units':
    945             self.units,
    946         'activation':
    947             activations.serialize(self.activation),
    948         'use_bias':
    949             self.use_bias,
    950         'kernel_initializer':
    951             initializers.serialize(self.kernel_initializer),
    952         'recurrent_initializer':
    953             initializers.serialize(self.recurrent_initializer),
    954         'bias_initializer':
    955             initializers.serialize(self.bias_initializer),
    956         'kernel_regularizer':
    957             regularizers.serialize(self.kernel_regularizer),
    958         'recurrent_regularizer':
    959             regularizers.serialize(self.recurrent_regularizer),
    960         'bias_regularizer':
    961             regularizers.serialize(self.bias_regularizer),
    962         'kernel_constraint':
    963             constraints.serialize(self.kernel_constraint),
    964         'recurrent_constraint':
    965             constraints.serialize(self.recurrent_constraint),
    966         'bias_constraint':
    967             constraints.serialize(self.bias_constraint),
    968         'dropout':
    969             self.dropout,
    970         'recurrent_dropout':
    971             self.recurrent_dropout
    972     }
    973     base_config = super(SimpleRNNCell, self).get_config()
    974     return dict(list(base_config.items()) + list(config.items()))
    975 
    976 
    977 @tf_export('keras.layers.SimpleRNN')
    978 class SimpleRNN(RNN):
    979   """Fully-connected RNN where the output is to be fed back to input.
    980 
    981   Arguments:
    982       units: Positive integer, dimensionality of the output space.
    983       activation: Activation function to use.
    984           Default: hyperbolic tangent (`tanh`).
    985           If you pass None, no activation is applied
    986           (ie. "linear" activation: `a(x) = x`).
    987       use_bias: Boolean, whether the layer uses a bias vector.
    988       kernel_initializer: Initializer for the `kernel` weights matrix,
    989           used for the linear transformation of the inputs.
    990       recurrent_initializer: Initializer for the `recurrent_kernel`
    991           weights matrix,
    992           used for the linear transformation of the recurrent state.
    993       bias_initializer: Initializer for the bias vector.
    994       kernel_regularizer: Regularizer function applied to
    995           the `kernel` weights matrix.
    996       recurrent_regularizer: Regularizer function applied to
    997           the `recurrent_kernel` weights matrix.
    998       bias_regularizer: Regularizer function applied to the bias vector.
    999       activity_regularizer: Regularizer function applied to
   1000           the output of the layer (its "activation")..
   1001       kernel_constraint: Constraint function applied to
   1002           the `kernel` weights matrix.
   1003       recurrent_constraint: Constraint function applied to
   1004           the `recurrent_kernel` weights matrix.
   1005       bias_constraint: Constraint function applied to the bias vector.
   1006       dropout: Float between 0 and 1.
   1007           Fraction of the units to drop for
   1008           the linear transformation of the inputs.
   1009       recurrent_dropout: Float between 0 and 1.
   1010           Fraction of the units to drop for
   1011           the linear transformation of the recurrent state.
   1012       return_sequences: Boolean. Whether to return the last output.
   1013           in the output sequence, or the full sequence.
   1014       return_state: Boolean. Whether to return the last state
   1015           in addition to the output.
   1016       go_backwards: Boolean (default False).
   1017           If True, process the input sequence backwards and return the
   1018           reversed sequence.
   1019       stateful: Boolean (default False). If True, the last state
   1020           for each sample at index i in a batch will be used as initial
   1021           state for the sample of index i in the following batch.
   1022       unroll: Boolean (default False).
   1023           If True, the network will be unrolled,
   1024           else a symbolic loop will be used.
   1025           Unrolling can speed-up a RNN,
   1026           although it tends to be more memory-intensive.
   1027           Unrolling is only suitable for short sequences.
   1028   """
   1029 
   1030   def __init__(self,
   1031                units,
   1032                activation='tanh',
   1033                use_bias=True,
   1034                kernel_initializer='glorot_uniform',
   1035                recurrent_initializer='orthogonal',
   1036                bias_initializer='zeros',
   1037                kernel_regularizer=None,
   1038                recurrent_regularizer=None,
   1039                bias_regularizer=None,
   1040                activity_regularizer=None,
   1041                kernel_constraint=None,
   1042                recurrent_constraint=None,
   1043                bias_constraint=None,
   1044                dropout=0.,
   1045                recurrent_dropout=0.,
   1046                return_sequences=False,
   1047                return_state=False,
   1048                go_backwards=False,
   1049                stateful=False,
   1050                unroll=False,
   1051                **kwargs):
   1052     if 'implementation' in kwargs:
   1053       kwargs.pop('implementation')
   1054       logging.warning('The `implementation` argument '
   1055                       'in `SimpleRNN` has been deprecated. '
   1056                       'Please remove it from your layer call.')
   1057     cell = SimpleRNNCell(
   1058         units,
   1059         activation=activation,
   1060         use_bias=use_bias,
   1061         kernel_initializer=kernel_initializer,
   1062         recurrent_initializer=recurrent_initializer,
   1063         bias_initializer=bias_initializer,
   1064         kernel_regularizer=kernel_regularizer,
   1065         recurrent_regularizer=recurrent_regularizer,
   1066         bias_regularizer=bias_regularizer,
   1067         kernel_constraint=kernel_constraint,
   1068         recurrent_constraint=recurrent_constraint,
   1069         bias_constraint=bias_constraint,
   1070         dropout=dropout,
   1071         recurrent_dropout=recurrent_dropout)
   1072     super(SimpleRNN, self).__init__(
   1073         cell,
   1074         return_sequences=return_sequences,
   1075         return_state=return_state,
   1076         go_backwards=go_backwards,
   1077         stateful=stateful,
   1078         unroll=unroll,
   1079         **kwargs)
   1080     self.activity_regularizer = regularizers.get(activity_regularizer)
   1081 
   1082   def call(self, inputs, mask=None, training=None, initial_state=None):
   1083     self.cell._dropout_mask = None
   1084     self.cell._recurrent_dropout_mask = None
   1085     return super(SimpleRNN, self).call(
   1086         inputs, mask=mask, training=training, initial_state=initial_state)
   1087 
   1088   @property
   1089   def units(self):
   1090     return self.cell.units
   1091 
   1092   @property
   1093   def activation(self):
   1094     return self.cell.activation
   1095 
   1096   @property
   1097   def use_bias(self):
   1098     return self.cell.use_bias
   1099 
   1100   @property
   1101   def kernel_initializer(self):
   1102     return self.cell.kernel_initializer
   1103 
   1104   @property
   1105   def recurrent_initializer(self):
   1106     return self.cell.recurrent_initializer
   1107 
   1108   @property
   1109   def bias_initializer(self):
   1110     return self.cell.bias_initializer
   1111 
   1112   @property
   1113   def kernel_regularizer(self):
   1114     return self.cell.kernel_regularizer
   1115 
   1116   @property
   1117   def recurrent_regularizer(self):
   1118     return self.cell.recurrent_regularizer
   1119 
   1120   @property
   1121   def bias_regularizer(self):
   1122     return self.cell.bias_regularizer
   1123 
   1124   @property
   1125   def kernel_constraint(self):
   1126     return self.cell.kernel_constraint
   1127 
   1128   @property
   1129   def recurrent_constraint(self):
   1130     return self.cell.recurrent_constraint
   1131 
   1132   @property
   1133   def bias_constraint(self):
   1134     return self.cell.bias_constraint
   1135 
   1136   @property
   1137   def dropout(self):
   1138     return self.cell.dropout
   1139 
   1140   @property
   1141   def recurrent_dropout(self):
   1142     return self.cell.recurrent_dropout
   1143 
   1144   def get_config(self):
   1145     config = {
   1146         'units':
   1147             self.units,
   1148         'activation':
   1149             activations.serialize(self.activation),
   1150         'use_bias':
   1151             self.use_bias,
   1152         'kernel_initializer':
   1153             initializers.serialize(self.kernel_initializer),
   1154         'recurrent_initializer':
   1155             initializers.serialize(self.recurrent_initializer),
   1156         'bias_initializer':
   1157             initializers.serialize(self.bias_initializer),
   1158         'kernel_regularizer':
   1159             regularizers.serialize(self.kernel_regularizer),
   1160         'recurrent_regularizer':
   1161             regularizers.serialize(self.recurrent_regularizer),
   1162         'bias_regularizer':
   1163             regularizers.serialize(self.bias_regularizer),
   1164         'activity_regularizer':
   1165             regularizers.serialize(self.activity_regularizer),
   1166         'kernel_constraint':
   1167             constraints.serialize(self.kernel_constraint),
   1168         'recurrent_constraint':
   1169             constraints.serialize(self.recurrent_constraint),
   1170         'bias_constraint':
   1171             constraints.serialize(self.bias_constraint),
   1172         'dropout':
   1173             self.dropout,
   1174         'recurrent_dropout':
   1175             self.recurrent_dropout
   1176     }
   1177     base_config = super(SimpleRNN, self).get_config()
   1178     del base_config['cell']
   1179     return dict(list(base_config.items()) + list(config.items()))
   1180 
   1181   @classmethod
   1182   def from_config(cls, config):
   1183     if 'implementation' in config:
   1184       config.pop('implementation')
   1185     return cls(**config)
   1186 
   1187 
   1188 @tf_export('keras.layers.GRUCell')
   1189 class GRUCell(Layer):
   1190   """Cell class for the GRU layer.
   1191 
   1192   Arguments:
   1193       units: Positive integer, dimensionality of the output space.
   1194       activation: Activation function to use.
   1195           Default: hyperbolic tangent (`tanh`).
   1196           If you pass None, no activation is applied
   1197           (ie. "linear" activation: `a(x) = x`).
   1198       recurrent_activation: Activation function to use
   1199           for the recurrent step.
   1200           Default: hard sigmoid (`hard_sigmoid`).
   1201           If you pass `None`, no activation is applied
   1202           (ie. "linear" activation: `a(x) = x`).
   1203       use_bias: Boolean, whether the layer uses a bias vector.
   1204       kernel_initializer: Initializer for the `kernel` weights matrix,
   1205           used for the linear transformation of the inputs.
   1206       recurrent_initializer: Initializer for the `recurrent_kernel`
   1207           weights matrix,
   1208           used for the linear transformation of the recurrent state.
   1209       bias_initializer: Initializer for the bias vector.
   1210       kernel_regularizer: Regularizer function applied to
   1211           the `kernel` weights matrix.
   1212       recurrent_regularizer: Regularizer function applied to
   1213           the `recurrent_kernel` weights matrix.
   1214       bias_regularizer: Regularizer function applied to the bias vector.
   1215       kernel_constraint: Constraint function applied to
   1216           the `kernel` weights matrix.
   1217       recurrent_constraint: Constraint function applied to
   1218           the `recurrent_kernel` weights matrix.
   1219       bias_constraint: Constraint function applied to the bias vector.
   1220       dropout: Float between 0 and 1.
   1221           Fraction of the units to drop for
   1222           the linear transformation of the inputs.
   1223       recurrent_dropout: Float between 0 and 1.
   1224           Fraction of the units to drop for
   1225           the linear transformation of the recurrent state.
   1226       implementation: Implementation mode, either 1 or 2.
   1227           Mode 1 will structure its operations as a larger number of
   1228           smaller dot products and additions, whereas mode 2 will
   1229           batch them into fewer, larger operations. These modes will
   1230           have different performance profiles on different hardware and
   1231           for different applications.
   1232   """
   1233 
   1234   def __init__(self,
   1235                units,
   1236                activation='tanh',
   1237                recurrent_activation='hard_sigmoid',
   1238                use_bias=True,
   1239                kernel_initializer='glorot_uniform',
   1240                recurrent_initializer='orthogonal',
   1241                bias_initializer='zeros',
   1242                kernel_regularizer=None,
   1243                recurrent_regularizer=None,
   1244                bias_regularizer=None,
   1245                kernel_constraint=None,
   1246                recurrent_constraint=None,
   1247                bias_constraint=None,
   1248                dropout=0.,
   1249                recurrent_dropout=0.,
   1250                implementation=1,
   1251                **kwargs):
   1252     super(GRUCell, self).__init__(**kwargs)
   1253     self.units = units
   1254     self.activation = activations.get(activation)
   1255     self.recurrent_activation = activations.get(recurrent_activation)
   1256     self.use_bias = use_bias
   1257 
   1258     self.kernel_initializer = initializers.get(kernel_initializer)
   1259     self.recurrent_initializer = initializers.get(recurrent_initializer)
   1260     self.bias_initializer = initializers.get(bias_initializer)
   1261 
   1262     self.kernel_regularizer = regularizers.get(kernel_regularizer)
   1263     self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
   1264     self.bias_regularizer = regularizers.get(bias_regularizer)
   1265 
   1266     self.kernel_constraint = constraints.get(kernel_constraint)
   1267     self.recurrent_constraint = constraints.get(recurrent_constraint)
   1268     self.bias_constraint = constraints.get(bias_constraint)
   1269 
   1270     self.dropout = min(1., max(0., dropout))
   1271     self.recurrent_dropout = min(1., max(0., recurrent_dropout))
   1272     self.implementation = implementation
   1273     self.state_size = self.units
   1274     self._dropout_mask = None
   1275     self._recurrent_dropout_mask = None
   1276 
   1277   @shape_type_conversion
   1278   def build(self, input_shape):
   1279     input_dim = input_shape[-1]
   1280     self.kernel = self.add_weight(
   1281         shape=(input_dim, self.units * 3),
   1282         name='kernel',
   1283         initializer=self.kernel_initializer,
   1284         regularizer=self.kernel_regularizer,
   1285         constraint=self.kernel_constraint)
   1286     self.recurrent_kernel = self.add_weight(
   1287         shape=(self.units, self.units * 3),
   1288         name='recurrent_kernel',
   1289         initializer=self.recurrent_initializer,
   1290         regularizer=self.recurrent_regularizer,
   1291         constraint=self.recurrent_constraint)
   1292 
   1293     if self.use_bias:
   1294       self.bias = self.add_weight(
   1295           shape=(self.units * 3,),
   1296           name='bias',
   1297           initializer=self.bias_initializer,
   1298           regularizer=self.bias_regularizer,
   1299           constraint=self.bias_constraint)
   1300     else:
   1301       self.bias = None
   1302 
   1303     self.kernel_z = self.kernel[:, :self.units]
   1304     self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units]
   1305     self.kernel_r = self.kernel[:, self.units:self.units * 2]
   1306     self.recurrent_kernel_r = self.recurrent_kernel[:, self.units:
   1307                                                     self.units * 2]
   1308     self.kernel_h = self.kernel[:, self.units * 2:]
   1309     self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:]
   1310 
   1311     if self.use_bias:
   1312       self.bias_z = self.bias[:self.units]
   1313       self.bias_r = self.bias[self.units:self.units * 2]
   1314       self.bias_h = self.bias[self.units * 2:]
   1315     else:
   1316       self.bias_z = None
   1317       self.bias_r = None
   1318       self.bias_h = None
   1319     self.built = True
   1320 
   1321   def call(self, inputs, states, training=None):
   1322     h_tm1 = states[0]  # previous memory
   1323 
   1324     if 0 < self.dropout < 1 and self._dropout_mask is None:
   1325       self._dropout_mask = _generate_dropout_mask(
   1326           _generate_dropout_ones(inputs,
   1327                                  K.shape(inputs)[-1]),
   1328           self.dropout,
   1329           training=training,
   1330           count=3)
   1331     if (0 < self.recurrent_dropout < 1 and
   1332         self._recurrent_dropout_mask is None):
   1333       self._recurrent_dropout_mask = _generate_dropout_mask(
   1334           _generate_dropout_ones(inputs, self.units),
   1335           self.recurrent_dropout,
   1336           training=training,
   1337           count=3)
   1338 
   1339     # dropout matrices for input units
   1340     dp_mask = self._dropout_mask
   1341     # dropout matrices for recurrent units
   1342     rec_dp_mask = self._recurrent_dropout_mask
   1343 
   1344     if self.implementation == 1:
   1345       if 0. < self.dropout < 1.:
   1346         inputs_z = inputs * dp_mask[0]
   1347         inputs_r = inputs * dp_mask[1]
   1348         inputs_h = inputs * dp_mask[2]
   1349       else:
   1350         inputs_z = inputs
   1351         inputs_r = inputs
   1352         inputs_h = inputs
   1353       x_z = K.dot(inputs_z, self.kernel_z)
   1354       x_r = K.dot(inputs_r, self.kernel_r)
   1355       x_h = K.dot(inputs_h, self.kernel_h)
   1356       if self.use_bias:
   1357         x_z = K.bias_add(x_z, self.bias_z)
   1358         x_r = K.bias_add(x_r, self.bias_r)
   1359         x_h = K.bias_add(x_h, self.bias_h)
   1360 
   1361       if 0. < self.recurrent_dropout < 1.:
   1362         h_tm1_z = h_tm1 * rec_dp_mask[0]
   1363         h_tm1_r = h_tm1 * rec_dp_mask[1]
   1364         h_tm1_h = h_tm1 * rec_dp_mask[2]
   1365       else:
   1366         h_tm1_z = h_tm1
   1367         h_tm1_r = h_tm1
   1368         h_tm1_h = h_tm1
   1369       z = self.recurrent_activation(
   1370           x_z + K.dot(h_tm1_z, self.recurrent_kernel_z))
   1371       r = self.recurrent_activation(
   1372           x_r + K.dot(h_tm1_r, self.recurrent_kernel_r))
   1373 
   1374       hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h))
   1375     else:
   1376       if 0. < self.dropout < 1.:
   1377         inputs *= dp_mask[0]
   1378       matrix_x = K.dot(inputs, self.kernel)
   1379       if self.use_bias:
   1380         matrix_x = K.bias_add(matrix_x, self.bias)
   1381       if 0. < self.recurrent_dropout < 1.:
   1382         h_tm1 *= rec_dp_mask[0]
   1383       matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
   1384 
   1385       x_z = matrix_x[:, :self.units]
   1386       x_r = matrix_x[:, self.units:2 * self.units]
   1387       recurrent_z = matrix_inner[:, :self.units]
   1388       recurrent_r = matrix_inner[:, self.units:2 * self.units]
   1389 
   1390       z = self.recurrent_activation(x_z + recurrent_z)
   1391       r = self.recurrent_activation(x_r + recurrent_r)
   1392 
   1393       x_h = matrix_x[:, 2 * self.units:]
   1394       recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
   1395       hh = self.activation(x_h + recurrent_h)
   1396     h = z * h_tm1 + (1 - z) * hh
   1397     if 0 < self.dropout + self.recurrent_dropout:
   1398       if training is None:
   1399         h._uses_learning_phase = True
   1400     return h, [h]
   1401 
   1402   def get_config(self):
   1403     config = {
   1404         'units':
   1405             self.units,
   1406         'activation':
   1407             activations.serialize(self.activation),
   1408         'recurrent_activation':
   1409             activations.serialize(self.recurrent_activation),
   1410         'use_bias':
   1411             self.use_bias,
   1412         'kernel_initializer':
   1413             initializers.serialize(self.kernel_initializer),
   1414         'recurrent_initializer':
   1415             initializers.serialize(self.recurrent_initializer),
   1416         'bias_initializer':
   1417             initializers.serialize(self.bias_initializer),
   1418         'kernel_regularizer':
   1419             regularizers.serialize(self.kernel_regularizer),
   1420         'recurrent_regularizer':
   1421             regularizers.serialize(self.recurrent_regularizer),
   1422         'bias_regularizer':
   1423             regularizers.serialize(self.bias_regularizer),
   1424         'kernel_constraint':
   1425             constraints.serialize(self.kernel_constraint),
   1426         'recurrent_constraint':
   1427             constraints.serialize(self.recurrent_constraint),
   1428         'bias_constraint':
   1429             constraints.serialize(self.bias_constraint),
   1430         'dropout':
   1431             self.dropout,
   1432         'recurrent_dropout':
   1433             self.recurrent_dropout,
   1434         'implementation':
   1435             self.implementation
   1436     }
   1437     base_config = super(GRUCell, self).get_config()
   1438     return dict(list(base_config.items()) + list(config.items()))
   1439 
   1440 
   1441 @tf_export('keras.layers.GRU')
   1442 class GRU(RNN):
   1443   """Gated Recurrent Unit - Cho et al.
   1444 
   1445   2014.
   1446 
   1447   Arguments:
   1448       units: Positive integer, dimensionality of the output space.
   1449       activation: Activation function to use.
   1450           Default: hyperbolic tangent (`tanh`).
   1451           If you pass `None`, no activation is applied
   1452           (ie. "linear" activation: `a(x) = x`).
   1453       recurrent_activation: Activation function to use
   1454           for the recurrent step.
   1455           Default: hard sigmoid (`hard_sigmoid`).
   1456           If you pass `None`, no activation is applied
   1457           (ie. "linear" activation: `a(x) = x`).
   1458       use_bias: Boolean, whether the layer uses a bias vector.
   1459       kernel_initializer: Initializer for the `kernel` weights matrix,
   1460           used for the linear transformation of the inputs.
   1461       recurrent_initializer: Initializer for the `recurrent_kernel`
   1462           weights matrix,
   1463           used for the linear transformation of the recurrent state.
   1464       bias_initializer: Initializer for the bias vector.
   1465       kernel_regularizer: Regularizer function applied to
   1466           the `kernel` weights matrix.
   1467       recurrent_regularizer: Regularizer function applied to
   1468           the `recurrent_kernel` weights matrix.
   1469       bias_regularizer: Regularizer function applied to the bias vector.
   1470       activity_regularizer: Regularizer function applied to
   1471           the output of the layer (its "activation")..
   1472       kernel_constraint: Constraint function applied to
   1473           the `kernel` weights matrix.
   1474       recurrent_constraint: Constraint function applied to
   1475           the `recurrent_kernel` weights matrix.
   1476       bias_constraint: Constraint function applied to the bias vector.
   1477       dropout: Float between 0 and 1.
   1478           Fraction of the units to drop for
   1479           the linear transformation of the inputs.
   1480       recurrent_dropout: Float between 0 and 1.
   1481           Fraction of the units to drop for
   1482           the linear transformation of the recurrent state.
   1483       implementation: Implementation mode, either 1 or 2.
   1484           Mode 1 will structure its operations as a larger number of
   1485           smaller dot products and additions, whereas mode 2 will
   1486           batch them into fewer, larger operations. These modes will
   1487           have different performance profiles on different hardware and
   1488           for different applications.
   1489       return_sequences: Boolean. Whether to return the last output.
   1490           in the output sequence, or the full sequence.
   1491       return_state: Boolean. Whether to return the last state
   1492           in addition to the output.
   1493       go_backwards: Boolean (default False).
   1494           If True, process the input sequence backwards and return the
   1495           reversed sequence.
   1496       stateful: Boolean (default False). If True, the last state
   1497           for each sample at index i in a batch will be used as initial
   1498           state for the sample of index i in the following batch.
   1499       unroll: Boolean (default False).
   1500           If True, the network will be unrolled,
   1501           else a symbolic loop will be used.
   1502           Unrolling can speed-up a RNN,
   1503           although it tends to be more memory-intensive.
   1504           Unrolling is only suitable for short sequences.
   1505 
   1506   """
   1507 
   1508   def __init__(self,
   1509                units,
   1510                activation='tanh',
   1511                recurrent_activation='hard_sigmoid',
   1512                use_bias=True,
   1513                kernel_initializer='glorot_uniform',
   1514                recurrent_initializer='orthogonal',
   1515                bias_initializer='zeros',
   1516                kernel_regularizer=None,
   1517                recurrent_regularizer=None,
   1518                bias_regularizer=None,
   1519                activity_regularizer=None,
   1520                kernel_constraint=None,
   1521                recurrent_constraint=None,
   1522                bias_constraint=None,
   1523                dropout=0.,
   1524                recurrent_dropout=0.,
   1525                implementation=1,
   1526                return_sequences=False,
   1527                return_state=False,
   1528                go_backwards=False,
   1529                stateful=False,
   1530                unroll=False,
   1531                **kwargs):
   1532     if implementation == 0:
   1533       logging.warning('`implementation=0` has been deprecated, '
   1534                       'and now defaults to `implementation=1`.'
   1535                       'Please update your layer call.')
   1536     cell = GRUCell(
   1537         units,
   1538         activation=activation,
   1539         recurrent_activation=recurrent_activation,
   1540         use_bias=use_bias,
   1541         kernel_initializer=kernel_initializer,
   1542         recurrent_initializer=recurrent_initializer,
   1543         bias_initializer=bias_initializer,
   1544         kernel_regularizer=kernel_regularizer,
   1545         recurrent_regularizer=recurrent_regularizer,
   1546         bias_regularizer=bias_regularizer,
   1547         kernel_constraint=kernel_constraint,
   1548         recurrent_constraint=recurrent_constraint,
   1549         bias_constraint=bias_constraint,
   1550         dropout=dropout,
   1551         recurrent_dropout=recurrent_dropout,
   1552         implementation=implementation)
   1553     super(GRU, self).__init__(
   1554         cell,
   1555         return_sequences=return_sequences,
   1556         return_state=return_state,
   1557         go_backwards=go_backwards,
   1558         stateful=stateful,
   1559         unroll=unroll,
   1560         **kwargs)
   1561     self.activity_regularizer = regularizers.get(activity_regularizer)
   1562 
   1563   def call(self, inputs, mask=None, training=None, initial_state=None):
   1564     self.cell._dropout_mask = None
   1565     self.cell._recurrent_dropout_mask = None
   1566     return super(GRU, self).call(
   1567         inputs, mask=mask, training=training, initial_state=initial_state)
   1568 
   1569   @property
   1570   def units(self):
   1571     return self.cell.units
   1572 
   1573   @property
   1574   def activation(self):
   1575     return self.cell.activation
   1576 
   1577   @property
   1578   def recurrent_activation(self):
   1579     return self.cell.recurrent_activation
   1580 
   1581   @property
   1582   def use_bias(self):
   1583     return self.cell.use_bias
   1584 
   1585   @property
   1586   def kernel_initializer(self):
   1587     return self.cell.kernel_initializer
   1588 
   1589   @property
   1590   def recurrent_initializer(self):
   1591     return self.cell.recurrent_initializer
   1592 
   1593   @property
   1594   def bias_initializer(self):
   1595     return self.cell.bias_initializer
   1596 
   1597   @property
   1598   def kernel_regularizer(self):
   1599     return self.cell.kernel_regularizer
   1600 
   1601   @property
   1602   def recurrent_regularizer(self):
   1603     return self.cell.recurrent_regularizer
   1604 
   1605   @property
   1606   def bias_regularizer(self):
   1607     return self.cell.bias_regularizer
   1608 
   1609   @property
   1610   def kernel_constraint(self):
   1611     return self.cell.kernel_constraint
   1612 
   1613   @property
   1614   def recurrent_constraint(self):
   1615     return self.cell.recurrent_constraint
   1616 
   1617   @property
   1618   def bias_constraint(self):
   1619     return self.cell.bias_constraint
   1620 
   1621   @property
   1622   def dropout(self):
   1623     return self.cell.dropout
   1624 
   1625   @property
   1626   def recurrent_dropout(self):
   1627     return self.cell.recurrent_dropout
   1628 
   1629   @property
   1630   def implementation(self):
   1631     return self.cell.implementation
   1632 
   1633   def get_config(self):
   1634     config = {
   1635         'units':
   1636             self.units,
   1637         'activation':
   1638             activations.serialize(self.activation),
   1639         'recurrent_activation':
   1640             activations.serialize(self.recurrent_activation),
   1641         'use_bias':
   1642             self.use_bias,
   1643         'kernel_initializer':
   1644             initializers.serialize(self.kernel_initializer),
   1645         'recurrent_initializer':
   1646             initializers.serialize(self.recurrent_initializer),
   1647         'bias_initializer':
   1648             initializers.serialize(self.bias_initializer),
   1649         'kernel_regularizer':
   1650             regularizers.serialize(self.kernel_regularizer),
   1651         'recurrent_regularizer':
   1652             regularizers.serialize(self.recurrent_regularizer),
   1653         'bias_regularizer':
   1654             regularizers.serialize(self.bias_regularizer),
   1655         'activity_regularizer':
   1656             regularizers.serialize(self.activity_regularizer),
   1657         'kernel_constraint':
   1658             constraints.serialize(self.kernel_constraint),
   1659         'recurrent_constraint':
   1660             constraints.serialize(self.recurrent_constraint),
   1661         'bias_constraint':
   1662             constraints.serialize(self.bias_constraint),
   1663         'dropout':
   1664             self.dropout,
   1665         'recurrent_dropout':
   1666             self.recurrent_dropout,
   1667         'implementation':
   1668             self.implementation
   1669     }
   1670     base_config = super(GRU, self).get_config()
   1671     del base_config['cell']
   1672     return dict(list(base_config.items()) + list(config.items()))
   1673 
   1674   @classmethod
   1675   def from_config(cls, config):
   1676     if 'implementation' in config and config['implementation'] == 0:
   1677       config['implementation'] = 1
   1678     return cls(**config)
   1679 
   1680 
   1681 @tf_export('keras.layers.LSTMCell')
   1682 class LSTMCell(Layer):
   1683   """Cell class for the LSTM layer.
   1684 
   1685   Arguments:
   1686       units: Positive integer, dimensionality of the output space.
   1687       activation: Activation function to use.
   1688           Default: hyperbolic tangent (`tanh`).
   1689           If you pass `None`, no activation is applied
   1690           (ie. "linear" activation: `a(x) = x`).
   1691       recurrent_activation: Activation function to use
   1692           for the recurrent step.
   1693           Default: hard sigmoid (`hard_sigmoid`).
   1694           If you pass `None`, no activation is applied
   1695           (ie. "linear" activation: `a(x) = x`).x
   1696       use_bias: Boolean, whether the layer uses a bias vector.
   1697       kernel_initializer: Initializer for the `kernel` weights matrix,
   1698           used for the linear transformation of the inputs.
   1699       recurrent_initializer: Initializer for the `recurrent_kernel`
   1700           weights matrix,
   1701           used for the linear transformation of the recurrent state.
   1702       bias_initializer: Initializer for the bias vector.
   1703       unit_forget_bias: Boolean.
   1704           If True, add 1 to the bias of the forget gate at initialization.
   1705           Setting it to true will also force `bias_initializer="zeros"`.
   1706           This is recommended in [Jozefowicz et
   1707             al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
   1708       kernel_regularizer: Regularizer function applied to
   1709           the `kernel` weights matrix.
   1710       recurrent_regularizer: Regularizer function applied to
   1711           the `recurrent_kernel` weights matrix.
   1712       bias_regularizer: Regularizer function applied to the bias vector.
   1713       kernel_constraint: Constraint function applied to
   1714           the `kernel` weights matrix.
   1715       recurrent_constraint: Constraint function applied to
   1716           the `recurrent_kernel` weights matrix.
   1717       bias_constraint: Constraint function applied to the bias vector.
   1718       dropout: Float between 0 and 1.
   1719           Fraction of the units to drop for
   1720           the linear transformation of the inputs.
   1721       recurrent_dropout: Float between 0 and 1.
   1722           Fraction of the units to drop for
   1723           the linear transformation of the recurrent state.
   1724       implementation: Implementation mode, either 1 or 2.
   1725           Mode 1 will structure its operations as a larger number of
   1726           smaller dot products and additions, whereas mode 2 will
   1727           batch them into fewer, larger operations. These modes will
   1728           have different performance profiles on different hardware and
   1729           for different applications.
   1730   """
   1731 
   1732   def __init__(self,
   1733                units,
   1734                activation='tanh',
   1735                recurrent_activation='hard_sigmoid',
   1736                use_bias=True,
   1737                kernel_initializer='glorot_uniform',
   1738                recurrent_initializer='orthogonal',
   1739                bias_initializer='zeros',
   1740                unit_forget_bias=True,
   1741                kernel_regularizer=None,
   1742                recurrent_regularizer=None,
   1743                bias_regularizer=None,
   1744                kernel_constraint=None,
   1745                recurrent_constraint=None,
   1746                bias_constraint=None,
   1747                dropout=0.,
   1748                recurrent_dropout=0.,
   1749                implementation=1,
   1750                **kwargs):
   1751     super(LSTMCell, self).__init__(**kwargs)
   1752     self.units = units
   1753     self.activation = activations.get(activation)
   1754     self.recurrent_activation = activations.get(recurrent_activation)
   1755     self.use_bias = use_bias
   1756 
   1757     self.kernel_initializer = initializers.get(kernel_initializer)
   1758     self.recurrent_initializer = initializers.get(recurrent_initializer)
   1759     self.bias_initializer = initializers.get(bias_initializer)
   1760     self.unit_forget_bias = unit_forget_bias
   1761 
   1762     self.kernel_regularizer = regularizers.get(kernel_regularizer)
   1763     self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
   1764     self.bias_regularizer = regularizers.get(bias_regularizer)
   1765 
   1766     self.kernel_constraint = constraints.get(kernel_constraint)
   1767     self.recurrent_constraint = constraints.get(recurrent_constraint)
   1768     self.bias_constraint = constraints.get(bias_constraint)
   1769 
   1770     self.dropout = min(1., max(0., dropout))
   1771     self.recurrent_dropout = min(1., max(0., recurrent_dropout))
   1772     self.implementation = implementation
   1773     self.state_size = (self.units, self.units)
   1774     self._dropout_mask = None
   1775     self._recurrent_dropout_mask = None
   1776 
   1777   @shape_type_conversion
   1778   def build(self, input_shape):
   1779     input_dim = input_shape[-1]
   1780     self.kernel = self.add_weight(
   1781         shape=(input_dim, self.units * 4),
   1782         name='kernel',
   1783         initializer=self.kernel_initializer,
   1784         regularizer=self.kernel_regularizer,
   1785         constraint=self.kernel_constraint)
   1786     self.recurrent_kernel = self.add_weight(
   1787         shape=(self.units, self.units * 4),
   1788         name='recurrent_kernel',
   1789         initializer=self.recurrent_initializer,
   1790         regularizer=self.recurrent_regularizer,
   1791         constraint=self.recurrent_constraint)
   1792 
   1793     if self.use_bias:
   1794       if self.unit_forget_bias:
   1795 
   1796         def bias_initializer(_, *args, **kwargs):
   1797           return K.concatenate([
   1798               self.bias_initializer((self.units,), *args, **kwargs),
   1799               initializers.Ones()((self.units,), *args, **kwargs),
   1800               self.bias_initializer((self.units * 2,), *args, **kwargs),
   1801           ])
   1802       else:
   1803         bias_initializer = self.bias_initializer
   1804       self.bias = self.add_weight(
   1805           shape=(self.units * 4,),
   1806           name='bias',
   1807           initializer=bias_initializer,
   1808           regularizer=self.bias_regularizer,
   1809           constraint=self.bias_constraint)
   1810     else:
   1811       self.bias = None
   1812 
   1813     self.kernel_i = self.kernel[:, :self.units]
   1814     self.kernel_f = self.kernel[:, self.units:self.units * 2]
   1815     self.kernel_c = self.kernel[:, self.units * 2:self.units * 3]
   1816     self.kernel_o = self.kernel[:, self.units * 3:]
   1817 
   1818     self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
   1819     self.recurrent_kernel_f = self.recurrent_kernel[:, self.units:
   1820                                                     self.units * 2]
   1821     self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2:
   1822                                                     self.units * 3]
   1823     self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
   1824 
   1825     if self.use_bias:
   1826       self.bias_i = self.bias[:self.units]
   1827       self.bias_f = self.bias[self.units:self.units * 2]
   1828       self.bias_c = self.bias[self.units * 2:self.units * 3]
   1829       self.bias_o = self.bias[self.units * 3:]
   1830     else:
   1831       self.bias_i = None
   1832       self.bias_f = None
   1833       self.bias_c = None
   1834       self.bias_o = None
   1835     self.built = True
   1836 
   1837   def call(self, inputs, states, training=None):
   1838     if 0 < self.dropout < 1 and self._dropout_mask is None:
   1839       self._dropout_mask = _generate_dropout_mask(
   1840           _generate_dropout_ones(inputs,
   1841                                  K.shape(inputs)[-1]),
   1842           self.dropout,
   1843           training=training,
   1844           count=4)
   1845     if (0 < self.recurrent_dropout < 1 and
   1846         self._recurrent_dropout_mask is None):
   1847       self._recurrent_dropout_mask = _generate_dropout_mask(
   1848           _generate_dropout_ones(inputs, self.units),
   1849           self.recurrent_dropout,
   1850           training=training,
   1851           count=4)
   1852 
   1853     # dropout matrices for input units
   1854     dp_mask = self._dropout_mask
   1855     # dropout matrices for recurrent units
   1856     rec_dp_mask = self._recurrent_dropout_mask
   1857 
   1858     h_tm1 = states[0]  # previous memory state
   1859     c_tm1 = states[1]  # previous carry state
   1860 
   1861     if self.implementation == 1:
   1862       if 0 < self.dropout < 1.:
   1863         inputs_i = inputs * dp_mask[0]
   1864         inputs_f = inputs * dp_mask[1]
   1865         inputs_c = inputs * dp_mask[2]
   1866         inputs_o = inputs * dp_mask[3]
   1867       else:
   1868         inputs_i = inputs
   1869         inputs_f = inputs
   1870         inputs_c = inputs
   1871         inputs_o = inputs
   1872       x_i = K.dot(inputs_i, self.kernel_i)
   1873       x_f = K.dot(inputs_f, self.kernel_f)
   1874       x_c = K.dot(inputs_c, self.kernel_c)
   1875       x_o = K.dot(inputs_o, self.kernel_o)
   1876       if self.use_bias:
   1877         x_i = K.bias_add(x_i, self.bias_i)
   1878         x_f = K.bias_add(x_f, self.bias_f)
   1879         x_c = K.bias_add(x_c, self.bias_c)
   1880         x_o = K.bias_add(x_o, self.bias_o)
   1881 
   1882       if 0 < self.recurrent_dropout < 1.:
   1883         h_tm1_i = h_tm1 * rec_dp_mask[0]
   1884         h_tm1_f = h_tm1 * rec_dp_mask[1]
   1885         h_tm1_c = h_tm1 * rec_dp_mask[2]
   1886         h_tm1_o = h_tm1 * rec_dp_mask[3]
   1887       else:
   1888         h_tm1_i = h_tm1
   1889         h_tm1_f = h_tm1
   1890         h_tm1_c = h_tm1
   1891         h_tm1_o = h_tm1
   1892       i = self.recurrent_activation(
   1893           x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
   1894       f = self.recurrent_activation(
   1895           x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
   1896       c = f * c_tm1 + i * self.activation(
   1897           x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
   1898       o = self.recurrent_activation(
   1899           x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
   1900     else:
   1901       if 0. < self.dropout < 1.:
   1902         inputs *= dp_mask[0]
   1903       z = K.dot(inputs, self.kernel)
   1904       if 0. < self.recurrent_dropout < 1.:
   1905         h_tm1 *= rec_dp_mask[0]
   1906       z += K.dot(h_tm1, self.recurrent_kernel)
   1907       if self.use_bias:
   1908         z = K.bias_add(z, self.bias)
   1909 
   1910       z0 = z[:, :self.units]
   1911       z1 = z[:, self.units:2 * self.units]
   1912       z2 = z[:, 2 * self.units:3 * self.units]
   1913       z3 = z[:, 3 * self.units:]
   1914 
   1915       i = self.recurrent_activation(z0)
   1916       f = self.recurrent_activation(z1)
   1917       c = f * c_tm1 + i * self.activation(z2)
   1918       o = self.recurrent_activation(z3)
   1919 
   1920     h = o * self.activation(c)
   1921     if 0 < self.dropout + self.recurrent_dropout:
   1922       if training is None:
   1923         h._uses_learning_phase = True
   1924     return h, [h, c]
   1925 
   1926   def get_config(self):
   1927     config = {
   1928         'units':
   1929             self.units,
   1930         'activation':
   1931             activations.serialize(self.activation),
   1932         'recurrent_activation':
   1933             activations.serialize(self.recurrent_activation),
   1934         'use_bias':
   1935             self.use_bias,
   1936         'kernel_initializer':
   1937             initializers.serialize(self.kernel_initializer),
   1938         'recurrent_initializer':
   1939             initializers.serialize(self.recurrent_initializer),
   1940         'bias_initializer':
   1941             initializers.serialize(self.bias_initializer),
   1942         'unit_forget_bias':
   1943             self.unit_forget_bias,
   1944         'kernel_regularizer':
   1945             regularizers.serialize(self.kernel_regularizer),
   1946         'recurrent_regularizer':
   1947             regularizers.serialize(self.recurrent_regularizer),
   1948         'bias_regularizer':
   1949             regularizers.serialize(self.bias_regularizer),
   1950         'kernel_constraint':
   1951             constraints.serialize(self.kernel_constraint),
   1952         'recurrent_constraint':
   1953             constraints.serialize(self.recurrent_constraint),
   1954         'bias_constraint':
   1955             constraints.serialize(self.bias_constraint),
   1956         'dropout':
   1957             self.dropout,
   1958         'recurrent_dropout':
   1959             self.recurrent_dropout,
   1960         'implementation':
   1961             self.implementation
   1962     }
   1963     base_config = super(LSTMCell, self).get_config()
   1964     return dict(list(base_config.items()) + list(config.items()))
   1965 
   1966 
   1967 @tf_export('keras.layers.LSTM')
   1968 class LSTM(RNN):
   1969   """Long-Short Term Memory layer - Hochreiter 1997.
   1970 
   1971   Arguments:
   1972       units: Positive integer, dimensionality of the output space.
   1973       activation: Activation function to use.
   1974           Default: hyperbolic tangent (`tanh`).
   1975           If you pass `None`, no activation is applied
   1976           (ie. "linear" activation: `a(x) = x`).
   1977       recurrent_activation: Activation function to use
   1978           for the recurrent step.
   1979           Default: hard sigmoid (`hard_sigmoid`).
   1980           If you pass `None`, no activation is applied
   1981           (ie. "linear" activation: `a(x) = x`).
   1982       use_bias: Boolean, whether the layer uses a bias vector.
   1983       kernel_initializer: Initializer for the `kernel` weights matrix,
   1984           used for the linear transformation of the inputs..
   1985       recurrent_initializer: Initializer for the `recurrent_kernel`
   1986           weights matrix,
   1987           used for the linear transformation of the recurrent state..
   1988       bias_initializer: Initializer for the bias vector.
   1989       unit_forget_bias: Boolean.
   1990           If True, add 1 to the bias of the forget gate at initialization.
   1991           Setting it to true will also force `bias_initializer="zeros"`.
   1992           This is recommended in [Jozefowicz et
   1993             al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
   1994       kernel_regularizer: Regularizer function applied to
   1995           the `kernel` weights matrix.
   1996       recurrent_regularizer: Regularizer function applied to
   1997           the `recurrent_kernel` weights matrix.
   1998       bias_regularizer: Regularizer function applied to the bias vector.
   1999       activity_regularizer: Regularizer function applied to
   2000           the output of the layer (its "activation")..
   2001       kernel_constraint: Constraint function applied to
   2002           the `kernel` weights matrix.
   2003       recurrent_constraint: Constraint function applied to
   2004           the `recurrent_kernel` weights matrix.
   2005       bias_constraint: Constraint function applied to the bias vector.
   2006       dropout: Float between 0 and 1.
   2007           Fraction of the units to drop for
   2008           the linear transformation of the inputs.
   2009       recurrent_dropout: Float between 0 and 1.
   2010           Fraction of the units to drop for
   2011           the linear transformation of the recurrent state.
   2012       implementation: Implementation mode, either 1 or 2.
   2013           Mode 1 will structure its operations as a larger number of
   2014           smaller dot products and additions, whereas mode 2 will
   2015           batch them into fewer, larger operations. These modes will
   2016           have different performance profiles on different hardware and
   2017           for different applications.
   2018       return_sequences: Boolean. Whether to return the last output.
   2019           in the output sequence, or the full sequence.
   2020       return_state: Boolean. Whether to return the last state
   2021           in addition to the output.
   2022       go_backwards: Boolean (default False).
   2023           If True, process the input sequence backwards and return the
   2024           reversed sequence.
   2025       stateful: Boolean (default False). If True, the last state
   2026           for each sample at index i in a batch will be used as initial
   2027           state for the sample of index i in the following batch.
   2028       unroll: Boolean (default False).
   2029           If True, the network will be unrolled,
   2030           else a symbolic loop will be used.
   2031           Unrolling can speed-up a RNN,
   2032           although it tends to be more memory-intensive.
   2033           Unrolling is only suitable for short sequences.
   2034 
   2035   """
   2036 
   2037   def __init__(self,
   2038                units,
   2039                activation='tanh',
   2040                recurrent_activation='hard_sigmoid',
   2041                use_bias=True,
   2042                kernel_initializer='glorot_uniform',
   2043                recurrent_initializer='orthogonal',
   2044                bias_initializer='zeros',
   2045                unit_forget_bias=True,
   2046                kernel_regularizer=None,
   2047                recurrent_regularizer=None,
   2048                bias_regularizer=None,
   2049                activity_regularizer=None,
   2050                kernel_constraint=None,
   2051                recurrent_constraint=None,
   2052                bias_constraint=None,
   2053                dropout=0.,
   2054                recurrent_dropout=0.,
   2055                implementation=1,
   2056                return_sequences=False,
   2057                return_state=False,
   2058                go_backwards=False,
   2059                stateful=False,
   2060                unroll=False,
   2061                **kwargs):
   2062     if implementation == 0:
   2063       logging.warning('`implementation=0` has been deprecated, '
   2064                       'and now defaults to `implementation=1`.'
   2065                       'Please update your layer call.')
   2066     cell = LSTMCell(
   2067         units,
   2068         activation=activation,
   2069         recurrent_activation=recurrent_activation,
   2070         use_bias=use_bias,
   2071         kernel_initializer=kernel_initializer,
   2072         recurrent_initializer=recurrent_initializer,
   2073         unit_forget_bias=unit_forget_bias,
   2074         bias_initializer=bias_initializer,
   2075         kernel_regularizer=kernel_regularizer,
   2076         recurrent_regularizer=recurrent_regularizer,
   2077         bias_regularizer=bias_regularizer,
   2078         kernel_constraint=kernel_constraint,
   2079         recurrent_constraint=recurrent_constraint,
   2080         bias_constraint=bias_constraint,
   2081         dropout=dropout,
   2082         recurrent_dropout=recurrent_dropout,
   2083         implementation=implementation)
   2084     super(LSTM, self).__init__(
   2085         cell,
   2086         return_sequences=return_sequences,
   2087         return_state=return_state,
   2088         go_backwards=go_backwards,
   2089         stateful=stateful,
   2090         unroll=unroll,
   2091         **kwargs)
   2092     self.activity_regularizer = regularizers.get(activity_regularizer)
   2093 
   2094   def call(self, inputs, mask=None, training=None, initial_state=None):
   2095     self.cell._dropout_mask = None
   2096     self.cell._recurrent_dropout_mask = None
   2097     return super(LSTM, self).call(
   2098         inputs, mask=mask, training=training, initial_state=initial_state)
   2099 
   2100   @property
   2101   def units(self):
   2102     return self.cell.units
   2103 
   2104   @property
   2105   def activation(self):
   2106     return self.cell.activation
   2107 
   2108   @property
   2109   def recurrent_activation(self):
   2110     return self.cell.recurrent_activation
   2111 
   2112   @property
   2113   def use_bias(self):
   2114     return self.cell.use_bias
   2115 
   2116   @property
   2117   def kernel_initializer(self):
   2118     return self.cell.kernel_initializer
   2119 
   2120   @property
   2121   def recurrent_initializer(self):
   2122     return self.cell.recurrent_initializer
   2123 
   2124   @property
   2125   def bias_initializer(self):
   2126     return self.cell.bias_initializer
   2127 
   2128   @property
   2129   def unit_forget_bias(self):
   2130     return self.cell.unit_forget_bias
   2131 
   2132   @property
   2133   def kernel_regularizer(self):
   2134     return self.cell.kernel_regularizer
   2135 
   2136   @property
   2137   def recurrent_regularizer(self):
   2138     return self.cell.recurrent_regularizer
   2139 
   2140   @property
   2141   def bias_regularizer(self):
   2142     return self.cell.bias_regularizer
   2143 
   2144   @property
   2145   def kernel_constraint(self):
   2146     return self.cell.kernel_constraint
   2147 
   2148   @property
   2149   def recurrent_constraint(self):
   2150     return self.cell.recurrent_constraint
   2151 
   2152   @property
   2153   def bias_constraint(self):
   2154     return self.cell.bias_constraint
   2155 
   2156   @property
   2157   def dropout(self):
   2158     return self.cell.dropout
   2159 
   2160   @property
   2161   def recurrent_dropout(self):
   2162     return self.cell.recurrent_dropout
   2163 
   2164   @property
   2165   def implementation(self):
   2166     return self.cell.implementation
   2167 
   2168   def get_config(self):
   2169     config = {
   2170         'units':
   2171             self.units,
   2172         'activation':
   2173             activations.serialize(self.activation),
   2174         'recurrent_activation':
   2175             activations.serialize(self.recurrent_activation),
   2176         'use_bias':
   2177             self.use_bias,
   2178         'kernel_initializer':
   2179             initializers.serialize(self.kernel_initializer),
   2180         'recurrent_initializer':
   2181             initializers.serialize(self.recurrent_initializer),
   2182         'bias_initializer':
   2183             initializers.serialize(self.bias_initializer),
   2184         'unit_forget_bias':
   2185             self.unit_forget_bias,
   2186         'kernel_regularizer':
   2187             regularizers.serialize(self.kernel_regularizer),
   2188         'recurrent_regularizer':
   2189             regularizers.serialize(self.recurrent_regularizer),
   2190         'bias_regularizer':
   2191             regularizers.serialize(self.bias_regularizer),
   2192         'activity_regularizer':
   2193             regularizers.serialize(self.activity_regularizer),
   2194         'kernel_constraint':
   2195             constraints.serialize(self.kernel_constraint),
   2196         'recurrent_constraint':
   2197             constraints.serialize(self.recurrent_constraint),
   2198         'bias_constraint':
   2199             constraints.serialize(self.bias_constraint),
   2200         'dropout':
   2201             self.dropout,
   2202         'recurrent_dropout':
   2203             self.recurrent_dropout,
   2204         'implementation':
   2205             self.implementation
   2206     }
   2207     base_config = super(LSTM, self).get_config()
   2208     del base_config['cell']
   2209     return dict(list(base_config.items()) + list(config.items()))
   2210 
   2211   @classmethod
   2212   def from_config(cls, config):
   2213     if 'implementation' in config and config['implementation'] == 0:
   2214       config['implementation'] = 1
   2215     return cls(**config)
   2216 
   2217 
   2218 def _generate_dropout_ones(inputs, dims):
   2219   return K.ones((K.shape(inputs)[0], dims))
   2220 
   2221 
   2222 def _generate_dropout_mask(ones, rate, training=None, count=1):
   2223 
   2224   def dropped_inputs():
   2225     return K.dropout(ones, rate)
   2226 
   2227   if count > 1:
   2228     return [
   2229         K.in_train_phase(dropped_inputs, ones, training=training)
   2230         for _ in range(count)
   2231     ]
   2232   return K.in_train_phase(dropped_inputs, ones, training=training)
   2233 
   2234 
   2235 class Recurrent(Layer):
   2236   """Deprecated abstract base class for recurrent layers.
   2237 
   2238   It still exists because it is leveraged by the convolutional-recurrent layers.
   2239   It will be removed entirely in the future.
   2240   It was never part of the public API.
   2241   Do not use.
   2242 
   2243   Arguments:
   2244       weights: list of Numpy arrays to set as initial weights.
   2245           The list should have 3 elements, of shapes:
   2246           `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
   2247       return_sequences: Boolean. Whether to return the last output
   2248           in the output sequence, or the full sequence.
   2249       return_state: Boolean. Whether to return the last state
   2250           in addition to the output.
   2251       go_backwards: Boolean (default False).
   2252           If True, process the input sequence backwards and return the
   2253           reversed sequence.
   2254       stateful: Boolean (default False). If True, the last state
   2255           for each sample at index i in a batch will be used as initial
   2256           state for the sample of index i in the following batch.
   2257       unroll: Boolean (default False).
   2258           If True, the network will be unrolled,
   2259           else a symbolic loop will be used.
   2260           Unrolling can speed-up a RNN,
   2261           although it tends to be more memory-intensive.
   2262           Unrolling is only suitable for short sequences.
   2263       implementation: one of {0, 1, or 2}.
   2264           If set to 0, the RNN will use
   2265           an implementation that uses fewer, larger matrix products,
   2266           thus running faster on CPU but consuming more memory.
   2267           If set to 1, the RNN will use more matrix products,
   2268           but smaller ones, thus running slower
   2269           (may actually be faster on GPU) while consuming less memory.
   2270           If set to 2 (LSTM/GRU only),
   2271           the RNN will combine the input gate,
   2272           the forget gate and the output gate into a single matrix,
   2273           enabling more time-efficient parallelization on the GPU.
   2274           Note: RNN dropout must be shared for all gates,
   2275           resulting in a slightly reduced regularization.
   2276       input_dim: dimensionality of the input (integer).
   2277           This argument (or alternatively, the keyword argument `input_shape`)
   2278           is required when using this layer as the first layer in a model.
   2279       input_length: Length of input sequences, to be specified
   2280           when it is constant.
   2281           This argument is required if you are going to connect
   2282           `Flatten` then `Dense` layers upstream
   2283           (without it, the shape of the dense outputs cannot be computed).
   2284           Note that if the recurrent layer is not the first layer
   2285           in your model, you would need to specify the input length
   2286           at the level of the first layer
   2287           (e.g. via the `input_shape` argument)
   2288 
   2289   Input shape:
   2290       3D tensor with shape `(batch_size, timesteps, input_dim)`,
   2291       (Optional) 2D tensors with shape `(batch_size, output_dim)`.
   2292 
   2293   Output shape:
   2294       - if `return_state`: a list of tensors. The first tensor is
   2295           the output. The remaining tensors are the last states,
   2296           each with shape `(batch_size, units)`.
   2297       - if `return_sequences`: 3D tensor with shape
   2298           `(batch_size, timesteps, units)`.
   2299       - else, 2D tensor with shape `(batch_size, units)`.
   2300 
   2301   # Masking
   2302       This layer supports masking for input data with a variable number
   2303       of timesteps. To introduce masks to your data,
   2304       use an `Embedding` layer with the `mask_zero` parameter
   2305       set to `True`.
   2306 
   2307   # Note on using statefulness in RNNs
   2308       You can set RNN layers to be 'stateful', which means that the states
   2309       computed for the samples in one batch will be reused as initial states
   2310       for the samples in the next batch. This assumes a one-to-one mapping
   2311       between samples in different successive batches.
   2312 
   2313       To enable statefulness:
   2314           - specify `stateful=True` in the layer constructor.
   2315           - specify a fixed batch size for your model, by passing
   2316               if sequential model:
   2317                 `batch_input_shape=(...)` to the first layer in your model.
   2318               else for functional model with 1 or more Input layers:
   2319                 `batch_shape=(...)` to all the first layers in your model.
   2320               This is the expected shape of your inputs
   2321               *including the batch size*.
   2322               It should be a tuple of integers, e.g. `(32, 10, 100)`.
   2323           - specify `shuffle=False` when calling fit().
   2324 
   2325       To reset the states of your model, call `.reset_states()` on either
   2326       a specific layer, or on your entire model.
   2327 
   2328   # Note on specifying the initial state of RNNs
   2329       You can specify the initial state of RNN layers symbolically by
   2330       calling them with the keyword argument `initial_state`. The value of
   2331       `initial_state` should be a tensor or list of tensors representing
   2332       the initial state of the RNN layer.
   2333 
   2334       You can specify the initial state of RNN layers numerically by
   2335       calling `reset_states` with the keyword argument `states`. The value of
   2336       `states` should be a numpy array or list of numpy arrays representing
   2337       the initial state of the RNN layer.
   2338   """
   2339 
   2340   def __init__(self,
   2341                return_sequences=False,
   2342                return_state=False,
   2343                go_backwards=False,
   2344                stateful=False,
   2345                unroll=False,
   2346                implementation=0,
   2347                **kwargs):
   2348     super(Recurrent, self).__init__(**kwargs)
   2349     self.return_sequences = return_sequences
   2350     self.return_state = return_state
   2351     self.go_backwards = go_backwards
   2352     self.stateful = stateful
   2353     self.unroll = unroll
   2354     self.implementation = implementation
   2355     self.supports_masking = True
   2356     self.input_spec = [InputSpec(ndim=3)]
   2357     self.state_spec = None
   2358     self.dropout = 0
   2359     self.recurrent_dropout = 0
   2360 
   2361   @shape_type_conversion
   2362   def compute_output_shape(self, input_shape):
   2363     if isinstance(input_shape, list):
   2364       input_shape = input_shape[0]
   2365     input_shape = tensor_shape.TensorShape(input_shape).as_list()
   2366     if self.return_sequences:
   2367       output_shape = (input_shape[0], input_shape[1], self.units)
   2368     else:
   2369       output_shape = (input_shape[0], self.units)
   2370 
   2371     if self.return_state:
   2372       state_shape = [tensor_shape.TensorShape(
   2373           (input_shape[0], self.units)) for _ in self.states]
   2374       return [tensor_shape.TensorShape(output_shape)] + state_shape
   2375     return tensor_shape.TensorShape(output_shape)
   2376 
   2377   def compute_mask(self, inputs, mask):
   2378     if isinstance(mask, list):
   2379       mask = mask[0]
   2380     output_mask = mask if self.return_sequences else None
   2381     if self.return_state:
   2382       state_mask = [None for _ in self.states]
   2383       return [output_mask] + state_mask
   2384     return output_mask
   2385 
   2386   def step(self, inputs, states):
   2387     raise NotImplementedError
   2388 
   2389   def get_constants(self, inputs, training=None):
   2390     return []
   2391 
   2392   def get_initial_state(self, inputs):
   2393     # build an all-zero tensor of shape (samples, output_dim)
   2394     initial_state = K.zeros_like(inputs)  # (samples, timesteps, input_dim)
   2395     initial_state = K.sum(initial_state, axis=(1, 2))  # (samples,)
   2396     initial_state = K.expand_dims(initial_state)  # (samples, 1)
   2397     initial_state = K.tile(initial_state, [1,
   2398                                            self.units])  # (samples, output_dim)
   2399     initial_state = [initial_state for _ in range(len(self.states))]
   2400     return initial_state
   2401 
   2402   def preprocess_input(self, inputs, training=None):
   2403     return inputs
   2404 
   2405   def __call__(self, inputs, initial_state=None, **kwargs):
   2406     if (isinstance(inputs, (list, tuple)) and
   2407         len(inputs) > 1
   2408         and initial_state is None):
   2409       initial_state = inputs[1:]
   2410       inputs = inputs[0]
   2411 
   2412     # If `initial_state` is specified,
   2413     # and if it a Keras tensor,
   2414     # then add it to the inputs and temporarily
   2415     # modify the input spec to include the state.
   2416     if initial_state is None:
   2417       return super(Recurrent, self).__call__(inputs, **kwargs)
   2418 
   2419     if not isinstance(initial_state, (list, tuple)):
   2420       initial_state = [initial_state]
   2421 
   2422     is_keras_tensor = hasattr(initial_state[0], '_keras_history')
   2423     for tensor in initial_state:
   2424       if hasattr(tensor, '_keras_history') != is_keras_tensor:
   2425         raise ValueError('The initial state of an RNN layer cannot be'
   2426                          ' specified with a mix of Keras tensors and'
   2427                          ' non-Keras tensors')
   2428 
   2429     if is_keras_tensor:
   2430       # Compute the full input spec, including state
   2431       input_spec = self.input_spec
   2432       state_spec = self.state_spec
   2433       if not isinstance(input_spec, list):
   2434         input_spec = [input_spec]
   2435       if not isinstance(state_spec, list):
   2436         state_spec = [state_spec]
   2437       self.input_spec = input_spec + state_spec
   2438 
   2439       # Compute the full inputs, including state
   2440       inputs = [inputs] + list(initial_state)
   2441 
   2442       # Perform the call
   2443       output = super(Recurrent, self).__call__(inputs, **kwargs)
   2444 
   2445       # Restore original input spec
   2446       self.input_spec = input_spec
   2447       return output
   2448     else:
   2449       kwargs['initial_state'] = initial_state
   2450       return super(Recurrent, self).__call__(inputs, **kwargs)
   2451 
   2452   def call(self, inputs, mask=None, training=None, initial_state=None):
   2453     # input shape: `(samples, time (padded with zeros), input_dim)`
   2454     # note that the .build() method of subclasses MUST define
   2455     # self.input_spec and self.state_spec with complete input shapes.
   2456     if isinstance(inputs, list):
   2457       initial_state = inputs[1:]
   2458       inputs = inputs[0]
   2459     elif initial_state is not None:
   2460       pass
   2461     elif self.stateful:
   2462       initial_state = self.states
   2463     else:
   2464       initial_state = self.get_initial_state(inputs)
   2465 
   2466     if isinstance(mask, list):
   2467       mask = mask[0]
   2468 
   2469     if len(initial_state) != len(self.states):
   2470       raise ValueError('Layer has ' + str(len(self.states)) +
   2471                        ' states but was passed ' + str(len(initial_state)) +
   2472                        ' initial states.')
   2473     input_shape = K.int_shape(inputs)
   2474     if self.unroll and input_shape[1] is None:
   2475       raise ValueError('Cannot unroll a RNN if the '
   2476                        'time dimension is undefined. \n'
   2477                        '- If using a Sequential model, '
   2478                        'specify the time dimension by passing '
   2479                        'an `input_shape` or `batch_input_shape` '
   2480                        'argument to your first layer. If your '
   2481                        'first layer is an Embedding, you can '
   2482                        'also use the `input_length` argument.\n'
   2483                        '- If using the functional API, specify '
   2484                        'the time dimension by passing a `shape` '
   2485                        'or `batch_shape` argument to your Input layer.')
   2486     constants = self.get_constants(inputs, training=None)
   2487     preprocessed_input = self.preprocess_input(inputs, training=None)
   2488     last_output, outputs, states = K.rnn(
   2489         self.step,
   2490         preprocessed_input,
   2491         initial_state,
   2492         go_backwards=self.go_backwards,
   2493         mask=mask,
   2494         constants=constants,
   2495         unroll=self.unroll)
   2496     if self.stateful:
   2497       updates = []
   2498       for i in range(len(states)):
   2499         updates.append(K.update(self.states[i], states[i]))
   2500       self.add_update(updates, inputs)
   2501 
   2502     # Properly set learning phase
   2503     if 0 < self.dropout + self.recurrent_dropout:
   2504       last_output._uses_learning_phase = True
   2505       outputs._uses_learning_phase = True
   2506 
   2507     if not self.return_sequences:
   2508       outputs = last_output
   2509 
   2510     if self.return_state:
   2511       if not isinstance(states, (list, tuple)):
   2512         states = [states]
   2513       else:
   2514         states = list(states)
   2515       return [outputs] + states
   2516     return outputs
   2517 
   2518   def reset_states(self, states=None):
   2519     if not self.stateful:
   2520       raise AttributeError('Layer must be stateful.')
   2521     batch_size = self.input_spec[0].shape[0]
   2522     if not batch_size:
   2523       raise ValueError('If a RNN is stateful, it needs to know '
   2524                        'its batch size. Specify the batch size '
   2525                        'of your input tensors: \n'
   2526                        '- If using a Sequential model, '
   2527                        'specify the batch size by passing '
   2528                        'a `batch_input_shape` '
   2529                        'argument to your first layer.\n'
   2530                        '- If using the functional API, specify '
   2531                        'the time dimension by passing a '
   2532                        '`batch_shape` argument to your Input layer.')
   2533     # initialize state if None
   2534     if self.states[0] is None:
   2535       self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
   2536     elif states is None:
   2537       for state in self.states:
   2538         K.set_value(state, np.zeros((batch_size, self.units)))
   2539     else:
   2540       if not isinstance(states, (list, tuple)):
   2541         states = [states]
   2542       if len(states) != len(self.states):
   2543         raise ValueError('Layer ' + self.name + ' expects ' +
   2544                          str(len(self.states)) + ' states, '
   2545                          'but it received ' + str(len(states)) +
   2546                          ' state values. Input received: ' + str(states))
   2547       for index, (value, state) in enumerate(zip(states, self.states)):
   2548         if value.shape != (batch_size, self.units):
   2549           raise ValueError('State ' + str(index) +
   2550                            ' is incompatible with layer ' + self.name +
   2551                            ': expected shape=' + str((batch_size, self.units)) +
   2552                            ', found shape=' + str(value.shape))
   2553         K.set_value(state, value)
   2554 
   2555   def get_config(self):
   2556     config = {
   2557         'return_sequences': self.return_sequences,
   2558         'return_state': self.return_state,
   2559         'go_backwards': self.go_backwards,
   2560         'stateful': self.stateful,
   2561         'unroll': self.unroll,
   2562         'implementation': self.implementation
   2563     }
   2564     base_config = super(Recurrent, self).get_config()
   2565     return dict(list(base_config.items()) + list(config.items()))
   2566