Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """LSTM Block Cell ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import abc
     21 
     22 import six
     23 
     24 from tensorflow.contrib.rnn.ops import gen_lstm_ops
     25 from tensorflow.contrib.util import loader
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.keras.engine import input_spec
     29 from tensorflow.python.layers import base as base_layer
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import init_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import nn_ops
     34 from tensorflow.python.ops import rnn_cell_impl
     35 from tensorflow.python.platform import resource_loader
     36 
     37 _lstm_ops_so = loader.load_op_library(
     38     resource_loader.get_path_to_datafile("_lstm_ops.so"))
     39 
     40 LayerRNNCell = rnn_cell_impl.LayerRNNCell  # pylint: disable=invalid-name
     41 
     42 
     43 # pylint: disable=invalid-name
     44 def _lstm_block_cell(x,
     45                      cs_prev,
     46                      h_prev,
     47                      w,
     48                      b,
     49                      wci=None,
     50                      wcf=None,
     51                      wco=None,
     52                      forget_bias=None,
     53                      cell_clip=None,
     54                      use_peephole=None,
     55                      name=None):
     56   r"""Computes the LSTM cell forward propagation for 1 time step.
     57 
     58   This implementation uses 1 weight matrix and 1 bias vector, and there's an
     59   optional peephole connection.
     60 
     61   This kernel op implements the following mathematical equations:
     62 
     63   ```python
     64   xh = [x, h_prev]
     65   [i, ci, f, o] = xh * w + b
     66   f = f + forget_bias
     67 
     68   if not use_peephole:
     69     wci = wcf = wco = 0
     70 
     71   i = sigmoid(cs_prev * wci + i)
     72   f = sigmoid(cs_prev * wcf + f)
     73   ci = tanh(ci)
     74 
     75   cs = ci .* i + cs_prev .* f
     76   cs = clip(cs, cell_clip)
     77 
     78   o = sigmoid(cs * wco + o)
     79   co = tanh(cs)
     80   h = co .* o
     81   ```
     82 
     83   Args:
     84     x: A `Tensor`. Must be one of the following types: `float32`.
     85       The input to the LSTM cell, shape (batch_size, num_inputs).
     86     cs_prev: A `Tensor`. Must have the same type as `x`.
     87       Value of the cell state at previous time step.
     88     h_prev: A `Tensor`. Must have the same type as `x`.
     89       Output of the previous cell at previous time step.
     90     w: A `Tensor`. Must have the same type as `x`. The weight matrix.
     91     b: A `Tensor`. Must have the same type as `x`. The bias vector.
     92     wci: A `Tensor`. Must have the same type as `x`.
     93       The weight matrix for input gate peephole connection.
     94     wcf: A `Tensor`. Must have the same type as `x`.
     95       The weight matrix for forget gate peephole connection.
     96     wco: A `Tensor`. Must have the same type as `x`.
     97       The weight matrix for output gate peephole connection.
     98     forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
     99     cell_clip: An optional `float`. Defaults to `-1` (no clipping).
    100       Value to clip the 'cs' value to. Disable by setting to negative value.
    101     use_peephole: An optional `bool`. Defaults to `False`.
    102       Whether to use peephole weights.
    103     name: A name for the operation (optional).
    104 
    105   Returns:
    106     A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
    107     i: A `Tensor`. Has the same type as `x`. The input gate.
    108     cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
    109     f: A `Tensor`. Has the same type as `x`. The forget gate.
    110     o: A `Tensor`. Has the same type as `x`. The output gate.
    111     ci: A `Tensor`. Has the same type as `x`. The cell input.
    112     co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
    113     h: A `Tensor`. Has the same type as `x`. The output h vector.
    114 
    115   Raises:
    116     ValueError: If cell_size is None.
    117   """
    118   if wci is None:
    119     cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
    120     if cell_size is None:
    121       raise ValueError("cell_size from `cs_prev` should not be None.")
    122     wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
    123     wcf = wci
    124     wco = wci
    125 
    126   # pylint: disable=protected-access
    127   return gen_lstm_ops.lstm_block_cell(
    128       x=x,
    129       cs_prev=cs_prev,
    130       h_prev=h_prev,
    131       w=w,
    132       wci=wci,
    133       wcf=wcf,
    134       wco=wco,
    135       b=b,
    136       forget_bias=forget_bias,
    137       cell_clip=cell_clip if cell_clip is not None else -1,
    138       use_peephole=use_peephole,
    139       name=name)
    140   # pylint: enable=protected-access
    141 
    142 
    143 def _block_lstm(seq_len_max,
    144                 x,
    145                 w,
    146                 b,
    147                 cs_prev=None,
    148                 h_prev=None,
    149                 wci=None,
    150                 wcf=None,
    151                 wco=None,
    152                 forget_bias=None,
    153                 cell_clip=None,
    154                 use_peephole=None,
    155                 name=None):
    156   r"""TODO(williamchan): add doc.
    157 
    158   Args:
    159     seq_len_max: A `Tensor` of type `int64`.
    160     x: A list of at least 1 `Tensor` objects of the same type.
    161     w: A `Tensor`. Must have the same type as `x`.
    162     b: A `Tensor`. Must have the same type as `x`.
    163     cs_prev: A `Tensor`. Must have the same type as `x`.
    164     h_prev: A `Tensor`. Must have the same type as `x`.
    165     wci: A `Tensor`. Must have the same type as `x`.
    166     wcf: A `Tensor`. Must have the same type as `x`.
    167     wco: A `Tensor`. Must have the same type as `x`.
    168     forget_bias: An optional `float`. Defaults to `1`.
    169     cell_clip: An optional `float`. Defaults to `-1` (no clipping).
    170     use_peephole: An optional `bool`. Defaults to `False`.
    171     name: A name for the operation (optional).
    172 
    173   Returns:
    174     A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
    175     i: A list with the same number of `Tensor` objects as `x` of `Tensor`
    176     objects of the same type as x.
    177     cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
    178     objects of the same type as x.
    179     f: A list with the same number of `Tensor` objects as `x` of `Tensor`
    180     objects of the same type as x.
    181     o: A list with the same number of `Tensor` objects as `x` of `Tensor`
    182     objects of the same type as x.
    183     ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
    184     objects of the same type as x.
    185     co: A list with the same number of `Tensor` objects as `x` of `Tensor`
    186     objects of the same type as x.
    187     h: A list with the same number of `Tensor` objects as `x` of `Tensor`
    188     objects of the same type as x.
    189 
    190   Raises:
    191     ValueError: If `b` does not have a valid shape.
    192   """
    193   dtype = x[0].dtype
    194   batch_size = x[0].get_shape().with_rank(2).dims[0].value
    195   cell_size4 = b.get_shape().with_rank(1).dims[0].value
    196   if cell_size4 is None:
    197     raise ValueError("`b` shape must not be None.")
    198   cell_size = cell_size4 / 4
    199   zero_state = None
    200   if cs_prev is None or h_prev is None:
    201     zero_state = array_ops.constant(
    202         0, dtype=dtype, shape=[batch_size, cell_size])
    203   if cs_prev is None:
    204     cs_prev = zero_state
    205   if h_prev is None:
    206     h_prev = zero_state
    207   if wci is None:
    208     wci = array_ops.constant(0, dtype=dtype, shape=[cell_size])
    209     wcf = wci
    210     wco = wci
    211 
    212   # pylint: disable=protected-access
    213   i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm(
    214       seq_len_max=seq_len_max,
    215       x=array_ops.stack(x),
    216       cs_prev=cs_prev,
    217       h_prev=h_prev,
    218       w=w,
    219       wci=wci,
    220       wcf=wcf,
    221       wco=wco,
    222       b=b,
    223       forget_bias=forget_bias,
    224       cell_clip=cell_clip if cell_clip is not None else -1,
    225       name=name,
    226       use_peephole=use_peephole)
    227 
    228   return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack(
    229       f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack(
    230           co), array_ops.unstack(h)
    231   # pylint: enable=protected-access
    232   # pylint: enable=invalid-name
    233 
    234 
    235 _lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"]
    236 
    237 
    238 @ops.RegisterGradient("LSTMBlockCell")
    239 def _LSTMBlockCellGrad(op, *grad):
    240   """Gradient for LSTMBlockCell."""
    241   (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs
    242   (i, cs, f, o, ci, co, _) = op.outputs
    243   (_, cs_grad, _, _, _, _, h_grad) = grad
    244 
    245   batch_size = x.get_shape().with_rank(2).dims[0].value
    246   if batch_size is None:
    247     batch_size = -1
    248   input_size = x.get_shape().with_rank(2).dims[1].value
    249   if input_size is None:
    250     raise ValueError("input_size from `x` should not be None.")
    251   cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
    252   if cell_size is None:
    253     raise ValueError("cell_size from `cs_prev` should not be None.")
    254 
    255   (cs_prev_grad, dicfo, wci_grad, wcf_grad,
    256    wco_grad) = gen_lstm_ops.lstm_block_cell_grad(
    257        x,
    258        cs_prev,
    259        h_prev,
    260        w,
    261        wci,
    262        wcf,
    263        wco,
    264        b,
    265        i,
    266        cs,
    267        f,
    268        o,
    269        ci,
    270        co,
    271        cs_grad,
    272        h_grad,
    273        use_peephole=op.get_attr("use_peephole"))
    274 
    275   # Backprop from dicfo to xh.
    276   xh_grad = math_ops.matmul(dicfo, w, transpose_b=True)
    277 
    278   x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size))
    279   x_grad.get_shape().merge_with(x.get_shape())
    280 
    281   h_prev_grad = array_ops.slice(xh_grad, (0, input_size),
    282                                 (batch_size, cell_size))
    283   h_prev_grad.get_shape().merge_with(h_prev.get_shape())
    284 
    285   # Backprop from dicfo to w.
    286   xh = array_ops.concat([x, h_prev], 1)
    287   w_grad = math_ops.matmul(xh, dicfo, transpose_a=True)
    288   w_grad.get_shape().merge_with(w.get_shape())
    289 
    290   # Backprop from dicfo to b.
    291   b_grad = nn_ops.bias_add_grad(dicfo)
    292   b_grad.get_shape().merge_with(b.get_shape())
    293 
    294   return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
    295           wco_grad, b_grad)
    296 
    297 
    298 @ops.RegisterGradient("BlockLSTM")
    299 def _BlockLSTMGrad(op, *grad):
    300   """Gradient for BlockLSTM."""
    301   seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs
    302   i, cs, f, o, ci, co, h = op.outputs
    303 
    304   cs_grad = grad[1]
    305   h_grad = grad[6]
    306 
    307   (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad,
    308    b_grad) = gen_lstm_ops.block_lstm_grad(
    309        seq_len_max,
    310        x,
    311        cs_prev,
    312        h_prev,
    313        w,
    314        wci,
    315        wcf,
    316        wco,
    317        b,
    318        i,
    319        cs,
    320        f,
    321        o,
    322        ci,
    323        co,
    324        h,
    325        cs_grad,
    326        h_grad,
    327        use_peephole=op.get_attr("use_peephole"))
    328 
    329   return [
    330       None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
    331       wco_grad, b_grad
    332   ]
    333 
    334 
    335 class LSTMBlockCell(LayerRNNCell):
    336   """Basic LSTM recurrent network cell.
    337 
    338   The implementation is based on: http://arxiv.org/abs/1409.2329.
    339 
    340   We add `forget_bias` (default: 1) to the biases of the forget gate in order to
    341   reduce the scale of forgetting in the beginning of the training.
    342 
    343   Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
    344   faster.  The weight and bias matrices should be compatible as long as the
    345   variable scope matches.
    346   """
    347 
    348   def __init__(self,
    349                num_units,
    350                forget_bias=1.0,
    351                cell_clip=None,
    352                use_peephole=False,
    353                dtype=None,
    354                reuse=None,
    355                name="lstm_cell"):
    356     """Initialize the basic LSTM cell.
    357 
    358     Args:
    359       num_units: int, The number of units in the LSTM cell.
    360       forget_bias: float, The bias added to forget gates (see above).
    361       cell_clip: An optional `float`. Defaults to `-1` (no clipping).
    362       use_peephole: Whether to use peephole connections or not.
    363       dtype: the variable dtype of this layer. Default to tf.float32.
    364       reuse: (optional) boolean describing whether to reuse variables in an
    365         existing scope.  If not `True`, and the existing scope already has the
    366         given variables, an error is raised.
    367       name: String, the name of the layer. Layers with the same name will
    368         share weights, but to avoid mistakes we require reuse=True in such
    369         cases.  By default this is "lstm_cell", for variable-name compatibility
    370         with `tf.nn.rnn_cell.LSTMCell`.
    371 
    372       When restoring from CudnnLSTM-trained checkpoints, must use
    373       CudnnCompatibleLSTMBlockCell instead.
    374     """
    375     super(LSTMBlockCell, self).__init__(_reuse=reuse, dtype=dtype, name=name)
    376     self._num_units = num_units
    377     self._forget_bias = forget_bias
    378     self._use_peephole = use_peephole
    379     self._cell_clip = cell_clip if cell_clip is not None else -1
    380     self._names = {
    381         "W": "kernel",
    382         "b": "bias",
    383         "wci": "w_i_diag",
    384         "wcf": "w_f_diag",
    385         "wco": "w_o_diag",
    386         "scope": "lstm_cell"
    387     }
    388     # Inputs must be 2-dimensional.
    389     self.input_spec = input_spec.InputSpec(ndim=2)
    390 
    391   @property
    392   def state_size(self):
    393     return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
    394 
    395   @property
    396   def output_size(self):
    397     return self._num_units
    398 
    399   def build(self, inputs_shape):
    400     if not inputs_shape.dims[1].value:
    401       raise ValueError(
    402           "Expecting inputs_shape[1] to be set: %s" % str(inputs_shape))
    403     input_size = inputs_shape.dims[1].value
    404     self._kernel = self.add_variable(
    405         self._names["W"], [input_size + self._num_units, self._num_units * 4])
    406     self._bias = self.add_variable(
    407         self._names["b"], [self._num_units * 4],
    408         initializer=init_ops.constant_initializer(0.0))
    409     if self._use_peephole:
    410       self._w_i_diag = self.add_variable(self._names["wci"], [self._num_units])
    411       self._w_f_diag = self.add_variable(self._names["wcf"], [self._num_units])
    412       self._w_o_diag = self.add_variable(self._names["wco"], [self._num_units])
    413 
    414     self.built = True
    415 
    416   def call(self, inputs, state):
    417     """Long short-term memory cell (LSTM)."""
    418     if len(state) != 2:
    419       raise ValueError("Expecting state to be a tuple with length 2.")
    420 
    421     if self._use_peephole:
    422       wci = self._w_i_diag
    423       wcf = self._w_f_diag
    424       wco = self._w_o_diag
    425     else:
    426       wci = wcf = wco = array_ops.zeros([self._num_units], dtype=self.dtype)
    427 
    428     (cs_prev, h_prev) = state
    429     (_, cs, _, _, _, _, h) = _lstm_block_cell(
    430         inputs,
    431         cs_prev,
    432         h_prev,
    433         self._kernel,
    434         self._bias,
    435         wci=wci,
    436         wcf=wcf,
    437         wco=wco,
    438         forget_bias=self._forget_bias,
    439         cell_clip=self._cell_clip,
    440         use_peephole=self._use_peephole)
    441 
    442     new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
    443     return h, new_state
    444 
    445 
    446 @six.add_metaclass(abc.ABCMeta)
    447 class LSTMBlockWrapper(base_layer.Layer):
    448   """This is a helper class that provides housekeeping for LSTM cells.
    449 
    450   This may be useful for alternative LSTM and similar type of cells.
    451   The subclasses must implement `_call_cell` method and `num_units` property.
    452   """
    453 
    454   @abc.abstractproperty
    455   def num_units(self):
    456     """Number of units in this cell (output dimension)."""
    457     pass
    458 
    459   @abc.abstractmethod
    460   def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
    461                  sequence_length):
    462     """Run this LSTM on inputs, starting from the given state.
    463 
    464     This method must be implemented by subclasses and does the actual work
    465     of calling the cell.
    466 
    467     Args:
    468       inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
    469       initial_cell_state: initial value for cell state, shape `[batch_size,
    470         self._num_units]`
    471       initial_output: initial value of cell output, shape `[batch_size,
    472         self._num_units]`
    473       dtype: The data type for the initial state and expected output.
    474       sequence_length: Specifies the length of each sequence in inputs. An int32
    475         or int64 vector (tensor) size [batch_size], values in [0, time_len) or
    476           None.
    477 
    478     Returns:
    479       A pair containing:
    480 
    481       - State: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
    482       - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
    483     """
    484     pass
    485 
    486   def call(self, inputs, initial_state=None, dtype=None, sequence_length=None):
    487     """Run this LSTM on inputs, starting from the given state.
    488 
    489     Args:
    490       inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
    491       initial_state: a tuple `(initial_cell_state, initial_output)` with tensors
    492         of shape `[batch_size, self._num_units]`. If this is not provided, the
    493         cell is expected to create a zero initial state of type `dtype`.
    494       dtype: The data type for the initial state and expected output. Required
    495         if `initial_state` is not provided or RNN state has a heterogeneous
    496         dtype.
    497       sequence_length: Specifies the length of each sequence in inputs. An
    498         `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
    499         time_len).`
    500         Defaults to `time_len` for each element.
    501 
    502     Returns:
    503       A pair containing:
    504 
    505       - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
    506         or a list of time_len tensors of shape `[batch_size, output_size]`,
    507         to match the type of the `inputs`.
    508       - Final state: a tuple `(cell_state, output)` matching `initial_state`.
    509 
    510     Raises:
    511       ValueError: in case of shape mismatches
    512     """
    513     is_list = isinstance(inputs, list)
    514     if is_list:
    515       inputs = array_ops.stack(inputs)
    516     inputs_shape = inputs.get_shape().with_rank(3)
    517     if not inputs_shape[2]:
    518       raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape)
    519     batch_size = inputs_shape.dims[1].value
    520     if batch_size is None:
    521       batch_size = array_ops.shape(inputs)[1]
    522     time_len = inputs_shape.dims[0].value
    523     if time_len is None:
    524       time_len = array_ops.shape(inputs)[0]
    525 
    526     # Provide default values for initial_state and dtype
    527     if initial_state is None:
    528       if dtype is None:
    529         raise ValueError("Either initial_state or dtype needs to be specified")
    530       z = array_ops.zeros(
    531           array_ops.stack([batch_size, self.num_units]), dtype=dtype)
    532       initial_state = z, z
    533     else:
    534       if len(initial_state) != 2:
    535         raise ValueError(
    536             "Expecting initial_state to be a tuple with length 2 or None")
    537       if dtype is None:
    538         dtype = initial_state[0].dtype
    539 
    540     # create the actual cell
    541     if sequence_length is not None:
    542       sequence_length = ops.convert_to_tensor(sequence_length)
    543     initial_cell_state, initial_output = initial_state  # pylint: disable=unpacking-non-sequence
    544     cell_states, outputs = self._call_cell(
    545         inputs, initial_cell_state, initial_output, dtype, sequence_length)
    546 
    547     if sequence_length is not None:
    548       # Mask out the part beyond sequence_length
    549       mask = array_ops.transpose(
    550           array_ops.sequence_mask(sequence_length, time_len, dtype=dtype),
    551           [1, 0])
    552       mask = array_ops.tile(
    553           array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units])
    554       outputs *= mask
    555       # Prepend initial states to cell_states and outputs for indexing to work
    556       # correctly,since we want to access the last valid state at
    557       # sequence_length - 1, which can even be -1, corresponding to the
    558       # initial state.
    559       mod_cell_states = array_ops.concat(
    560           [array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0)
    561       mod_outputs = array_ops.concat(
    562           [array_ops.expand_dims(initial_output, [0]), outputs], 0)
    563       final_cell_state = self._gather_states(mod_cell_states, sequence_length,
    564                                              batch_size)
    565       final_output = self._gather_states(mod_outputs, sequence_length,
    566                                          batch_size)
    567     else:
    568       # No sequence_lengths used: final state is the last state
    569       final_cell_state = cell_states[-1]
    570       final_output = outputs[-1]
    571 
    572     if is_list:
    573       # Input was a list, so return a list
    574       outputs = array_ops.unstack(outputs)
    575 
    576     final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
    577     return outputs, final_state
    578 
    579   def _gather_states(self, data, indices, batch_size):
    580     """Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
    581     return array_ops.gather_nd(
    582         data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
    583 
    584 
    585 class LSTMBlockFusedCell(LSTMBlockWrapper):
    586   """FusedRNNCell implementation of LSTM.
    587 
    588   This is an extremely efficient LSTM implementation, that uses a single TF op
    589   for the entire LSTM. It should be both faster and more memory-efficient than
    590   LSTMBlockCell defined above.
    591 
    592   The implementation is based on: http://arxiv.org/abs/1409.2329.
    593 
    594   We add forget_bias (default: 1) to the biases of the forget gate in order to
    595   reduce the scale of forgetting in the beginning of the training.
    596 
    597   The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
    598   """
    599 
    600   def __init__(self,
    601                num_units,
    602                forget_bias=1.0,
    603                cell_clip=None,
    604                use_peephole=False,
    605                reuse=None,
    606                dtype=None,
    607                name="lstm_fused_cell"):
    608     """Initialize the LSTM cell.
    609 
    610     Args:
    611       num_units: int, The number of units in the LSTM cell.
    612       forget_bias: float, The bias added to forget gates (see above).
    613       cell_clip: clip the cell to this value. Defaults is no cell clipping.
    614       use_peephole: Whether to use peephole connections or not.
    615       reuse: (optional) boolean describing whether to reuse variables in an
    616         existing scope.  If not `True`, and the existing scope already has the
    617         given variables, an error is raised.
    618       dtype: the dtype of variables of this layer.
    619       name: String, the name of the layer. Layers with the same name will
    620         share weights, but to avoid mistakes we require reuse=True in such
    621         cases.  By default this is "lstm_cell", for variable-name compatibility
    622         with `tf.nn.rnn_cell.LSTMCell`.
    623     """
    624     super(LSTMBlockFusedCell, self).__init__(
    625         _reuse=reuse, name=name, dtype=dtype)
    626     self._num_units = num_units
    627     self._forget_bias = forget_bias
    628     self._cell_clip = cell_clip if cell_clip is not None else -1
    629     self._use_peephole = use_peephole
    630 
    631     # Inputs must be 3-dimensional.
    632     self.input_spec = input_spec.InputSpec(ndim=3)
    633 
    634   @property
    635   def num_units(self):
    636     """Number of units in this cell (output dimension)."""
    637     return self._num_units
    638 
    639   def build(self, input_shape):
    640     input_size = input_shape.dims[2].value
    641     self._kernel = self.add_variable(
    642         "kernel", [input_size + self._num_units, self._num_units * 4])
    643     self._bias = self.add_variable(
    644         "bias", [self._num_units * 4],
    645         initializer=init_ops.constant_initializer(0.0))
    646     if self._use_peephole:
    647       self._w_i_diag = self.add_variable("w_i_diag", [self._num_units])
    648       self._w_f_diag = self.add_variable("w_f_diag", [self._num_units])
    649       self._w_o_diag = self.add_variable("w_o_diag", [self._num_units])
    650 
    651     self.built = True
    652 
    653   def _call_cell(self,
    654                  inputs,
    655                  initial_cell_state=None,
    656                  initial_output=None,
    657                  dtype=None,
    658                  sequence_length=None):
    659     """Run this LSTM on inputs, starting from the given state.
    660 
    661     Args:
    662       inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
    663       initial_cell_state: initial value for cell state, shape `[batch_size,
    664         self._num_units]`
    665       initial_output: initial value of cell output, shape `[batch_size,
    666         self._num_units]`
    667       dtype: The data type for the initial state and expected output.
    668       sequence_length: Specifies the length of each sequence in inputs. An
    669         `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
    670         time_len)` or None.
    671 
    672     Returns:
    673       A pair containing:
    674 
    675       - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
    676                          output_size]`
    677       - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
    678                     output_size]`
    679     """
    680 
    681     inputs_shape = inputs.get_shape().with_rank(3)
    682     time_len = inputs_shape.dims[0].value
    683     if time_len is None:
    684       time_len = array_ops.shape(inputs)[0]
    685 
    686     if self._use_peephole:
    687       wci = self._w_i_diag
    688       wco = self._w_o_diag
    689       wcf = self._w_f_diag
    690     else:
    691       wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype)
    692 
    693     if sequence_length is None:
    694       max_seq_len = math_ops.cast(time_len, dtypes.int64)
    695     else:
    696       max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length),
    697                                   dtypes.int64)
    698 
    699     _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm(
    700         seq_len_max=max_seq_len,
    701         x=inputs,
    702         cs_prev=initial_cell_state,
    703         h_prev=initial_output,
    704         w=self._kernel,
    705         wci=wci,
    706         wcf=wcf,
    707         wco=wco,
    708         b=self._bias,
    709         forget_bias=self._forget_bias,
    710         cell_clip=self._cell_clip,
    711         use_peephole=self._use_peephole)
    712     return cs, h
    713