Home | History | Annotate | Download | only in engine
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=protected-access
     16 """Base layer code and base model (Network) code.
     17 """
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import copy
     23 import json
     24 import os
     25 
     26 import numpy as np
     27 from six.moves import zip  # pylint: disable=redefined-builtin
     28 
     29 from tensorflow.python.eager import context
     30 from tensorflow.python.framework import tensor_shape
     31 from tensorflow.python.keras._impl.keras import backend as K
     32 from tensorflow.python.keras._impl.keras import constraints
     33 from tensorflow.python.keras._impl.keras import initializers
     34 from tensorflow.python.keras._impl.keras import regularizers
     35 from tensorflow.python.keras._impl.keras.utils import conv_utils
     36 from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
     37 from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary
     38 from tensorflow.python.layers import base as tf_base_layers
     39 from tensorflow.python.layers import network as tf_network
     40 from tensorflow.python.layers import utils as tf_layers_util
     41 from tensorflow.python.platform import tf_logging as logging
     42 from tensorflow.python.util import tf_inspect
     43 from tensorflow.python.util.tf_export import tf_export
     44 
     45 
     46 # pylint: disable=g-import-not-at-top
     47 try:
     48   import h5py
     49 except ImportError:
     50   h5py = None
     51 
     52 try:
     53   import yaml
     54 except ImportError:
     55   yaml = None
     56 # pylint: enable=g-import-not-at-top
     57 
     58 # pylint: disable=invalid-name
     59 InputSpec = tf_base_layers.InputSpec
     60 Node = tf_base_layers.Node
     61 TFBaseLayer = tf_base_layers.Layer
     62 # pylint: enable=invalid-name
     63 
     64 
     65 @tf_export('keras.layers.Layer')
     66 class Layer(tf_base_layers.Layer):
     67   """Abstract base layer class.
     68 
     69   # Properties
     70       name: String, must be unique within a model.
     71       input_spec: List of InputSpec class instances
     72           each entry describes one required input:
     73               - ndim
     74               - dtype
     75           A layer with `n` input tensors must have
     76           an `input_spec` of length `n`.
     77       trainable: Boolean, whether the layer weights
     78           will be updated during training.
     79       uses_learning_phase: Whether any operation
     80           of the layer uses `K.in_training_phase()`
     81           or `K.in_test_phase()`.
     82       input_shape: Shape tuple. Provided for convenience,
     83           but note that there may be cases in which this
     84           attribute is ill-defined (e.g. a shared layer
     85           with multiple input shapes), in which case
     86           requesting `input_shape` will raise an Exception.
     87           Prefer using `layer.get_input_shape_for(input_shape)`,
     88           or `layer.get_input_shape_at(node_index)`.
     89       output_shape: Shape tuple. See above.
     90       inbound_nodes: List of nodes.
     91       outbound_nodes: List of nodes.
     92       input, output: Input/output tensor(s). Note that if the layer is used
     93           more than once (shared layer), this is ill-defined
     94           and will raise an exception. In such cases, use
     95           `layer.get_input_at(node_index)`.
     96       input_mask, output_mask: Same as above, for masks.
     97       trainable_weights: List of variables.
     98       non_trainable_weights: List of variables.
     99       weights: The concatenation of the lists trainable_weights and
    100           non_trainable_weights (in this order).
    101 
    102   # Methods
    103       call(x, mask=None): Where the layer's logic lives.
    104       __call__(x, mask=None): Wrapper around the layer logic (`call`).
    105           If x is a Keras tensor:
    106               - Connect current layer with last layer from tensor:
    107                   `self._add_inbound_node(last_layer)`
    108               - Add layer to tensor history
    109           If layer is not built:
    110               - Build from inputs shape
    111       get_weights()
    112       set_weights(weights)
    113       get_config()
    114       count_params()
    115       compute_output_shape(input_shape)
    116       compute_mask(x, mask)
    117       get_input_at(node_index)
    118       get_output_at(node_index)
    119       get_input_shape_at(node_index)
    120       get_output_shape_at(node_index)
    121       get_input_mask_at(node_index)
    122       get_output_mask_at(node_index)
    123 
    124   # Class Methods
    125       from_config(config)
    126 
    127   # Internal methods:
    128       build(input_shape)
    129       _add_inbound_node(layer, index=0)
    130   """
    131 
    132   def __init__(self, **kwargs):
    133     # These properties should be set by the user via keyword arguments.
    134     # note that 'dtype', 'input_shape' and 'batch_input_shape'
    135     # are only applicable to input layers: do not pass these keywords
    136     # to non-input layers.
    137     allowed_kwargs = {
    138         'activity_regularizer',
    139         'input_shape',
    140         'batch_input_shape',
    141         'batch_size',
    142         'dtype',
    143         'name',
    144         'trainable',
    145         'weights',
    146     }
    147     # Validate optional keyword arguments.
    148     for kwarg in kwargs:
    149       if kwarg not in allowed_kwargs:
    150         raise TypeError('Keyword argument not understood:', kwarg)
    151 
    152     # Get layer name.
    153     name = kwargs.get('name')
    154 
    155     # Get `trainable` status.
    156     trainable = kwargs.get('trainable', True)
    157 
    158     # Get `dtype`.
    159     dtype = kwargs.get('dtype')
    160     if dtype is None:
    161       dtype = K.floatx()
    162 
    163     # Call super, which will set all properties common to Keras layers
    164     # and core TF layers.
    165     super(Layer, self).__init__(
    166         name=name, dtype=dtype, trainable=trainable,
    167         activity_regularizer=kwargs.get('activity_regularizer'))
    168 
    169     # Add properties that are Keras-only for now.
    170     self.supports_masking = False
    171 
    172     # Manage input shape information if passed.
    173     if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
    174       # In this case we will later create an input layer
    175       # to insert before the current layer
    176       if 'batch_input_shape' in kwargs:
    177         batch_input_shape = tuple(kwargs['batch_input_shape'])
    178       elif 'input_shape' in kwargs:
    179         if 'batch_size' in kwargs:
    180           batch_size = kwargs['batch_size']
    181         else:
    182           batch_size = None
    183         batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
    184       self._batch_input_shape = batch_input_shape
    185 
    186     # Manage initial weight values if passed.
    187     if 'weights' in kwargs:
    188       self._initial_weights = kwargs['weights']
    189     else:
    190       self._initial_weights = None
    191 
    192   def add_weight(self,
    193                  name,
    194                  shape,
    195                  dtype=None,
    196                  initializer=None,
    197                  regularizer=None,
    198                  trainable=True,
    199                  constraint=None):
    200     """Adds a weight variable to the layer.
    201 
    202     Arguments:
    203         name: String, the name for the weight variable.
    204         shape: The shape tuple of the weight.
    205         dtype: The dtype of the weight.
    206         initializer: An Initializer instance (callable).
    207         regularizer: An optional Regularizer instance.
    208         trainable: A boolean, whether the weight should
    209             be trained via backprop or not (assuming
    210             that the layer itself is also trainable).
    211         constraint: An optional Constraint instance.
    212 
    213     Returns:
    214         The created weight variable.
    215     """
    216     if dtype is None:
    217       dtype = K.floatx()
    218     weight = self.add_variable(name, shape,
    219                                dtype=dtype,
    220                                initializer=initializers.get(initializer),
    221                                regularizer=regularizers.get(regularizer),
    222                                constraint=constraints.get(constraint),
    223                                trainable=trainable)
    224     return weight
    225 
    226   def call(self, inputs, **kwargs):  # pylint: disable=unused-argument
    227     """This is where the layer's logic lives.
    228 
    229     Arguments:
    230         inputs: Input tensor, or list/tuple of input tensors.
    231         **kwargs: Additional keyword arguments.
    232 
    233     Returns:
    234         A tensor or list/tuple of tensors.
    235     """
    236     return inputs
    237 
    238   def __call__(self, inputs, **kwargs):
    239     """Wrapper around self.call(), for handling internal references.
    240 
    241     If a Keras tensor is passed:
    242         - We call self._add_inbound_node().
    243         - If necessary, we `build` the layer to match
    244             the shape of the input(s).
    245         - We update the _keras_history of the output tensor(s)
    246             with the current layer.
    247             This is done as part of _add_inbound_node().
    248 
    249     Arguments:
    250         inputs: Can be a tensor or list/tuple of tensors.
    251         **kwargs: Additional keyword arguments to be passed to `call()`.
    252 
    253     Returns:
    254         Output of the layer's `call` method.
    255 
    256     Raises:
    257         ValueError: in case the layer is missing shape information
    258             for its `build` call.
    259     """
    260     # Actually call the layer (optionally building it).
    261     output = super(Layer, self).__call__(inputs, **kwargs)
    262     if context.in_eager_mode():
    263       return output
    264 
    265     # Un-built subclassed network: build it
    266     if isinstance(self, Network) and not self.inputs:
    267       self._set_inputs(inputs, training=kwargs.get('training'))
    268 
    269     # Update learning phase info.
    270     output_tensors = _to_list(output)
    271     uses_lp = any(
    272         [getattr(x, '_uses_learning_phase', False) for x in _to_list(inputs)])
    273     uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
    274     for i in range(len(output_tensors)):
    275       output_tensors[i]._uses_learning_phase = getattr(
    276           output_tensors[i], '_uses_learning_phase', False) or uses_lp
    277 
    278     # Optionally load weight values that were specified at layer instantiation.
    279     if hasattr(self, '_initial_weights') and self._initial_weights is not None:
    280       self.set_weights(self._initial_weights)
    281       del self._initial_weights
    282     return output
    283 
    284   def compute_output_shape(self, input_shape):
    285     """Computes the output shape of the layer.
    286 
    287     Assumes that the layer will be built
    288     to match that input shape provided.
    289 
    290     Arguments:
    291         input_shape: Shape tuple (tuple of integers)
    292             or list of shape tuples (one per output tensor of the layer).
    293             Shape tuples can include None for free dimensions,
    294             instead of an integer.
    295 
    296     Returns:
    297         An input shape tuple.
    298     """
    299     logging.warning(
    300         'All custom layers should implement the '
    301         '`compute_output_shape` method. This layer (' + self.name + ') '
    302         'is relying on the base `Layer.compute_output_shape` implementation, '
    303         'which will start raising a `NotImplementedError` '
    304         'as of July 1st, 2018.')
    305     return input_shape
    306 
    307   def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
    308     """Computes an output mask tensor.
    309 
    310     Arguments:
    311         inputs: Tensor or list of tensors.
    312         mask: Tensor or list of tensors.
    313 
    314     Returns:
    315         None or a tensor (or list of tensors,
    316             one per output tensor of the layer).
    317     """
    318     if not self.supports_masking:
    319       if mask is not None:
    320         if isinstance(mask, list):
    321           if any(m is not None for m in mask):
    322             raise TypeError('Layer ' + self.name + ' does not support masking, '
    323                             'but was passed an input_mask: ' + str(mask))
    324         else:
    325           raise TypeError('Layer ' + self.name + ' does not support masking, '
    326                           'but was passed an input_mask: ' + str(mask))
    327       # masking not explicitly supported: return None as mask
    328       return None
    329     # if masking is explicitly supported, by default
    330     # carry over the input mask
    331     return mask
    332 
    333   def get_input_mask_at(self, node_index):
    334     """Retrieves the input mask tensor(s) of a layer at a given node.
    335 
    336     Arguments:
    337         node_index: Integer, index of the node
    338             from which to retrieve the attribute.
    339             E.g. `node_index=0` will correspond to the
    340             first time the layer was called.
    341 
    342     Returns:
    343         A mask tensor
    344         (or list of tensors if the layer has multiple inputs).
    345     """
    346     inputs = self.get_input_at(node_index)
    347     if isinstance(inputs, list):
    348       return [getattr(x, '_keras_mask', None) for x in inputs]
    349     else:
    350       return getattr(inputs, '_keras_mask', None)
    351 
    352   def get_output_mask_at(self, node_index):
    353     """Retrieves the output mask tensor(s) of a layer at a given node.
    354 
    355     Arguments:
    356         node_index: Integer, index of the node
    357             from which to retrieve the attribute.
    358             E.g. `node_index=0` will correspond to the
    359             first time the layer was called.
    360 
    361     Returns:
    362         A mask tensor
    363         (or list of tensors if the layer has multiple outputs).
    364     """
    365     output = self.get_output_at(node_index)
    366     if isinstance(output, list):
    367       return [getattr(x, '_keras_mask', None) for x in output]
    368     else:
    369       return getattr(output, '_keras_mask', None)
    370 
    371   @property
    372   def input_mask(self):
    373     """Retrieves the input mask tensor(s) of a layer.
    374 
    375     Only applicable if the layer has exactly one inbound node,
    376     i.e. if it is connected to one incoming layer.
    377 
    378     Returns:
    379         Input mask tensor (potentially None) or list of input
    380         mask tensors.
    381 
    382     Raises:
    383         AttributeError: if the layer is connected to
    384         more than one incoming layers.
    385     """
    386     inputs = self.input
    387     if isinstance(inputs, list):
    388       return [getattr(x, '_keras_mask', None) for x in inputs]
    389     else:
    390       return getattr(inputs, '_keras_mask', None)
    391 
    392   @property
    393   def output_mask(self):
    394     """Retrieves the output mask tensor(s) of a layer.
    395 
    396     Only applicable if the layer has exactly one inbound node,
    397     i.e. if it is connected to one incoming layer.
    398 
    399     Returns:
    400         Output mask tensor (potentially None) or list of output
    401         mask tensors.
    402 
    403     Raises:
    404         AttributeError: if the layer is connected to
    405         more than one incoming layers.
    406     """
    407     output = self.output
    408     if isinstance(output, list):
    409       return [getattr(x, '_keras_mask', None) for x in output]
    410     else:
    411       return getattr(output, '_keras_mask', None)
    412 
    413   def set_weights(self, weights):
    414     """Sets the weights of the layer, from Numpy arrays.
    415 
    416     Arguments:
    417         weights: a list of Numpy arrays. The number
    418             of arrays and their shape must match
    419             number of the dimensions of the weights
    420             of the layer (i.e. it should match the
    421             output of `get_weights`).
    422 
    423     Raises:
    424         ValueError: If the provided weights list does not match the
    425             layer's specifications.
    426     """
    427     params = self.weights
    428     if len(params) != len(weights):
    429       raise ValueError('You called `set_weights(weights)` on layer "' +
    430                        self.name + '" with a  weight list of length ' +
    431                        str(len(weights)) + ', but the layer was expecting ' +
    432                        str(len(params)) + ' weights. Provided weights: ' +
    433                        str(weights)[:50] + '...')
    434     if not params:
    435       return
    436     weight_value_tuples = []
    437     param_values = K.batch_get_value(params)
    438     for pv, p, w in zip(param_values, params, weights):
    439       if pv.shape != w.shape:
    440         raise ValueError('Layer weight shape ' + str(pv.shape) +
    441                          ' not compatible with '
    442                          'provided weight shape ' + str(w.shape))
    443       weight_value_tuples.append((p, w))
    444     K.batch_set_value(weight_value_tuples)
    445 
    446   def get_weights(self):
    447     """Returns the current weights of the layer.
    448 
    449     Returns:
    450         Weights values as a list of numpy arrays.
    451     """
    452     params = self.weights
    453     return K.batch_get_value(params)
    454 
    455   def get_config(self):
    456     """Returns the config of the layer.
    457 
    458     A layer config is a Python dictionary (serializable)
    459     containing the configuration of a layer.
    460     The same layer can be reinstantiated later
    461     (without its trained weights) from this configuration.
    462 
    463     The config of a layer does not include connectivity
    464     information, nor the layer class name. These are handled
    465     by `Network` (one layer of abstraction above).
    466 
    467     Returns:
    468         Python dictionary.
    469     """
    470     config = {'name': self.name, 'trainable': self.trainable}
    471     if hasattr(self, '_batch_input_shape'):
    472       config['batch_input_shape'] = self._batch_input_shape
    473     if hasattr(self, 'dtype'):
    474       config['dtype'] = self.dtype
    475     return config
    476 
    477   @classmethod
    478   def from_config(cls, config):
    479     """Creates a layer from its config.
    480 
    481     This method is the reverse of `get_config`,
    482     capable of instantiating the same layer from the config
    483     dictionary. It does not handle layer connectivity
    484     (handled by Network), nor weights (handled by `set_weights`).
    485 
    486     Arguments:
    487         config: A Python dictionary, typically the
    488             output of get_config.
    489 
    490     Returns:
    491         A layer instance.
    492     """
    493     return cls(**config)
    494 
    495   @tf_base_layers.Layer.activity_regularizer.setter
    496   def activity_regularizer(self, activity_regularizer):
    497     self._activity_regularizer = activity_regularizer
    498 
    499 
    500 @tf_export('keras.layers.InputLayer')
    501 class InputLayer(tf_network.InputLayer, Layer):
    502   """Layer to be used as an entry point into a graph.
    503 
    504   It can either wrap an existing tensor (pass an `input_tensor` argument)
    505   or create its a placeholder tensor (pass argument `input_shape`.
    506 
    507   Arguments:
    508       input_shape: Shape tuple, not including the batch axis.
    509       batch_size: Optional input batch size (integer or None).
    510       dtype: Datatype of the input.
    511       input_tensor: Optional tensor to use as layer input
    512           instead of creating a placeholder.
    513       sparse: Boolean, whether the placeholder created
    514           is meant to be sparse.
    515       name: Name of the layer (string).
    516   """
    517 
    518   def __init__(self,
    519                input_shape=None,
    520                batch_size=None,
    521                dtype=None,
    522                input_tensor=None,
    523                sparse=False,
    524                name=None,
    525                **kwargs):
    526     if 'batch_input_shape' in kwargs:
    527       batch_input_shape = kwargs.pop('batch_input_shape')
    528       if input_shape and batch_input_shape:
    529         raise ValueError('Only provide the input_shape OR '
    530                          'batch_input_shape argument to '
    531                          'InputLayer, not both at the same time.')
    532       batch_size = batch_input_shape[0]
    533       input_shape = batch_input_shape[1:]
    534     if kwargs:
    535       raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
    536 
    537     if not name:
    538       prefix = 'input'
    539       name = prefix + '_' + str(K.get_uid(prefix))
    540 
    541     if not dtype:
    542       if input_tensor is None:
    543         dtype = K.floatx()
    544       else:
    545         dtype = K.dtype(input_tensor)
    546     super(InputLayer, self).__init__(input_shape=input_shape,
    547                                      batch_size=batch_size,
    548                                      dtype=dtype,
    549                                      input_tensor=input_tensor,
    550                                      sparse=sparse,
    551                                      name=name)
    552 
    553   def get_config(self):
    554     config = {
    555         'batch_input_shape': self._batch_input_shape,
    556         'dtype': self.dtype,
    557         'sparse': self.sparse,
    558         'name': self.name
    559     }
    560     return config
    561 
    562 
    563 @tf_export('keras.layers.Input', 'keras.Input')
    564 def Input(  # pylint: disable=invalid-name
    565     shape=None,
    566     batch_size=None,
    567     name=None,
    568     dtype=None,
    569     sparse=False,
    570     tensor=None,
    571     **kwargs):
    572   """`Input()` is used to instantiate a Keras tensor.
    573 
    574   A Keras tensor is a tensor object from the underlying backend
    575   (Theano or TensorFlow), which we augment with certain
    576   attributes that allow us to build a Keras model
    577   just by knowing the inputs and outputs of the model.
    578 
    579   For instance, if a, b and c are Keras tensors,
    580   it becomes possible to do:
    581   `model = Model(input=[a, b], output=c)`
    582 
    583   The added Keras attribute is:
    584       `_keras_history`: Last layer applied to the tensor.
    585           the entire layer graph is retrievable from that layer,
    586           recursively.
    587 
    588   Arguments:
    589       shape: A shape tuple (integers), not including the batch size.
    590           For instance, `shape=(32,)` indicates that the expected input
    591           will be batches of 32-dimensional vectors.
    592       batch_size: optional static batch size (integer).
    593       name: An optional name string for the layer.
    594           Should be unique in a model (do not reuse the same name twice).
    595           It will be autogenerated if it isn't provided.
    596       dtype: The data type expected by the input, as a string
    597           (`float32`, `float64`, `int32`...)
    598       sparse: A boolean specifying whether the placeholder
    599           to be created is sparse.
    600       tensor: Optional existing tensor to wrap into the `Input` layer.
    601           If set, the layer will not create a placeholder tensor.
    602       **kwargs: deprecated arguments support.
    603 
    604   Returns:
    605       A tensor.
    606 
    607   Example:
    608 
    609       ```python
    610       # this is a logistic regression in Keras
    611       x = Input(shape=(32,))
    612       y = Dense(16, activation='softmax')(x)
    613       model = Model(x, y)
    614       ```
    615 
    616   Raises:
    617     ValueError: in case of invalid arguments.
    618   """
    619   if 'batch_shape' in kwargs:
    620     batch_shape = kwargs.pop('batch_shape')
    621     if shape and batch_shape:
    622       raise ValueError('Only provide the shape OR '
    623                        'batch_shape argument to '
    624                        'Input, not both at the same time.')
    625     batch_size = batch_shape[0]
    626     shape = batch_shape[1:]
    627   if kwargs:
    628     raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
    629 
    630   if dtype is None:
    631     dtype = K.floatx()
    632   if not shape and tensor is None:
    633     raise ValueError('Please provide to Input either a `shape`'
    634                      ' or a `tensor` argument. Note that '
    635                      '`shape` does not include the batch '
    636                      'dimension.')
    637   input_layer = InputLayer(
    638       input_shape=shape,
    639       batch_size=batch_size,
    640       name=name,
    641       dtype=dtype,
    642       sparse=sparse,
    643       input_tensor=tensor)
    644   # Return tensor including `_keras_history`.
    645   # Note that in this case train_output and test_output are the same pointer.
    646   outputs = input_layer._inbound_nodes[0].output_tensors
    647   if len(outputs) == 1:
    648     return outputs[0]
    649   else:
    650     return outputs
    651 
    652 
    653 class Network(tf_network.GraphNetwork, Layer):
    654   """A Network is a directed acyclic graph of layers.
    655 
    656   It is the topological form of a "model". A Model
    657   is simply a Network with added training routines.
    658 
    659   # Properties
    660       name
    661       inputs
    662       outputs
    663       input_layers
    664       output_layers
    665       input_spec (list of class instances)
    666           each entry describes one required input:
    667               - ndim
    668               - dtype
    669       trainable (boolean)
    670       input_shape
    671       output_shape
    672       inbound_nodes: list of nodes
    673       outbound_nodes: list of nodes
    674       trainable_weights (list of variables)
    675       non_trainable_weights (list of variables)
    676 
    677   # Methods
    678       summary
    679       get_layer
    680       get_weights
    681       set_weights
    682       get_config
    683       compute_output_shape
    684 
    685   # Class Methods
    686       from_config
    687   """
    688 
    689   def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
    690     # Signature detection
    691     if (len(args) == 2 or
    692         len(args) == 1 and 'outputs' in kwargs or
    693         'inputs' in kwargs and 'outputs' in kwargs):
    694       # Graph network
    695       self._init_graph_network(*args, **kwargs)
    696     else:
    697       # Subclassed network
    698       self._init_subclassed_network(**kwargs)
    699 
    700   def _init_graph_network(self, inputs, outputs, name=None):
    701     # TODO(fchollet): merge back tf.layers.Network and tf.keras.Network
    702     # into a single class tf.keras.Network
    703     super(Network, self).__init__(inputs, outputs, name=name)
    704 
    705     self._is_compiled = False
    706     self._expects_training_arg = False
    707 
    708     self.supports_masking = False
    709     self.optimizer = None
    710 
    711     # Fill in the output mask cache.
    712     masks = []
    713     for x in self.inputs:
    714       mask = x._keras_mask if hasattr(x, '_keras_mask') else None
    715       masks.append(mask)
    716     mask_cache_key = (tf_layers_util.object_list_uid(self.inputs) + '_' +
    717                       tf_layers_util.object_list_uid(masks))
    718     masks = []
    719     for x in self.outputs:
    720       mask = x._keras_mask if hasattr(x, '_keras_mask') else None
    721       masks.append(mask)
    722     if len(masks) == 1:
    723       mask = masks[0]
    724     else:
    725       mask = masks
    726     self._output_mask_cache[mask_cache_key] = mask
    727 
    728     # Build self.input_names and self.output_names.
    729     self.input_names = []
    730     self.output_names = []
    731     self._feed_input_names = []
    732     self._feed_inputs = []
    733     self._feed_input_shapes = []
    734     for i, layer in enumerate(self._input_layers):
    735       self.input_names.append(layer.name)
    736       if layer.is_placeholder:
    737         self._feed_input_names.append(layer.name)
    738         self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
    739         # layer.input gives an error in eager mode
    740         if context.in_graph_mode():
    741           self._feed_inputs.append(layer.input)
    742     for layer in self._output_layers:
    743       self.output_names.append(layer.name)
    744 
    745   def _init_subclassed_network(self, name=None):
    746     self._init_set_name(name)
    747     self._layers = []
    748     self._is_graph_network = False
    749     self._is_compiled = False
    750     if 'training' in tf_inspect.getargspec(self.call).args:
    751       self._expects_training_arg = True
    752     else:
    753       self._expects_training_arg = False
    754 
    755     self.outputs = None
    756     self.inputs = None
    757     self.trainable = True
    758     self.supports_masking = False
    759     self.built = False
    760     self.optimizer = None
    761 
    762     # Not used, exists for compatibility purposes due to implementation of
    763     # the base layer tf.layers.Layer - TODO(fchollet): clean up when refactoring
    764     self._scope = None
    765     self._reuse = None
    766     self._dtype = None
    767     self._graph = None
    768     self._activity_regularizer = None
    769 
    770     # Used in symbolic mode only
    771     self._updates = []
    772     self._losses = []
    773 
    774     # Used in symbolic mode only, only in conjonction with graph-networks
    775     self._outbound_nodes = []
    776     self._inbound_nodes = []
    777 
    778   def __setattr__(self, name, value):
    779     if isinstance(value, (tf_base_layers.Layer, Network)):
    780       try:
    781         is_graph_network = self._is_graph_network
    782       except AttributeError:
    783         raise RuntimeError('It looks like you are subclassing `Model` and you '
    784                            'forgot to call `super(YourClass, self).__init__()`.'
    785                            ' Always start with this line.')
    786       if not is_graph_network:
    787         if value not in self._layers:
    788           self._layers.append(value)
    789     super(Network, self).__setattr__(name, value)
    790 
    791   def add_variable(self, name, shape, dtype=None, initializer=None,
    792                    regularizer=None, trainable=True, constraint=None):
    793     raise NotImplementedError('`add_variable` is not supported on Networks')
    794 
    795   def add_loss(self, *args, **kwargs):
    796     if context.in_eager_mode():
    797       raise NotImplementedError('`add_loss` is not supported in eager-mode '
    798                                 'on Networks')
    799     super(Network, self).add_loss(*args, **kwargs)
    800 
    801   @property
    802   def uses_learning_phase(self):
    803     return any(
    804         [getattr(x, '_uses_learning_phase', False) for x in self.outputs])
    805 
    806   @property
    807   def stateful(self):
    808     return any([(hasattr(layer, 'stateful') and layer.stateful)
    809                 for layer in self.layers])
    810 
    811   def reset_states(self):
    812     for layer in self.layers:
    813       if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
    814         layer.reset_states()
    815 
    816   @property
    817   def state_updates(self):
    818     """Returns the `updates` from all layers that are stateful.
    819 
    820     This is useful for separating training updates and
    821     state updates, e.g. when we need to update a layer's internal state
    822     during prediction.
    823 
    824     Returns:
    825         A list of update ops.
    826     """
    827     state_updates = []
    828     for layer in self.layers:
    829       if getattr(layer, 'stateful', False):
    830         if hasattr(layer, 'updates'):
    831           state_updates += layer.updates
    832     return state_updates
    833 
    834   def get_weights(self):
    835     """Retrieves the weights of the model.
    836 
    837     Returns:
    838         A flat list of Numpy arrays.
    839     """
    840     weights = []
    841     for layer in self.layers:
    842       weights += layer.weights
    843     return K.batch_get_value(weights)
    844 
    845   def set_weights(self, weights):
    846     """Sets the weights of the model.
    847 
    848     Arguments:
    849         weights: A list of Numpy arrays with shapes and types matching
    850             the output of `model.get_weights()`.
    851     """
    852     tuples = []
    853     for layer in self.layers:
    854       num_param = len(layer.weights)
    855       layer_weights = weights[:num_param]
    856       for sw, w in zip(layer.weights, layer_weights):
    857         tuples.append((sw, w))
    858       weights = weights[num_param:]
    859     K.batch_set_value(tuples)
    860 
    861   def compute_mask(self, inputs, mask):
    862     if not self._is_graph_network:
    863       return None
    864 
    865     inputs = _to_list(inputs)
    866     if mask is None:
    867       masks = [None for _ in range(len(inputs))]
    868     else:
    869       masks = _to_list(mask)
    870     cache_key = (tf_layers_util.object_list_uid(inputs)
    871                  + '_' + tf_layers_util.object_list_uid(masks))
    872     if cache_key in self._output_mask_cache:
    873       return self._output_mask_cache[cache_key]
    874     else:
    875       _, output_masks = self._run_internal_graph(inputs, masks)
    876       return output_masks
    877 
    878   def get_config(self):
    879     if not self._is_graph_network:
    880       raise NotImplementedError
    881 
    882     config = {
    883         'name': self.name,
    884     }
    885     node_conversion_map = {}
    886     for layer in self.layers:
    887       if issubclass(layer.__class__, Network):
    888         # Networks start with a pre-existing node
    889         # linking their input to output.
    890         kept_nodes = 1
    891       else:
    892         kept_nodes = 0
    893       for original_node_index, node in enumerate(layer._inbound_nodes):
    894         node_key = tf_network._make_node_key(layer.name,
    895                                              original_node_index)
    896         if node_key in self._network_nodes:
    897           node_conversion_map[node_key] = kept_nodes
    898           kept_nodes += 1
    899     layer_configs = []
    900     for layer in self.layers:  # From the earliest layers on.
    901       layer_class_name = layer.__class__.__name__
    902       layer_config = layer.get_config()
    903       filtered_inbound_nodes = []
    904       for original_node_index, node in enumerate(layer._inbound_nodes):
    905         node_key = tf_network._make_node_key(layer.name,
    906                                              original_node_index)
    907         if node_key in self._network_nodes:
    908           # The node is relevant to the model:
    909           # add to filtered_inbound_nodes.
    910           if node.arguments:
    911             try:
    912               json.dumps(node.arguments)
    913               kwargs = node.arguments
    914             except TypeError:
    915               logging.warning(
    916                   'Layer ' + layer.name +
    917                   ' was passed non-serializable keyword arguments: ' +
    918                   str(node.arguments) + '. They will not be included '
    919                   'in the serialized model (and thus will be missing '
    920                   'at deserialization time).')
    921               kwargs = {}
    922           else:
    923             kwargs = {}
    924           if node.inbound_layers:
    925             node_data = []
    926             for i in range(len(node.inbound_layers)):
    927               inbound_layer = node.inbound_layers[i]
    928               node_index = node.node_indices[i]
    929               tensor_index = node.tensor_indices[i]
    930               node_key = tf_network._make_node_key(inbound_layer.name,
    931                                                    node_index)
    932               new_node_index = node_conversion_map.get(node_key, 0)
    933               node_data.append(
    934                   [inbound_layer.name, new_node_index, tensor_index, kwargs])
    935             filtered_inbound_nodes.append(node_data)
    936       layer_configs.append({
    937           'name': layer.name,
    938           'class_name': layer_class_name,
    939           'config': layer_config,
    940           'inbound_nodes': filtered_inbound_nodes,
    941       })
    942     config['layers'] = layer_configs
    943 
    944     # Gather info about inputs and outputs.
    945     model_inputs = []
    946     for i in range(len(self._input_layers)):
    947       layer, node_index, tensor_index = self._input_coordinates[i]
    948       node_key = tf_network._make_node_key(layer.name,
    949                                            node_index)
    950       if node_key not in self._network_nodes:
    951         continue
    952       new_node_index = node_conversion_map[node_key]
    953       model_inputs.append([layer.name, new_node_index, tensor_index])
    954     config['input_layers'] = model_inputs
    955     model_outputs = []
    956     for i in range(len(self._output_layers)):
    957       layer, node_index, tensor_index = self._output_coordinates[i]
    958       node_key = tf_network._make_node_key(layer.name,
    959                                            node_index)
    960       if node_key not in self._network_nodes:
    961         continue
    962       new_node_index = node_conversion_map[node_key]
    963       model_outputs.append([layer.name, new_node_index, tensor_index])
    964     config['output_layers'] = model_outputs
    965     return copy.deepcopy(config)
    966 
    967   @classmethod
    968   def from_config(cls, config, custom_objects=None):
    969     """Instantiates a Model from its config (output of `get_config()`).
    970 
    971     Arguments:
    972         config: Model config dictionary.
    973         custom_objects: Optional dictionary mapping names
    974             (strings) to custom classes or functions to be
    975             considered during deserialization.
    976 
    977     Returns:
    978         A model instance.
    979 
    980     Raises:
    981         ValueError: In case of improperly formatted config dict.
    982     """
    983     # Layer instances created during
    984     # the graph reconstruction process
    985     created_layers = {}
    986 
    987     # Dictionary mapping layer instances to
    988     # node data that specifies a layer call.
    989     # It acts as a queue that maintains any unprocessed
    990     # layer call until it becomes possible to process it
    991     # (i.e. until the input tensors to the call all exist).
    992     unprocessed_nodes = {}
    993 
    994     def add_unprocessed_node(layer, node_data):
    995       if layer not in unprocessed_nodes:
    996         unprocessed_nodes[layer] = [node_data]
    997       else:
    998         unprocessed_nodes[layer].append(node_data)
    999 
   1000     def process_node(layer, node_data):
   1001       """Deserialize a node.
   1002 
   1003       Arguments:
   1004           layer: layer instance.
   1005           node_data: node config dict.
   1006 
   1007       Raises:
   1008           ValueError: In case of improperly formatted `node_data` dict.
   1009       """
   1010       input_tensors = []
   1011       for input_data in node_data:
   1012         inbound_layer_name = input_data[0]
   1013         inbound_node_index = input_data[1]
   1014         inbound_tensor_index = input_data[2]
   1015         if len(input_data) == 3:
   1016           kwargs = {}
   1017         elif len(input_data) == 4:
   1018           kwargs = input_data[3]
   1019         else:
   1020           raise ValueError('Improperly formatted model config.')
   1021         if inbound_layer_name not in created_layers:
   1022           add_unprocessed_node(layer, node_data)
   1023           return
   1024         inbound_layer = created_layers[inbound_layer_name]
   1025         if len(inbound_layer._inbound_nodes) <= inbound_node_index:
   1026           add_unprocessed_node(layer, node_data)
   1027           return
   1028         inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
   1029         input_tensors.append(inbound_node.output_tensors[inbound_tensor_index])
   1030       # Call layer on its inputs, thus creating the node
   1031       # and building the layer if needed.
   1032       if input_tensors:
   1033         if len(input_tensors) == 1:
   1034           layer(input_tensors[0], **kwargs)
   1035         else:
   1036           layer(input_tensors, **kwargs)
   1037 
   1038     def process_layer(layer_data):
   1039       """Deserialize a layer, then call it on appropriate inputs.
   1040 
   1041       Arguments:
   1042           layer_data: layer config dict.
   1043 
   1044       Raises:
   1045           ValueError: In case of improperly formatted `layer_data` dict.
   1046       """
   1047       layer_name = layer_data['name']
   1048 
   1049       # Instantiate layer.
   1050       from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1051 
   1052       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1053       created_layers[layer_name] = layer
   1054 
   1055       # Gather layer inputs.
   1056       inbound_nodes_data = layer_data['inbound_nodes']
   1057       for node_data in inbound_nodes_data:
   1058         # We don't process nodes (i.e. make layer calls)
   1059         # on the fly because the inbound node may not yet exist,
   1060         # in case of layer shared at different topological depths
   1061         # (e.g. a model such as A(B(A(B(x)))))
   1062         add_unprocessed_node(layer, node_data)
   1063 
   1064     # First, we create all layers and enqueue nodes to be processed
   1065     for layer_data in config['layers']:
   1066       process_layer(layer_data)
   1067     # Then we process nodes in order of layer depth.
   1068     # Nodes that cannot yet be processed (if the inbound node
   1069     # does not yet exist) are re-enqueued, and the process
   1070     # is repeated until all nodes are processed.
   1071     while unprocessed_nodes:
   1072       for layer_data in config['layers']:
   1073         layer = created_layers[layer_data['name']]
   1074         if layer in unprocessed_nodes:
   1075           for node_data in unprocessed_nodes.pop(layer):
   1076             process_node(layer, node_data)
   1077 
   1078     name = config.get('name')
   1079     input_tensors = []
   1080     output_tensors = []
   1081     for layer_data in config['input_layers']:
   1082       layer_name, node_index, tensor_index = layer_data
   1083       assert layer_name in created_layers
   1084       layer = created_layers[layer_name]
   1085       layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
   1086       input_tensors.append(layer_output_tensors[tensor_index])
   1087     for layer_data in config['output_layers']:
   1088       layer_name, node_index, tensor_index = layer_data
   1089       assert layer_name in created_layers
   1090       layer = created_layers[layer_name]
   1091       layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
   1092       output_tensors.append(layer_output_tensors[tensor_index])
   1093     return cls(inputs=input_tensors, outputs=output_tensors, name=name)
   1094 
   1095   def save(self, filepath, overwrite=True, include_optimizer=True):
   1096     """Save the model to a single HDF5 file.
   1097 
   1098     The savefile includes:
   1099         - The model architecture, allowing to re-instantiate the model.
   1100         - The model weights.
   1101         - The state of the optimizer, allowing to resume training
   1102             exactly where you left off.
   1103 
   1104     This allows you to save the entirety of the state of a model
   1105     in a single file.
   1106 
   1107     Saved models can be reinstantiated via `keras.models.load_model`.
   1108     The model returned by `load_model`
   1109     is a compiled model ready to be used (unless the saved model
   1110     was never compiled in the first place).
   1111 
   1112     Arguments:
   1113         filepath: String, path to the file to save the weights to.
   1114         overwrite: Whether to silently overwrite any existing file at the
   1115             target location, or provide the user with a manual prompt.
   1116         include_optimizer: If True, save optimizer's state together.
   1117 
   1118     Example:
   1119 
   1120     ```python
   1121     from keras.models import load_model
   1122 
   1123     model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
   1124     del model  # deletes the existing model
   1125 
   1126     # returns a compiled model
   1127     # identical to the previous one
   1128     model = load_model('my_model.h5')
   1129     ```
   1130     """
   1131     if not self._is_graph_network:
   1132       raise NotImplementedError
   1133 
   1134     from tensorflow.python.keras._impl.keras.models import save_model  # pylint: disable=g-import-not-at-top
   1135     save_model(self, filepath, overwrite, include_optimizer)
   1136 
   1137   def save_weights(self, filepath, overwrite=True):
   1138     """Dumps all layer weights to a HDF5 file.
   1139 
   1140     The weight file has:
   1141         - `layer_names` (attribute), a list of strings
   1142             (ordered names of model layers).
   1143         - For every layer, a `group` named `layer.name`
   1144             - For every such layer group, a group attribute `weight_names`,
   1145                 a list of strings
   1146                 (ordered names of weights tensor of the layer).
   1147             - For every weight in the layer, a dataset
   1148                 storing the weight value, named after the weight tensor.
   1149 
   1150     Arguments:
   1151         filepath: String, path to the file to save the weights to.
   1152         overwrite: Whether to silently overwrite any existing file at the
   1153             target location, or provide the user with a manual prompt.
   1154 
   1155     Raises:
   1156         ImportError: If h5py is not available.
   1157     """
   1158     if h5py is None:
   1159       raise ImportError('`save_weights` requires h5py.')
   1160     # If file exists and should not be overwritten:
   1161     if not overwrite and os.path.isfile(filepath):
   1162       proceed = ask_to_proceed_with_overwrite(filepath)
   1163       if not proceed:
   1164         return
   1165     with h5py.File(filepath, 'w') as f:
   1166       save_weights_to_hdf5_group(f, self.layers)
   1167 
   1168   def load_weights(self, filepath, by_name=False):
   1169     """Loads all layer weights from a HDF5 save file.
   1170 
   1171     If `by_name` is False (default) weights are loaded
   1172     based on the network's topology, meaning the architecture
   1173     should be the same as when the weights were saved.
   1174     Note that layers that don't have weights are not taken
   1175     into account in the topological ordering, so adding or
   1176     removing layers is fine as long as they don't have weights.
   1177 
   1178     If `by_name` is True, weights are loaded into layers
   1179     only if they share the same name. This is useful
   1180     for fine-tuning or transfer-learning models where
   1181     some of the layers have changed.
   1182 
   1183     Arguments:
   1184         filepath: String, path to the weights file to load.
   1185         by_name: Boolean, whether to load weights by name
   1186             or by topological order.
   1187 
   1188     Raises:
   1189         ImportError: If h5py is not available.
   1190     """
   1191     if h5py is None:
   1192       raise ImportError('`load_weights` requires h5py.')
   1193     with h5py.File(filepath, 'r') as f:
   1194       if 'layer_names' not in f.attrs and 'model_weights' in f:
   1195         f = f['model_weights']
   1196       if by_name:
   1197         load_weights_from_hdf5_group_by_name(f, self.layers)
   1198       else:
   1199         load_weights_from_hdf5_group(f, self.layers)
   1200 
   1201   def _updated_config(self):
   1202     """Util hared between different serialization methods.
   1203 
   1204     Returns:
   1205         Model config with Keras version information added.
   1206     """
   1207     from tensorflow.python.keras._impl.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
   1208 
   1209     config = self.get_config()
   1210     model_config = {
   1211         'class_name': self.__class__.__name__,
   1212         'config': config,
   1213         'keras_version': keras_version,
   1214         'backend': K.backend()
   1215     }
   1216     return model_config
   1217 
   1218   def to_json(self, **kwargs):
   1219     """Returns a JSON string containing the network configuration.
   1220 
   1221     To load a network from a JSON save file, use
   1222     `keras.models.model_from_json(json_string, custom_objects={})`.
   1223 
   1224     Arguments:
   1225         **kwargs: Additional keyword arguments
   1226             to be passed to `json.dumps()`.
   1227 
   1228     Returns:
   1229         A JSON string.
   1230     """
   1231     if not self._is_graph_network:
   1232       raise NotImplementedError
   1233 
   1234     def get_json_type(obj):
   1235       # If obj is any numpy type
   1236       if type(obj).__module__ == np.__name__:
   1237         return obj.item()
   1238 
   1239       # If obj is a python 'type'
   1240       if type(obj).__name__ == type.__name__:
   1241         return obj.__name__
   1242 
   1243       raise TypeError('Not JSON Serializable:', obj)
   1244 
   1245     model_config = self._updated_config()
   1246     return json.dumps(model_config, default=get_json_type, **kwargs)
   1247 
   1248   def to_yaml(self, **kwargs):
   1249     """Returns a yaml string containing the network configuration.
   1250 
   1251     To load a network from a yaml save file, use
   1252     `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
   1253 
   1254     `custom_objects` should be a dictionary mapping
   1255     the names of custom losses / layers / etc to the corresponding
   1256     functions / classes.
   1257 
   1258     Arguments:
   1259         **kwargs: Additional keyword arguments
   1260             to be passed to `yaml.dump()`.
   1261 
   1262     Returns:
   1263         A YAML string.
   1264 
   1265     Raises:
   1266         ImportError: if yaml module is not found.
   1267     """
   1268     if not self._is_graph_network:
   1269       raise NotImplementedError
   1270 
   1271     if yaml is None:
   1272       raise ImportError('Requires yaml module installed.')
   1273     return yaml.dump(self._updated_config(), **kwargs)
   1274 
   1275   def summary(self, line_length=None, positions=None, print_fn=None):
   1276     """Prints a string summary of the network.
   1277 
   1278     Arguments:
   1279         line_length: Total length of printed lines
   1280             (e.g. set this to adapt the display to different
   1281             terminal window sizes).
   1282         positions: Relative or absolute positions of log elements
   1283             in each line. If not provided,
   1284             defaults to `[.33, .55, .67, 1.]`.
   1285         print_fn: Print function to use. Defaults to `print`.
   1286             It will be called on each line of the summary.
   1287             You can set it to a custom function
   1288             in order to capture the string summary.
   1289     """
   1290     print_layer_summary(self,
   1291                         line_length=line_length,
   1292                         positions=positions,
   1293                         print_fn=print_fn)
   1294 
   1295 
   1296 def get_source_inputs(tensor, layer=None, node_index=None):
   1297   """Returns the list of input tensors necessary to compute `tensor`.
   1298 
   1299   Output will always be a list of tensors
   1300   (potentially with 1 element).
   1301 
   1302   Arguments:
   1303       tensor: The tensor to start from.
   1304       layer: Origin layer of the tensor. Will be
   1305           determined via tensor._keras_history if not provided.
   1306       node_index: Origin node index of the tensor.
   1307 
   1308   Returns:
   1309       List of input tensors.
   1310   """
   1311   if not hasattr(tensor, '_keras_history'):
   1312     return tensor
   1313 
   1314   if layer is None or node_index:
   1315     layer, node_index, _ = tensor._keras_history
   1316   if not layer._inbound_nodes:
   1317     return [tensor]
   1318   else:
   1319     node = layer._inbound_nodes[node_index]
   1320     if not node.inbound_layers:
   1321       # Reached an Input layer, stop recursion.
   1322       return node.input_tensors
   1323     else:
   1324       source_tensors = []
   1325       for i in range(len(node.inbound_layers)):
   1326         x = node.input_tensors[i]
   1327         layer = node.inbound_layers[i]
   1328         node_index = node.node_indices[i]
   1329         previous_sources = get_source_inputs(x, layer, node_index)
   1330         # Avoid input redundancy.
   1331         for x in previous_sources:
   1332           if x not in source_tensors:
   1333             source_tensors.append(x)
   1334       return source_tensors
   1335 
   1336 
   1337 def _to_list(x):
   1338   """Normalizes a list/tensor into a list.
   1339 
   1340   If a tensor is passed, we return
   1341   a list of size 1 containing the tensor.
   1342 
   1343   Arguments:
   1344       x: target object to be normalized.
   1345 
   1346   Returns:
   1347       A list.
   1348   """
   1349   if isinstance(x, list):
   1350     return x
   1351   return [x]
   1352 
   1353 
   1354 def save_weights_to_hdf5_group(f, layers):
   1355   from tensorflow.python.keras._impl.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
   1356 
   1357   f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
   1358   f.attrs['backend'] = K.backend().encode('utf8')
   1359   f.attrs['keras_version'] = str(keras_version).encode('utf8')
   1360 
   1361   for layer in layers:
   1362     g = f.create_group(layer.name)
   1363     symbolic_weights = layer.weights
   1364     weight_values = K.batch_get_value(symbolic_weights)
   1365     weight_names = []
   1366     for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
   1367       if hasattr(w, 'name') and w.name:
   1368         name = str(w.name)
   1369       else:
   1370         name = 'param_' + str(i)
   1371       weight_names.append(name.encode('utf8'))
   1372     g.attrs['weight_names'] = weight_names
   1373     for name, val in zip(weight_names, weight_values):
   1374       param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
   1375       if not val.shape:
   1376         # scalar
   1377         param_dset[()] = val
   1378       else:
   1379         param_dset[:] = val
   1380 
   1381 
   1382 def preprocess_weights_for_loading(layer,
   1383                                    weights,
   1384                                    original_keras_version=None,
   1385                                    original_backend=None):
   1386   """Converts layers weights from Keras 1 format to Keras 2.
   1387 
   1388   Arguments:
   1389       layer: Layer instance.
   1390       weights: List of weights values (Numpy arrays).
   1391       original_keras_version: Keras version for the weights, as a string.
   1392       original_backend: Keras backend the weights were trained with,
   1393           as a string.
   1394 
   1395   Returns:
   1396       A list of weights values (Numpy arrays).
   1397   """
   1398   if layer.__class__.__name__ == 'Bidirectional':
   1399     num_weights_per_layer = len(weights) // 2
   1400     forward_weights = preprocess_weights_for_loading(
   1401         layer.forward_layer, weights[:num_weights_per_layer],
   1402         original_keras_version, original_backend)
   1403     backward_weights = preprocess_weights_for_loading(
   1404         layer.backward_layer, weights[num_weights_per_layer:],
   1405         original_keras_version, original_backend)
   1406     weights = forward_weights + backward_weights
   1407 
   1408   if original_keras_version == '1':
   1409     if layer.__class__.__name__ == 'TimeDistributed':
   1410       weights = preprocess_weights_for_loading(
   1411           layer.layer, weights, original_keras_version, original_backend)
   1412 
   1413     if layer.__class__.__name__ == 'Conv1D':
   1414       shape = weights[0].shape
   1415       # Handle Keras 1.1 format
   1416       if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
   1417         # Legacy shape:
   1418         # (filters, input_dim, filter_length, 1)
   1419         assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
   1420                                                            1)
   1421         weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
   1422       weights[0] = weights[0][:, 0, :, :]
   1423 
   1424     if layer.__class__.__name__ == 'Conv2D':
   1425       if layer.data_format == 'channels_first':
   1426         # old: (filters, stack_size, kernel_rows, kernel_cols)
   1427         # new: (kernel_rows, kernel_cols, stack_size, filters)
   1428         weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
   1429 
   1430     if layer.__class__.__name__ == 'Conv2DTranspose':
   1431       if layer.data_format == 'channels_last':
   1432         # old: (kernel_rows, kernel_cols, stack_size, filters)
   1433         # new: (kernel_rows, kernel_cols, filters, stack_size)
   1434         weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
   1435       if layer.data_format == 'channels_first':
   1436         # old: (filters, stack_size, kernel_rows, kernel_cols)
   1437         # new: (kernel_rows, kernel_cols, filters, stack_size)
   1438         weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
   1439 
   1440     if layer.__class__.__name__ == 'Conv3D':
   1441       if layer.data_format == 'channels_first':
   1442         # old: (filters, stack_size, ...)
   1443         # new: (..., stack_size, filters)
   1444         weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
   1445 
   1446     if layer.__class__.__name__ == 'GRU':
   1447       if len(weights) == 9:
   1448         kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
   1449         recurrent_kernel = np.concatenate(
   1450             [weights[1], weights[4], weights[7]], axis=-1)
   1451         bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
   1452         weights = [kernel, recurrent_kernel, bias]
   1453 
   1454     if layer.__class__.__name__ == 'LSTM':
   1455       if len(weights) == 12:
   1456         # old: i, c, f, o
   1457         # new: i, f, c, o
   1458         kernel = np.concatenate(
   1459             [weights[0], weights[6], weights[3], weights[9]], axis=-1)
   1460         recurrent_kernel = np.concatenate(
   1461             [weights[1], weights[7], weights[4], weights[10]], axis=-1)
   1462         bias = np.concatenate(
   1463             [weights[2], weights[8], weights[5], weights[11]], axis=-1)
   1464         weights = [kernel, recurrent_kernel, bias]
   1465 
   1466     if layer.__class__.__name__ == 'ConvLSTM2D':
   1467       if len(weights) == 12:
   1468         kernel = np.concatenate(
   1469             [weights[0], weights[6], weights[3], weights[9]], axis=-1)
   1470         recurrent_kernel = np.concatenate(
   1471             [weights[1], weights[7], weights[4], weights[10]], axis=-1)
   1472         bias = np.concatenate(
   1473             [weights[2], weights[8], weights[5], weights[11]], axis=-1)
   1474         if layer.data_format == 'channels_first':
   1475           # old: (filters, stack_size, kernel_rows, kernel_cols)
   1476           # new: (kernel_rows, kernel_cols, stack_size, filters)
   1477           kernel = np.transpose(kernel, (2, 3, 1, 0))
   1478           recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
   1479         weights = [kernel, recurrent_kernel, bias]
   1480 
   1481     if layer.__class__.__name__ in ['Model', 'Sequential']:
   1482       new_weights = []
   1483       # trainable weights
   1484       for sublayer in layer.layers:
   1485         num_weights = len(sublayer.trainable_weights)
   1486         if num_weights > 0:
   1487           new_weights.extend(
   1488               preprocess_weights_for_loading(
   1489                   layer=sublayer,
   1490                   weights=weights[:num_weights],
   1491                   original_keras_version=original_keras_version,
   1492                   original_backend=original_backend))
   1493           weights = weights[num_weights:]
   1494 
   1495       # non-trainable weights
   1496       for sublayer in layer.layers:
   1497         num_weights = len([
   1498             l for l in sublayer.weights if l not in sublayer.trainable_weights
   1499         ])
   1500         if num_weights > 0:
   1501           new_weights.extend(
   1502               preprocess_weights_for_loading(
   1503                   layer=sublayer,
   1504                   weights=weights[:num_weights],
   1505                   original_keras_version=original_keras_version,
   1506                   original_backend=original_backend))
   1507           weights = weights[num_weights:]
   1508       weights = new_weights
   1509 
   1510   conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
   1511   if layer.__class__.__name__ in conv_layers:
   1512     if original_backend == 'theano':
   1513       weights[0] = conv_utils.convert_kernel(weights[0])
   1514       if layer.__class__.__name__ == 'ConvLSTM2D':
   1515         weights[1] = conv_utils.convert_kernel(weights[1])
   1516     if K.int_shape(layer.weights[0]) != weights[0].shape:
   1517       weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
   1518       if layer.__class__.__name__ == 'ConvLSTM2D':
   1519         weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
   1520 
   1521   # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM
   1522   if layer.__class__.__name__ == 'LSTM' and len(weights) == 3:
   1523     # Determine if loading a CuDNNLSTM layer from the number of bias weights:
   1524     # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
   1525     # if there's no bias weight in the file, skip this conversion
   1526     units = weights[1].shape[0]
   1527     bias = weights[2]
   1528     if len(bias) == units * 8:
   1529       # reshape the kernels
   1530       kernels = np.split(weights[0], 4, axis=1)
   1531       kernels = [
   1532           kernel.reshape(-1).reshape(kernel.shape, order='F')
   1533           for kernel in kernels
   1534       ]
   1535       weights[0] = np.concatenate(kernels, axis=1)
   1536 
   1537       # transpose the recurrent kernels
   1538       recurrent_kernels = np.split(weights[1], 4, axis=1)
   1539       recurrent_kernels = [kernel.T for kernel in recurrent_kernels]
   1540       weights[1] = np.concatenate(recurrent_kernels, axis=1)
   1541 
   1542       # split the bias into half and merge
   1543       weights[2] = bias[:units * 4] + bias[units * 4:]
   1544 
   1545   return weights
   1546 
   1547 
   1548 def load_weights_from_hdf5_group(f, layers):
   1549   """Implements topological (order-based) weight loading.
   1550 
   1551   Arguments:
   1552       f: A pointer to a HDF5 group.
   1553       layers: a list of target layers.
   1554 
   1555   Raises:
   1556       ValueError: in case of mismatch between provided layers
   1557           and weights file.
   1558   """
   1559   if 'keras_version' in f.attrs:
   1560     original_keras_version = f.attrs['keras_version'].decode('utf8')
   1561   else:
   1562     original_keras_version = '1'
   1563   if 'backend' in f.attrs:
   1564     original_backend = f.attrs['backend'].decode('utf8')
   1565   else:
   1566     original_backend = None
   1567 
   1568   filtered_layers = []
   1569   for layer in layers:
   1570     weights = layer.weights
   1571     if weights:
   1572       filtered_layers.append(layer)
   1573 
   1574   layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
   1575   filtered_layer_names = []
   1576   for name in layer_names:
   1577     g = f[name]
   1578     weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
   1579     if weight_names:
   1580       filtered_layer_names.append(name)
   1581   layer_names = filtered_layer_names
   1582   if len(layer_names) != len(filtered_layers):
   1583     raise ValueError('You are trying to load a weight file '
   1584                      'containing ' + str(len(layer_names)) +
   1585                      ' layers into a model with ' + str(len(filtered_layers)) +
   1586                      ' layers.')
   1587 
   1588   # We batch weight value assignments in a single backend call
   1589   # which provides a speedup in TensorFlow.
   1590   weight_value_tuples = []
   1591   for k, name in enumerate(layer_names):
   1592     g = f[name]
   1593     weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
   1594     weight_values = [g[weight_name] for weight_name in weight_names]
   1595     layer = filtered_layers[k]
   1596     symbolic_weights = layer.weights
   1597     weight_values = preprocess_weights_for_loading(
   1598         layer, weight_values, original_keras_version, original_backend)
   1599     if len(weight_values) != len(symbolic_weights):
   1600       raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
   1601                        '" in the current model) was found to '
   1602                        'correspond to layer ' + name + ' in the save file. '
   1603                        'However the new layer ' + layer.name + ' expects ' +
   1604                        str(len(symbolic_weights)) +
   1605                        ' weights, but the saved weights have ' +
   1606                        str(len(weight_values)) + ' elements.')
   1607     weight_value_tuples += zip(symbolic_weights, weight_values)
   1608   K.batch_set_value(weight_value_tuples)
   1609 
   1610 
   1611 def load_weights_from_hdf5_group_by_name(f, layers):
   1612   """Implements name-based weight loading.
   1613 
   1614   (instead of topological weight loading).
   1615 
   1616   Layers that have no matching name are skipped.
   1617 
   1618   Arguments:
   1619       f: A pointer to a HDF5 group.
   1620       layers: a list of target layers.
   1621 
   1622   Raises:
   1623       ValueError: in case of mismatch between provided layers
   1624           and weights file.
   1625   """
   1626   if 'keras_version' in f.attrs:
   1627     original_keras_version = f.attrs['keras_version'].decode('utf8')
   1628   else:
   1629     original_keras_version = '1'
   1630   if 'backend' in f.attrs:
   1631     original_backend = f.attrs['backend'].decode('utf8')
   1632   else:
   1633     original_backend = None
   1634 
   1635   # New file format.
   1636   layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
   1637 
   1638   # Reverse index of layer name to list of layers with name.
   1639   index = {}
   1640   for layer in layers:
   1641     if layer.name:
   1642       index.setdefault(layer.name, []).append(layer)
   1643 
   1644   # We batch weight value assignments in a single backend call
   1645   # which provides a speedup in TensorFlow.
   1646   weight_value_tuples = []
   1647   for k, name in enumerate(layer_names):
   1648     g = f[name]
   1649     weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
   1650     weight_values = [g[weight_name] for weight_name in weight_names]
   1651 
   1652     for layer in index.get(name, []):
   1653       symbolic_weights = layer.weights
   1654       weight_values = preprocess_weights_for_loading(
   1655           layer, weight_values, original_keras_version, original_backend)
   1656       if len(weight_values) != len(symbolic_weights):
   1657         raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
   1658                          '") expects ' + str(len(symbolic_weights)) +
   1659                          ' weight(s), but the saved weights' + ' have ' +
   1660                          str(len(weight_values)) + ' element(s).')
   1661       # Set values.
   1662       for i in range(len(weight_values)):
   1663         weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
   1664   K.batch_set_value(weight_value_tuples)
   1665 
   1666 
   1667 def shape_type_conversion(fn):
   1668   """Decorator that handles tuple/TensorShape conversion.
   1669 
   1670   Used in `compute_output_shape` and `build`.
   1671 
   1672   Arguments:
   1673     fn: function to wrap.
   1674 
   1675   Returns:
   1676     Wrapped function.
   1677   """
   1678 
   1679   def wrapper(instance, input_shape):
   1680     if input_shape is not None:
   1681       if isinstance(input_shape, list):
   1682         input_shape = [
   1683             tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
   1684       else:
   1685         input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
   1686     output_shape = fn(instance, input_shape)
   1687     if output_shape is not None:
   1688       if isinstance(output_shape, list):
   1689         return [tensor_shape.TensorShape(x) for x in output_shape]
   1690       return tensor_shape.TensorShape(output_shape)
   1691 
   1692   return wrapper
   1693