1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 # pylint: disable=protected-access 16 """Base layer code and base model (Network) code. 17 """ 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import copy 23 import json 24 import os 25 26 import numpy as np 27 from six.moves import zip # pylint: disable=redefined-builtin 28 29 from tensorflow.python.eager import context 30 from tensorflow.python.framework import tensor_shape 31 from tensorflow.python.keras._impl.keras import backend as K 32 from tensorflow.python.keras._impl.keras import constraints 33 from tensorflow.python.keras._impl.keras import initializers 34 from tensorflow.python.keras._impl.keras import regularizers 35 from tensorflow.python.keras._impl.keras.utils import conv_utils 36 from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite 37 from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary 38 from tensorflow.python.layers import base as tf_base_layers 39 from tensorflow.python.layers import network as tf_network 40 from tensorflow.python.layers import utils as tf_layers_util 41 from tensorflow.python.platform import tf_logging as logging 42 from tensorflow.python.util import tf_inspect 43 from tensorflow.python.util.tf_export import tf_export 44 45 46 # pylint: disable=g-import-not-at-top 47 try: 48 import h5py 49 except ImportError: 50 h5py = None 51 52 try: 53 import yaml 54 except ImportError: 55 yaml = None 56 # pylint: enable=g-import-not-at-top 57 58 # pylint: disable=invalid-name 59 InputSpec = tf_base_layers.InputSpec 60 Node = tf_base_layers.Node 61 TFBaseLayer = tf_base_layers.Layer 62 # pylint: enable=invalid-name 63 64 65 @tf_export('keras.layers.Layer') 66 class Layer(tf_base_layers.Layer): 67 """Abstract base layer class. 68 69 # Properties 70 name: String, must be unique within a model. 71 input_spec: List of InputSpec class instances 72 each entry describes one required input: 73 - ndim 74 - dtype 75 A layer with `n` input tensors must have 76 an `input_spec` of length `n`. 77 trainable: Boolean, whether the layer weights 78 will be updated during training. 79 uses_learning_phase: Whether any operation 80 of the layer uses `K.in_training_phase()` 81 or `K.in_test_phase()`. 82 input_shape: Shape tuple. Provided for convenience, 83 but note that there may be cases in which this 84 attribute is ill-defined (e.g. a shared layer 85 with multiple input shapes), in which case 86 requesting `input_shape` will raise an Exception. 87 Prefer using `layer.get_input_shape_for(input_shape)`, 88 or `layer.get_input_shape_at(node_index)`. 89 output_shape: Shape tuple. See above. 90 inbound_nodes: List of nodes. 91 outbound_nodes: List of nodes. 92 input, output: Input/output tensor(s). Note that if the layer is used 93 more than once (shared layer), this is ill-defined 94 and will raise an exception. In such cases, use 95 `layer.get_input_at(node_index)`. 96 input_mask, output_mask: Same as above, for masks. 97 trainable_weights: List of variables. 98 non_trainable_weights: List of variables. 99 weights: The concatenation of the lists trainable_weights and 100 non_trainable_weights (in this order). 101 102 # Methods 103 call(x, mask=None): Where the layer's logic lives. 104 __call__(x, mask=None): Wrapper around the layer logic (`call`). 105 If x is a Keras tensor: 106 - Connect current layer with last layer from tensor: 107 `self._add_inbound_node(last_layer)` 108 - Add layer to tensor history 109 If layer is not built: 110 - Build from inputs shape 111 get_weights() 112 set_weights(weights) 113 get_config() 114 count_params() 115 compute_output_shape(input_shape) 116 compute_mask(x, mask) 117 get_input_at(node_index) 118 get_output_at(node_index) 119 get_input_shape_at(node_index) 120 get_output_shape_at(node_index) 121 get_input_mask_at(node_index) 122 get_output_mask_at(node_index) 123 124 # Class Methods 125 from_config(config) 126 127 # Internal methods: 128 build(input_shape) 129 _add_inbound_node(layer, index=0) 130 """ 131 132 def __init__(self, **kwargs): 133 # These properties should be set by the user via keyword arguments. 134 # note that 'dtype', 'input_shape' and 'batch_input_shape' 135 # are only applicable to input layers: do not pass these keywords 136 # to non-input layers. 137 allowed_kwargs = { 138 'activity_regularizer', 139 'input_shape', 140 'batch_input_shape', 141 'batch_size', 142 'dtype', 143 'name', 144 'trainable', 145 'weights', 146 } 147 # Validate optional keyword arguments. 148 for kwarg in kwargs: 149 if kwarg not in allowed_kwargs: 150 raise TypeError('Keyword argument not understood:', kwarg) 151 152 # Get layer name. 153 name = kwargs.get('name') 154 155 # Get `trainable` status. 156 trainable = kwargs.get('trainable', True) 157 158 # Get `dtype`. 159 dtype = kwargs.get('dtype') 160 if dtype is None: 161 dtype = K.floatx() 162 163 # Call super, which will set all properties common to Keras layers 164 # and core TF layers. 165 super(Layer, self).__init__( 166 name=name, dtype=dtype, trainable=trainable, 167 activity_regularizer=kwargs.get('activity_regularizer')) 168 169 # Add properties that are Keras-only for now. 170 self.supports_masking = False 171 172 # Manage input shape information if passed. 173 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 174 # In this case we will later create an input layer 175 # to insert before the current layer 176 if 'batch_input_shape' in kwargs: 177 batch_input_shape = tuple(kwargs['batch_input_shape']) 178 elif 'input_shape' in kwargs: 179 if 'batch_size' in kwargs: 180 batch_size = kwargs['batch_size'] 181 else: 182 batch_size = None 183 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 184 self._batch_input_shape = batch_input_shape 185 186 # Manage initial weight values if passed. 187 if 'weights' in kwargs: 188 self._initial_weights = kwargs['weights'] 189 else: 190 self._initial_weights = None 191 192 def add_weight(self, 193 name, 194 shape, 195 dtype=None, 196 initializer=None, 197 regularizer=None, 198 trainable=True, 199 constraint=None): 200 """Adds a weight variable to the layer. 201 202 Arguments: 203 name: String, the name for the weight variable. 204 shape: The shape tuple of the weight. 205 dtype: The dtype of the weight. 206 initializer: An Initializer instance (callable). 207 regularizer: An optional Regularizer instance. 208 trainable: A boolean, whether the weight should 209 be trained via backprop or not (assuming 210 that the layer itself is also trainable). 211 constraint: An optional Constraint instance. 212 213 Returns: 214 The created weight variable. 215 """ 216 if dtype is None: 217 dtype = K.floatx() 218 weight = self.add_variable(name, shape, 219 dtype=dtype, 220 initializer=initializers.get(initializer), 221 regularizer=regularizers.get(regularizer), 222 constraint=constraints.get(constraint), 223 trainable=trainable) 224 return weight 225 226 def call(self, inputs, **kwargs): # pylint: disable=unused-argument 227 """This is where the layer's logic lives. 228 229 Arguments: 230 inputs: Input tensor, or list/tuple of input tensors. 231 **kwargs: Additional keyword arguments. 232 233 Returns: 234 A tensor or list/tuple of tensors. 235 """ 236 return inputs 237 238 def __call__(self, inputs, **kwargs): 239 """Wrapper around self.call(), for handling internal references. 240 241 If a Keras tensor is passed: 242 - We call self._add_inbound_node(). 243 - If necessary, we `build` the layer to match 244 the shape of the input(s). 245 - We update the _keras_history of the output tensor(s) 246 with the current layer. 247 This is done as part of _add_inbound_node(). 248 249 Arguments: 250 inputs: Can be a tensor or list/tuple of tensors. 251 **kwargs: Additional keyword arguments to be passed to `call()`. 252 253 Returns: 254 Output of the layer's `call` method. 255 256 Raises: 257 ValueError: in case the layer is missing shape information 258 for its `build` call. 259 """ 260 # Actually call the layer (optionally building it). 261 output = super(Layer, self).__call__(inputs, **kwargs) 262 if context.in_eager_mode(): 263 return output 264 265 # Un-built subclassed network: build it 266 if isinstance(self, Network) and not self.inputs: 267 self._set_inputs(inputs, training=kwargs.get('training')) 268 269 # Update learning phase info. 270 output_tensors = _to_list(output) 271 uses_lp = any( 272 [getattr(x, '_uses_learning_phase', False) for x in _to_list(inputs)]) 273 uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp 274 for i in range(len(output_tensors)): 275 output_tensors[i]._uses_learning_phase = getattr( 276 output_tensors[i], '_uses_learning_phase', False) or uses_lp 277 278 # Optionally load weight values that were specified at layer instantiation. 279 if hasattr(self, '_initial_weights') and self._initial_weights is not None: 280 self.set_weights(self._initial_weights) 281 del self._initial_weights 282 return output 283 284 def compute_output_shape(self, input_shape): 285 """Computes the output shape of the layer. 286 287 Assumes that the layer will be built 288 to match that input shape provided. 289 290 Arguments: 291 input_shape: Shape tuple (tuple of integers) 292 or list of shape tuples (one per output tensor of the layer). 293 Shape tuples can include None for free dimensions, 294 instead of an integer. 295 296 Returns: 297 An input shape tuple. 298 """ 299 logging.warning( 300 'All custom layers should implement the ' 301 '`compute_output_shape` method. This layer (' + self.name + ') ' 302 'is relying on the base `Layer.compute_output_shape` implementation, ' 303 'which will start raising a `NotImplementedError` ' 304 'as of July 1st, 2018.') 305 return input_shape 306 307 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 308 """Computes an output mask tensor. 309 310 Arguments: 311 inputs: Tensor or list of tensors. 312 mask: Tensor or list of tensors. 313 314 Returns: 315 None or a tensor (or list of tensors, 316 one per output tensor of the layer). 317 """ 318 if not self.supports_masking: 319 if mask is not None: 320 if isinstance(mask, list): 321 if any(m is not None for m in mask): 322 raise TypeError('Layer ' + self.name + ' does not support masking, ' 323 'but was passed an input_mask: ' + str(mask)) 324 else: 325 raise TypeError('Layer ' + self.name + ' does not support masking, ' 326 'but was passed an input_mask: ' + str(mask)) 327 # masking not explicitly supported: return None as mask 328 return None 329 # if masking is explicitly supported, by default 330 # carry over the input mask 331 return mask 332 333 def get_input_mask_at(self, node_index): 334 """Retrieves the input mask tensor(s) of a layer at a given node. 335 336 Arguments: 337 node_index: Integer, index of the node 338 from which to retrieve the attribute. 339 E.g. `node_index=0` will correspond to the 340 first time the layer was called. 341 342 Returns: 343 A mask tensor 344 (or list of tensors if the layer has multiple inputs). 345 """ 346 inputs = self.get_input_at(node_index) 347 if isinstance(inputs, list): 348 return [getattr(x, '_keras_mask', None) for x in inputs] 349 else: 350 return getattr(inputs, '_keras_mask', None) 351 352 def get_output_mask_at(self, node_index): 353 """Retrieves the output mask tensor(s) of a layer at a given node. 354 355 Arguments: 356 node_index: Integer, index of the node 357 from which to retrieve the attribute. 358 E.g. `node_index=0` will correspond to the 359 first time the layer was called. 360 361 Returns: 362 A mask tensor 363 (or list of tensors if the layer has multiple outputs). 364 """ 365 output = self.get_output_at(node_index) 366 if isinstance(output, list): 367 return [getattr(x, '_keras_mask', None) for x in output] 368 else: 369 return getattr(output, '_keras_mask', None) 370 371 @property 372 def input_mask(self): 373 """Retrieves the input mask tensor(s) of a layer. 374 375 Only applicable if the layer has exactly one inbound node, 376 i.e. if it is connected to one incoming layer. 377 378 Returns: 379 Input mask tensor (potentially None) or list of input 380 mask tensors. 381 382 Raises: 383 AttributeError: if the layer is connected to 384 more than one incoming layers. 385 """ 386 inputs = self.input 387 if isinstance(inputs, list): 388 return [getattr(x, '_keras_mask', None) for x in inputs] 389 else: 390 return getattr(inputs, '_keras_mask', None) 391 392 @property 393 def output_mask(self): 394 """Retrieves the output mask tensor(s) of a layer. 395 396 Only applicable if the layer has exactly one inbound node, 397 i.e. if it is connected to one incoming layer. 398 399 Returns: 400 Output mask tensor (potentially None) or list of output 401 mask tensors. 402 403 Raises: 404 AttributeError: if the layer is connected to 405 more than one incoming layers. 406 """ 407 output = self.output 408 if isinstance(output, list): 409 return [getattr(x, '_keras_mask', None) for x in output] 410 else: 411 return getattr(output, '_keras_mask', None) 412 413 def set_weights(self, weights): 414 """Sets the weights of the layer, from Numpy arrays. 415 416 Arguments: 417 weights: a list of Numpy arrays. The number 418 of arrays and their shape must match 419 number of the dimensions of the weights 420 of the layer (i.e. it should match the 421 output of `get_weights`). 422 423 Raises: 424 ValueError: If the provided weights list does not match the 425 layer's specifications. 426 """ 427 params = self.weights 428 if len(params) != len(weights): 429 raise ValueError('You called `set_weights(weights)` on layer "' + 430 self.name + '" with a weight list of length ' + 431 str(len(weights)) + ', but the layer was expecting ' + 432 str(len(params)) + ' weights. Provided weights: ' + 433 str(weights)[:50] + '...') 434 if not params: 435 return 436 weight_value_tuples = [] 437 param_values = K.batch_get_value(params) 438 for pv, p, w in zip(param_values, params, weights): 439 if pv.shape != w.shape: 440 raise ValueError('Layer weight shape ' + str(pv.shape) + 441 ' not compatible with ' 442 'provided weight shape ' + str(w.shape)) 443 weight_value_tuples.append((p, w)) 444 K.batch_set_value(weight_value_tuples) 445 446 def get_weights(self): 447 """Returns the current weights of the layer. 448 449 Returns: 450 Weights values as a list of numpy arrays. 451 """ 452 params = self.weights 453 return K.batch_get_value(params) 454 455 def get_config(self): 456 """Returns the config of the layer. 457 458 A layer config is a Python dictionary (serializable) 459 containing the configuration of a layer. 460 The same layer can be reinstantiated later 461 (without its trained weights) from this configuration. 462 463 The config of a layer does not include connectivity 464 information, nor the layer class name. These are handled 465 by `Network` (one layer of abstraction above). 466 467 Returns: 468 Python dictionary. 469 """ 470 config = {'name': self.name, 'trainable': self.trainable} 471 if hasattr(self, '_batch_input_shape'): 472 config['batch_input_shape'] = self._batch_input_shape 473 if hasattr(self, 'dtype'): 474 config['dtype'] = self.dtype 475 return config 476 477 @classmethod 478 def from_config(cls, config): 479 """Creates a layer from its config. 480 481 This method is the reverse of `get_config`, 482 capable of instantiating the same layer from the config 483 dictionary. It does not handle layer connectivity 484 (handled by Network), nor weights (handled by `set_weights`). 485 486 Arguments: 487 config: A Python dictionary, typically the 488 output of get_config. 489 490 Returns: 491 A layer instance. 492 """ 493 return cls(**config) 494 495 @tf_base_layers.Layer.activity_regularizer.setter 496 def activity_regularizer(self, activity_regularizer): 497 self._activity_regularizer = activity_regularizer 498 499 500 @tf_export('keras.layers.InputLayer') 501 class InputLayer(tf_network.InputLayer, Layer): 502 """Layer to be used as an entry point into a graph. 503 504 It can either wrap an existing tensor (pass an `input_tensor` argument) 505 or create its a placeholder tensor (pass argument `input_shape`. 506 507 Arguments: 508 input_shape: Shape tuple, not including the batch axis. 509 batch_size: Optional input batch size (integer or None). 510 dtype: Datatype of the input. 511 input_tensor: Optional tensor to use as layer input 512 instead of creating a placeholder. 513 sparse: Boolean, whether the placeholder created 514 is meant to be sparse. 515 name: Name of the layer (string). 516 """ 517 518 def __init__(self, 519 input_shape=None, 520 batch_size=None, 521 dtype=None, 522 input_tensor=None, 523 sparse=False, 524 name=None, 525 **kwargs): 526 if 'batch_input_shape' in kwargs: 527 batch_input_shape = kwargs.pop('batch_input_shape') 528 if input_shape and batch_input_shape: 529 raise ValueError('Only provide the input_shape OR ' 530 'batch_input_shape argument to ' 531 'InputLayer, not both at the same time.') 532 batch_size = batch_input_shape[0] 533 input_shape = batch_input_shape[1:] 534 if kwargs: 535 raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) 536 537 if not name: 538 prefix = 'input' 539 name = prefix + '_' + str(K.get_uid(prefix)) 540 541 if not dtype: 542 if input_tensor is None: 543 dtype = K.floatx() 544 else: 545 dtype = K.dtype(input_tensor) 546 super(InputLayer, self).__init__(input_shape=input_shape, 547 batch_size=batch_size, 548 dtype=dtype, 549 input_tensor=input_tensor, 550 sparse=sparse, 551 name=name) 552 553 def get_config(self): 554 config = { 555 'batch_input_shape': self._batch_input_shape, 556 'dtype': self.dtype, 557 'sparse': self.sparse, 558 'name': self.name 559 } 560 return config 561 562 563 @tf_export('keras.layers.Input', 'keras.Input') 564 def Input( # pylint: disable=invalid-name 565 shape=None, 566 batch_size=None, 567 name=None, 568 dtype=None, 569 sparse=False, 570 tensor=None, 571 **kwargs): 572 """`Input()` is used to instantiate a Keras tensor. 573 574 A Keras tensor is a tensor object from the underlying backend 575 (Theano or TensorFlow), which we augment with certain 576 attributes that allow us to build a Keras model 577 just by knowing the inputs and outputs of the model. 578 579 For instance, if a, b and c are Keras tensors, 580 it becomes possible to do: 581 `model = Model(input=[a, b], output=c)` 582 583 The added Keras attribute is: 584 `_keras_history`: Last layer applied to the tensor. 585 the entire layer graph is retrievable from that layer, 586 recursively. 587 588 Arguments: 589 shape: A shape tuple (integers), not including the batch size. 590 For instance, `shape=(32,)` indicates that the expected input 591 will be batches of 32-dimensional vectors. 592 batch_size: optional static batch size (integer). 593 name: An optional name string for the layer. 594 Should be unique in a model (do not reuse the same name twice). 595 It will be autogenerated if it isn't provided. 596 dtype: The data type expected by the input, as a string 597 (`float32`, `float64`, `int32`...) 598 sparse: A boolean specifying whether the placeholder 599 to be created is sparse. 600 tensor: Optional existing tensor to wrap into the `Input` layer. 601 If set, the layer will not create a placeholder tensor. 602 **kwargs: deprecated arguments support. 603 604 Returns: 605 A tensor. 606 607 Example: 608 609 ```python 610 # this is a logistic regression in Keras 611 x = Input(shape=(32,)) 612 y = Dense(16, activation='softmax')(x) 613 model = Model(x, y) 614 ``` 615 616 Raises: 617 ValueError: in case of invalid arguments. 618 """ 619 if 'batch_shape' in kwargs: 620 batch_shape = kwargs.pop('batch_shape') 621 if shape and batch_shape: 622 raise ValueError('Only provide the shape OR ' 623 'batch_shape argument to ' 624 'Input, not both at the same time.') 625 batch_size = batch_shape[0] 626 shape = batch_shape[1:] 627 if kwargs: 628 raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) 629 630 if dtype is None: 631 dtype = K.floatx() 632 if not shape and tensor is None: 633 raise ValueError('Please provide to Input either a `shape`' 634 ' or a `tensor` argument. Note that ' 635 '`shape` does not include the batch ' 636 'dimension.') 637 input_layer = InputLayer( 638 input_shape=shape, 639 batch_size=batch_size, 640 name=name, 641 dtype=dtype, 642 sparse=sparse, 643 input_tensor=tensor) 644 # Return tensor including `_keras_history`. 645 # Note that in this case train_output and test_output are the same pointer. 646 outputs = input_layer._inbound_nodes[0].output_tensors 647 if len(outputs) == 1: 648 return outputs[0] 649 else: 650 return outputs 651 652 653 class Network(tf_network.GraphNetwork, Layer): 654 """A Network is a directed acyclic graph of layers. 655 656 It is the topological form of a "model". A Model 657 is simply a Network with added training routines. 658 659 # Properties 660 name 661 inputs 662 outputs 663 input_layers 664 output_layers 665 input_spec (list of class instances) 666 each entry describes one required input: 667 - ndim 668 - dtype 669 trainable (boolean) 670 input_shape 671 output_shape 672 inbound_nodes: list of nodes 673 outbound_nodes: list of nodes 674 trainable_weights (list of variables) 675 non_trainable_weights (list of variables) 676 677 # Methods 678 summary 679 get_layer 680 get_weights 681 set_weights 682 get_config 683 compute_output_shape 684 685 # Class Methods 686 from_config 687 """ 688 689 def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called 690 # Signature detection 691 if (len(args) == 2 or 692 len(args) == 1 and 'outputs' in kwargs or 693 'inputs' in kwargs and 'outputs' in kwargs): 694 # Graph network 695 self._init_graph_network(*args, **kwargs) 696 else: 697 # Subclassed network 698 self._init_subclassed_network(**kwargs) 699 700 def _init_graph_network(self, inputs, outputs, name=None): 701 # TODO(fchollet): merge back tf.layers.Network and tf.keras.Network 702 # into a single class tf.keras.Network 703 super(Network, self).__init__(inputs, outputs, name=name) 704 705 self._is_compiled = False 706 self._expects_training_arg = False 707 708 self.supports_masking = False 709 self.optimizer = None 710 711 # Fill in the output mask cache. 712 masks = [] 713 for x in self.inputs: 714 mask = x._keras_mask if hasattr(x, '_keras_mask') else None 715 masks.append(mask) 716 mask_cache_key = (tf_layers_util.object_list_uid(self.inputs) + '_' + 717 tf_layers_util.object_list_uid(masks)) 718 masks = [] 719 for x in self.outputs: 720 mask = x._keras_mask if hasattr(x, '_keras_mask') else None 721 masks.append(mask) 722 if len(masks) == 1: 723 mask = masks[0] 724 else: 725 mask = masks 726 self._output_mask_cache[mask_cache_key] = mask 727 728 # Build self.input_names and self.output_names. 729 self.input_names = [] 730 self.output_names = [] 731 self._feed_input_names = [] 732 self._feed_inputs = [] 733 self._feed_input_shapes = [] 734 for i, layer in enumerate(self._input_layers): 735 self.input_names.append(layer.name) 736 if layer.is_placeholder: 737 self._feed_input_names.append(layer.name) 738 self._feed_input_shapes.append(K.int_shape(self.inputs[i])) 739 # layer.input gives an error in eager mode 740 if context.in_graph_mode(): 741 self._feed_inputs.append(layer.input) 742 for layer in self._output_layers: 743 self.output_names.append(layer.name) 744 745 def _init_subclassed_network(self, name=None): 746 self._init_set_name(name) 747 self._layers = [] 748 self._is_graph_network = False 749 self._is_compiled = False 750 if 'training' in tf_inspect.getargspec(self.call).args: 751 self._expects_training_arg = True 752 else: 753 self._expects_training_arg = False 754 755 self.outputs = None 756 self.inputs = None 757 self.trainable = True 758 self.supports_masking = False 759 self.built = False 760 self.optimizer = None 761 762 # Not used, exists for compatibility purposes due to implementation of 763 # the base layer tf.layers.Layer - TODO(fchollet): clean up when refactoring 764 self._scope = None 765 self._reuse = None 766 self._dtype = None 767 self._graph = None 768 self._activity_regularizer = None 769 770 # Used in symbolic mode only 771 self._updates = [] 772 self._losses = [] 773 774 # Used in symbolic mode only, only in conjonction with graph-networks 775 self._outbound_nodes = [] 776 self._inbound_nodes = [] 777 778 def __setattr__(self, name, value): 779 if isinstance(value, (tf_base_layers.Layer, Network)): 780 try: 781 is_graph_network = self._is_graph_network 782 except AttributeError: 783 raise RuntimeError('It looks like you are subclassing `Model` and you ' 784 'forgot to call `super(YourClass, self).__init__()`.' 785 ' Always start with this line.') 786 if not is_graph_network: 787 if value not in self._layers: 788 self._layers.append(value) 789 super(Network, self).__setattr__(name, value) 790 791 def add_variable(self, name, shape, dtype=None, initializer=None, 792 regularizer=None, trainable=True, constraint=None): 793 raise NotImplementedError('`add_variable` is not supported on Networks') 794 795 def add_loss(self, *args, **kwargs): 796 if context.in_eager_mode(): 797 raise NotImplementedError('`add_loss` is not supported in eager-mode ' 798 'on Networks') 799 super(Network, self).add_loss(*args, **kwargs) 800 801 @property 802 def uses_learning_phase(self): 803 return any( 804 [getattr(x, '_uses_learning_phase', False) for x in self.outputs]) 805 806 @property 807 def stateful(self): 808 return any([(hasattr(layer, 'stateful') and layer.stateful) 809 for layer in self.layers]) 810 811 def reset_states(self): 812 for layer in self.layers: 813 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 814 layer.reset_states() 815 816 @property 817 def state_updates(self): 818 """Returns the `updates` from all layers that are stateful. 819 820 This is useful for separating training updates and 821 state updates, e.g. when we need to update a layer's internal state 822 during prediction. 823 824 Returns: 825 A list of update ops. 826 """ 827 state_updates = [] 828 for layer in self.layers: 829 if getattr(layer, 'stateful', False): 830 if hasattr(layer, 'updates'): 831 state_updates += layer.updates 832 return state_updates 833 834 def get_weights(self): 835 """Retrieves the weights of the model. 836 837 Returns: 838 A flat list of Numpy arrays. 839 """ 840 weights = [] 841 for layer in self.layers: 842 weights += layer.weights 843 return K.batch_get_value(weights) 844 845 def set_weights(self, weights): 846 """Sets the weights of the model. 847 848 Arguments: 849 weights: A list of Numpy arrays with shapes and types matching 850 the output of `model.get_weights()`. 851 """ 852 tuples = [] 853 for layer in self.layers: 854 num_param = len(layer.weights) 855 layer_weights = weights[:num_param] 856 for sw, w in zip(layer.weights, layer_weights): 857 tuples.append((sw, w)) 858 weights = weights[num_param:] 859 K.batch_set_value(tuples) 860 861 def compute_mask(self, inputs, mask): 862 if not self._is_graph_network: 863 return None 864 865 inputs = _to_list(inputs) 866 if mask is None: 867 masks = [None for _ in range(len(inputs))] 868 else: 869 masks = _to_list(mask) 870 cache_key = (tf_layers_util.object_list_uid(inputs) 871 + '_' + tf_layers_util.object_list_uid(masks)) 872 if cache_key in self._output_mask_cache: 873 return self._output_mask_cache[cache_key] 874 else: 875 _, output_masks = self._run_internal_graph(inputs, masks) 876 return output_masks 877 878 def get_config(self): 879 if not self._is_graph_network: 880 raise NotImplementedError 881 882 config = { 883 'name': self.name, 884 } 885 node_conversion_map = {} 886 for layer in self.layers: 887 if issubclass(layer.__class__, Network): 888 # Networks start with a pre-existing node 889 # linking their input to output. 890 kept_nodes = 1 891 else: 892 kept_nodes = 0 893 for original_node_index, node in enumerate(layer._inbound_nodes): 894 node_key = tf_network._make_node_key(layer.name, 895 original_node_index) 896 if node_key in self._network_nodes: 897 node_conversion_map[node_key] = kept_nodes 898 kept_nodes += 1 899 layer_configs = [] 900 for layer in self.layers: # From the earliest layers on. 901 layer_class_name = layer.__class__.__name__ 902 layer_config = layer.get_config() 903 filtered_inbound_nodes = [] 904 for original_node_index, node in enumerate(layer._inbound_nodes): 905 node_key = tf_network._make_node_key(layer.name, 906 original_node_index) 907 if node_key in self._network_nodes: 908 # The node is relevant to the model: 909 # add to filtered_inbound_nodes. 910 if node.arguments: 911 try: 912 json.dumps(node.arguments) 913 kwargs = node.arguments 914 except TypeError: 915 logging.warning( 916 'Layer ' + layer.name + 917 ' was passed non-serializable keyword arguments: ' + 918 str(node.arguments) + '. They will not be included ' 919 'in the serialized model (and thus will be missing ' 920 'at deserialization time).') 921 kwargs = {} 922 else: 923 kwargs = {} 924 if node.inbound_layers: 925 node_data = [] 926 for i in range(len(node.inbound_layers)): 927 inbound_layer = node.inbound_layers[i] 928 node_index = node.node_indices[i] 929 tensor_index = node.tensor_indices[i] 930 node_key = tf_network._make_node_key(inbound_layer.name, 931 node_index) 932 new_node_index = node_conversion_map.get(node_key, 0) 933 node_data.append( 934 [inbound_layer.name, new_node_index, tensor_index, kwargs]) 935 filtered_inbound_nodes.append(node_data) 936 layer_configs.append({ 937 'name': layer.name, 938 'class_name': layer_class_name, 939 'config': layer_config, 940 'inbound_nodes': filtered_inbound_nodes, 941 }) 942 config['layers'] = layer_configs 943 944 # Gather info about inputs and outputs. 945 model_inputs = [] 946 for i in range(len(self._input_layers)): 947 layer, node_index, tensor_index = self._input_coordinates[i] 948 node_key = tf_network._make_node_key(layer.name, 949 node_index) 950 if node_key not in self._network_nodes: 951 continue 952 new_node_index = node_conversion_map[node_key] 953 model_inputs.append([layer.name, new_node_index, tensor_index]) 954 config['input_layers'] = model_inputs 955 model_outputs = [] 956 for i in range(len(self._output_layers)): 957 layer, node_index, tensor_index = self._output_coordinates[i] 958 node_key = tf_network._make_node_key(layer.name, 959 node_index) 960 if node_key not in self._network_nodes: 961 continue 962 new_node_index = node_conversion_map[node_key] 963 model_outputs.append([layer.name, new_node_index, tensor_index]) 964 config['output_layers'] = model_outputs 965 return copy.deepcopy(config) 966 967 @classmethod 968 def from_config(cls, config, custom_objects=None): 969 """Instantiates a Model from its config (output of `get_config()`). 970 971 Arguments: 972 config: Model config dictionary. 973 custom_objects: Optional dictionary mapping names 974 (strings) to custom classes or functions to be 975 considered during deserialization. 976 977 Returns: 978 A model instance. 979 980 Raises: 981 ValueError: In case of improperly formatted config dict. 982 """ 983 # Layer instances created during 984 # the graph reconstruction process 985 created_layers = {} 986 987 # Dictionary mapping layer instances to 988 # node data that specifies a layer call. 989 # It acts as a queue that maintains any unprocessed 990 # layer call until it becomes possible to process it 991 # (i.e. until the input tensors to the call all exist). 992 unprocessed_nodes = {} 993 994 def add_unprocessed_node(layer, node_data): 995 if layer not in unprocessed_nodes: 996 unprocessed_nodes[layer] = [node_data] 997 else: 998 unprocessed_nodes[layer].append(node_data) 999 1000 def process_node(layer, node_data): 1001 """Deserialize a node. 1002 1003 Arguments: 1004 layer: layer instance. 1005 node_data: node config dict. 1006 1007 Raises: 1008 ValueError: In case of improperly formatted `node_data` dict. 1009 """ 1010 input_tensors = [] 1011 for input_data in node_data: 1012 inbound_layer_name = input_data[0] 1013 inbound_node_index = input_data[1] 1014 inbound_tensor_index = input_data[2] 1015 if len(input_data) == 3: 1016 kwargs = {} 1017 elif len(input_data) == 4: 1018 kwargs = input_data[3] 1019 else: 1020 raise ValueError('Improperly formatted model config.') 1021 if inbound_layer_name not in created_layers: 1022 add_unprocessed_node(layer, node_data) 1023 return 1024 inbound_layer = created_layers[inbound_layer_name] 1025 if len(inbound_layer._inbound_nodes) <= inbound_node_index: 1026 add_unprocessed_node(layer, node_data) 1027 return 1028 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 1029 input_tensors.append(inbound_node.output_tensors[inbound_tensor_index]) 1030 # Call layer on its inputs, thus creating the node 1031 # and building the layer if needed. 1032 if input_tensors: 1033 if len(input_tensors) == 1: 1034 layer(input_tensors[0], **kwargs) 1035 else: 1036 layer(input_tensors, **kwargs) 1037 1038 def process_layer(layer_data): 1039 """Deserialize a layer, then call it on appropriate inputs. 1040 1041 Arguments: 1042 layer_data: layer config dict. 1043 1044 Raises: 1045 ValueError: In case of improperly formatted `layer_data` dict. 1046 """ 1047 layer_name = layer_data['name'] 1048 1049 # Instantiate layer. 1050 from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 1051 1052 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 1053 created_layers[layer_name] = layer 1054 1055 # Gather layer inputs. 1056 inbound_nodes_data = layer_data['inbound_nodes'] 1057 for node_data in inbound_nodes_data: 1058 # We don't process nodes (i.e. make layer calls) 1059 # on the fly because the inbound node may not yet exist, 1060 # in case of layer shared at different topological depths 1061 # (e.g. a model such as A(B(A(B(x))))) 1062 add_unprocessed_node(layer, node_data) 1063 1064 # First, we create all layers and enqueue nodes to be processed 1065 for layer_data in config['layers']: 1066 process_layer(layer_data) 1067 # Then we process nodes in order of layer depth. 1068 # Nodes that cannot yet be processed (if the inbound node 1069 # does not yet exist) are re-enqueued, and the process 1070 # is repeated until all nodes are processed. 1071 while unprocessed_nodes: 1072 for layer_data in config['layers']: 1073 layer = created_layers[layer_data['name']] 1074 if layer in unprocessed_nodes: 1075 for node_data in unprocessed_nodes.pop(layer): 1076 process_node(layer, node_data) 1077 1078 name = config.get('name') 1079 input_tensors = [] 1080 output_tensors = [] 1081 for layer_data in config['input_layers']: 1082 layer_name, node_index, tensor_index = layer_data 1083 assert layer_name in created_layers 1084 layer = created_layers[layer_name] 1085 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1086 input_tensors.append(layer_output_tensors[tensor_index]) 1087 for layer_data in config['output_layers']: 1088 layer_name, node_index, tensor_index = layer_data 1089 assert layer_name in created_layers 1090 layer = created_layers[layer_name] 1091 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1092 output_tensors.append(layer_output_tensors[tensor_index]) 1093 return cls(inputs=input_tensors, outputs=output_tensors, name=name) 1094 1095 def save(self, filepath, overwrite=True, include_optimizer=True): 1096 """Save the model to a single HDF5 file. 1097 1098 The savefile includes: 1099 - The model architecture, allowing to re-instantiate the model. 1100 - The model weights. 1101 - The state of the optimizer, allowing to resume training 1102 exactly where you left off. 1103 1104 This allows you to save the entirety of the state of a model 1105 in a single file. 1106 1107 Saved models can be reinstantiated via `keras.models.load_model`. 1108 The model returned by `load_model` 1109 is a compiled model ready to be used (unless the saved model 1110 was never compiled in the first place). 1111 1112 Arguments: 1113 filepath: String, path to the file to save the weights to. 1114 overwrite: Whether to silently overwrite any existing file at the 1115 target location, or provide the user with a manual prompt. 1116 include_optimizer: If True, save optimizer's state together. 1117 1118 Example: 1119 1120 ```python 1121 from keras.models import load_model 1122 1123 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 1124 del model # deletes the existing model 1125 1126 # returns a compiled model 1127 # identical to the previous one 1128 model = load_model('my_model.h5') 1129 ``` 1130 """ 1131 if not self._is_graph_network: 1132 raise NotImplementedError 1133 1134 from tensorflow.python.keras._impl.keras.models import save_model # pylint: disable=g-import-not-at-top 1135 save_model(self, filepath, overwrite, include_optimizer) 1136 1137 def save_weights(self, filepath, overwrite=True): 1138 """Dumps all layer weights to a HDF5 file. 1139 1140 The weight file has: 1141 - `layer_names` (attribute), a list of strings 1142 (ordered names of model layers). 1143 - For every layer, a `group` named `layer.name` 1144 - For every such layer group, a group attribute `weight_names`, 1145 a list of strings 1146 (ordered names of weights tensor of the layer). 1147 - For every weight in the layer, a dataset 1148 storing the weight value, named after the weight tensor. 1149 1150 Arguments: 1151 filepath: String, path to the file to save the weights to. 1152 overwrite: Whether to silently overwrite any existing file at the 1153 target location, or provide the user with a manual prompt. 1154 1155 Raises: 1156 ImportError: If h5py is not available. 1157 """ 1158 if h5py is None: 1159 raise ImportError('`save_weights` requires h5py.') 1160 # If file exists and should not be overwritten: 1161 if not overwrite and os.path.isfile(filepath): 1162 proceed = ask_to_proceed_with_overwrite(filepath) 1163 if not proceed: 1164 return 1165 with h5py.File(filepath, 'w') as f: 1166 save_weights_to_hdf5_group(f, self.layers) 1167 1168 def load_weights(self, filepath, by_name=False): 1169 """Loads all layer weights from a HDF5 save file. 1170 1171 If `by_name` is False (default) weights are loaded 1172 based on the network's topology, meaning the architecture 1173 should be the same as when the weights were saved. 1174 Note that layers that don't have weights are not taken 1175 into account in the topological ordering, so adding or 1176 removing layers is fine as long as they don't have weights. 1177 1178 If `by_name` is True, weights are loaded into layers 1179 only if they share the same name. This is useful 1180 for fine-tuning or transfer-learning models where 1181 some of the layers have changed. 1182 1183 Arguments: 1184 filepath: String, path to the weights file to load. 1185 by_name: Boolean, whether to load weights by name 1186 or by topological order. 1187 1188 Raises: 1189 ImportError: If h5py is not available. 1190 """ 1191 if h5py is None: 1192 raise ImportError('`load_weights` requires h5py.') 1193 with h5py.File(filepath, 'r') as f: 1194 if 'layer_names' not in f.attrs and 'model_weights' in f: 1195 f = f['model_weights'] 1196 if by_name: 1197 load_weights_from_hdf5_group_by_name(f, self.layers) 1198 else: 1199 load_weights_from_hdf5_group(f, self.layers) 1200 1201 def _updated_config(self): 1202 """Util hared between different serialization methods. 1203 1204 Returns: 1205 Model config with Keras version information added. 1206 """ 1207 from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 1208 1209 config = self.get_config() 1210 model_config = { 1211 'class_name': self.__class__.__name__, 1212 'config': config, 1213 'keras_version': keras_version, 1214 'backend': K.backend() 1215 } 1216 return model_config 1217 1218 def to_json(self, **kwargs): 1219 """Returns a JSON string containing the network configuration. 1220 1221 To load a network from a JSON save file, use 1222 `keras.models.model_from_json(json_string, custom_objects={})`. 1223 1224 Arguments: 1225 **kwargs: Additional keyword arguments 1226 to be passed to `json.dumps()`. 1227 1228 Returns: 1229 A JSON string. 1230 """ 1231 if not self._is_graph_network: 1232 raise NotImplementedError 1233 1234 def get_json_type(obj): 1235 # If obj is any numpy type 1236 if type(obj).__module__ == np.__name__: 1237 return obj.item() 1238 1239 # If obj is a python 'type' 1240 if type(obj).__name__ == type.__name__: 1241 return obj.__name__ 1242 1243 raise TypeError('Not JSON Serializable:', obj) 1244 1245 model_config = self._updated_config() 1246 return json.dumps(model_config, default=get_json_type, **kwargs) 1247 1248 def to_yaml(self, **kwargs): 1249 """Returns a yaml string containing the network configuration. 1250 1251 To load a network from a yaml save file, use 1252 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 1253 1254 `custom_objects` should be a dictionary mapping 1255 the names of custom losses / layers / etc to the corresponding 1256 functions / classes. 1257 1258 Arguments: 1259 **kwargs: Additional keyword arguments 1260 to be passed to `yaml.dump()`. 1261 1262 Returns: 1263 A YAML string. 1264 1265 Raises: 1266 ImportError: if yaml module is not found. 1267 """ 1268 if not self._is_graph_network: 1269 raise NotImplementedError 1270 1271 if yaml is None: 1272 raise ImportError('Requires yaml module installed.') 1273 return yaml.dump(self._updated_config(), **kwargs) 1274 1275 def summary(self, line_length=None, positions=None, print_fn=None): 1276 """Prints a string summary of the network. 1277 1278 Arguments: 1279 line_length: Total length of printed lines 1280 (e.g. set this to adapt the display to different 1281 terminal window sizes). 1282 positions: Relative or absolute positions of log elements 1283 in each line. If not provided, 1284 defaults to `[.33, .55, .67, 1.]`. 1285 print_fn: Print function to use. Defaults to `print`. 1286 It will be called on each line of the summary. 1287 You can set it to a custom function 1288 in order to capture the string summary. 1289 """ 1290 print_layer_summary(self, 1291 line_length=line_length, 1292 positions=positions, 1293 print_fn=print_fn) 1294 1295 1296 def get_source_inputs(tensor, layer=None, node_index=None): 1297 """Returns the list of input tensors necessary to compute `tensor`. 1298 1299 Output will always be a list of tensors 1300 (potentially with 1 element). 1301 1302 Arguments: 1303 tensor: The tensor to start from. 1304 layer: Origin layer of the tensor. Will be 1305 determined via tensor._keras_history if not provided. 1306 node_index: Origin node index of the tensor. 1307 1308 Returns: 1309 List of input tensors. 1310 """ 1311 if not hasattr(tensor, '_keras_history'): 1312 return tensor 1313 1314 if layer is None or node_index: 1315 layer, node_index, _ = tensor._keras_history 1316 if not layer._inbound_nodes: 1317 return [tensor] 1318 else: 1319 node = layer._inbound_nodes[node_index] 1320 if not node.inbound_layers: 1321 # Reached an Input layer, stop recursion. 1322 return node.input_tensors 1323 else: 1324 source_tensors = [] 1325 for i in range(len(node.inbound_layers)): 1326 x = node.input_tensors[i] 1327 layer = node.inbound_layers[i] 1328 node_index = node.node_indices[i] 1329 previous_sources = get_source_inputs(x, layer, node_index) 1330 # Avoid input redundancy. 1331 for x in previous_sources: 1332 if x not in source_tensors: 1333 source_tensors.append(x) 1334 return source_tensors 1335 1336 1337 def _to_list(x): 1338 """Normalizes a list/tensor into a list. 1339 1340 If a tensor is passed, we return 1341 a list of size 1 containing the tensor. 1342 1343 Arguments: 1344 x: target object to be normalized. 1345 1346 Returns: 1347 A list. 1348 """ 1349 if isinstance(x, list): 1350 return x 1351 return [x] 1352 1353 1354 def save_weights_to_hdf5_group(f, layers): 1355 from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 1356 1357 f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers] 1358 f.attrs['backend'] = K.backend().encode('utf8') 1359 f.attrs['keras_version'] = str(keras_version).encode('utf8') 1360 1361 for layer in layers: 1362 g = f.create_group(layer.name) 1363 symbolic_weights = layer.weights 1364 weight_values = K.batch_get_value(symbolic_weights) 1365 weight_names = [] 1366 for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): 1367 if hasattr(w, 'name') and w.name: 1368 name = str(w.name) 1369 else: 1370 name = 'param_' + str(i) 1371 weight_names.append(name.encode('utf8')) 1372 g.attrs['weight_names'] = weight_names 1373 for name, val in zip(weight_names, weight_values): 1374 param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) 1375 if not val.shape: 1376 # scalar 1377 param_dset[()] = val 1378 else: 1379 param_dset[:] = val 1380 1381 1382 def preprocess_weights_for_loading(layer, 1383 weights, 1384 original_keras_version=None, 1385 original_backend=None): 1386 """Converts layers weights from Keras 1 format to Keras 2. 1387 1388 Arguments: 1389 layer: Layer instance. 1390 weights: List of weights values (Numpy arrays). 1391 original_keras_version: Keras version for the weights, as a string. 1392 original_backend: Keras backend the weights were trained with, 1393 as a string. 1394 1395 Returns: 1396 A list of weights values (Numpy arrays). 1397 """ 1398 if layer.__class__.__name__ == 'Bidirectional': 1399 num_weights_per_layer = len(weights) // 2 1400 forward_weights = preprocess_weights_for_loading( 1401 layer.forward_layer, weights[:num_weights_per_layer], 1402 original_keras_version, original_backend) 1403 backward_weights = preprocess_weights_for_loading( 1404 layer.backward_layer, weights[num_weights_per_layer:], 1405 original_keras_version, original_backend) 1406 weights = forward_weights + backward_weights 1407 1408 if original_keras_version == '1': 1409 if layer.__class__.__name__ == 'TimeDistributed': 1410 weights = preprocess_weights_for_loading( 1411 layer.layer, weights, original_keras_version, original_backend) 1412 1413 if layer.__class__.__name__ == 'Conv1D': 1414 shape = weights[0].shape 1415 # Handle Keras 1.1 format 1416 if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: 1417 # Legacy shape: 1418 # (filters, input_dim, filter_length, 1) 1419 assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], 1420 1) 1421 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 1422 weights[0] = weights[0][:, 0, :, :] 1423 1424 if layer.__class__.__name__ == 'Conv2D': 1425 if layer.data_format == 'channels_first': 1426 # old: (filters, stack_size, kernel_rows, kernel_cols) 1427 # new: (kernel_rows, kernel_cols, stack_size, filters) 1428 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 1429 1430 if layer.__class__.__name__ == 'Conv2DTranspose': 1431 if layer.data_format == 'channels_last': 1432 # old: (kernel_rows, kernel_cols, stack_size, filters) 1433 # new: (kernel_rows, kernel_cols, filters, stack_size) 1434 weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) 1435 if layer.data_format == 'channels_first': 1436 # old: (filters, stack_size, kernel_rows, kernel_cols) 1437 # new: (kernel_rows, kernel_cols, filters, stack_size) 1438 weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) 1439 1440 if layer.__class__.__name__ == 'Conv3D': 1441 if layer.data_format == 'channels_first': 1442 # old: (filters, stack_size, ...) 1443 # new: (..., stack_size, filters) 1444 weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) 1445 1446 if layer.__class__.__name__ == 'GRU': 1447 if len(weights) == 9: 1448 kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) 1449 recurrent_kernel = np.concatenate( 1450 [weights[1], weights[4], weights[7]], axis=-1) 1451 bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) 1452 weights = [kernel, recurrent_kernel, bias] 1453 1454 if layer.__class__.__name__ == 'LSTM': 1455 if len(weights) == 12: 1456 # old: i, c, f, o 1457 # new: i, f, c, o 1458 kernel = np.concatenate( 1459 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 1460 recurrent_kernel = np.concatenate( 1461 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 1462 bias = np.concatenate( 1463 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 1464 weights = [kernel, recurrent_kernel, bias] 1465 1466 if layer.__class__.__name__ == 'ConvLSTM2D': 1467 if len(weights) == 12: 1468 kernel = np.concatenate( 1469 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 1470 recurrent_kernel = np.concatenate( 1471 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 1472 bias = np.concatenate( 1473 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 1474 if layer.data_format == 'channels_first': 1475 # old: (filters, stack_size, kernel_rows, kernel_cols) 1476 # new: (kernel_rows, kernel_cols, stack_size, filters) 1477 kernel = np.transpose(kernel, (2, 3, 1, 0)) 1478 recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) 1479 weights = [kernel, recurrent_kernel, bias] 1480 1481 if layer.__class__.__name__ in ['Model', 'Sequential']: 1482 new_weights = [] 1483 # trainable weights 1484 for sublayer in layer.layers: 1485 num_weights = len(sublayer.trainable_weights) 1486 if num_weights > 0: 1487 new_weights.extend( 1488 preprocess_weights_for_loading( 1489 layer=sublayer, 1490 weights=weights[:num_weights], 1491 original_keras_version=original_keras_version, 1492 original_backend=original_backend)) 1493 weights = weights[num_weights:] 1494 1495 # non-trainable weights 1496 for sublayer in layer.layers: 1497 num_weights = len([ 1498 l for l in sublayer.weights if l not in sublayer.trainable_weights 1499 ]) 1500 if num_weights > 0: 1501 new_weights.extend( 1502 preprocess_weights_for_loading( 1503 layer=sublayer, 1504 weights=weights[:num_weights], 1505 original_keras_version=original_keras_version, 1506 original_backend=original_backend)) 1507 weights = weights[num_weights:] 1508 weights = new_weights 1509 1510 conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] 1511 if layer.__class__.__name__ in conv_layers: 1512 if original_backend == 'theano': 1513 weights[0] = conv_utils.convert_kernel(weights[0]) 1514 if layer.__class__.__name__ == 'ConvLSTM2D': 1515 weights[1] = conv_utils.convert_kernel(weights[1]) 1516 if K.int_shape(layer.weights[0]) != weights[0].shape: 1517 weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) 1518 if layer.__class__.__name__ == 'ConvLSTM2D': 1519 weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) 1520 1521 # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM 1522 if layer.__class__.__name__ == 'LSTM' and len(weights) == 3: 1523 # Determine if loading a CuDNNLSTM layer from the number of bias weights: 1524 # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) 1525 # if there's no bias weight in the file, skip this conversion 1526 units = weights[1].shape[0] 1527 bias = weights[2] 1528 if len(bias) == units * 8: 1529 # reshape the kernels 1530 kernels = np.split(weights[0], 4, axis=1) 1531 kernels = [ 1532 kernel.reshape(-1).reshape(kernel.shape, order='F') 1533 for kernel in kernels 1534 ] 1535 weights[0] = np.concatenate(kernels, axis=1) 1536 1537 # transpose the recurrent kernels 1538 recurrent_kernels = np.split(weights[1], 4, axis=1) 1539 recurrent_kernels = [kernel.T for kernel in recurrent_kernels] 1540 weights[1] = np.concatenate(recurrent_kernels, axis=1) 1541 1542 # split the bias into half and merge 1543 weights[2] = bias[:units * 4] + bias[units * 4:] 1544 1545 return weights 1546 1547 1548 def load_weights_from_hdf5_group(f, layers): 1549 """Implements topological (order-based) weight loading. 1550 1551 Arguments: 1552 f: A pointer to a HDF5 group. 1553 layers: a list of target layers. 1554 1555 Raises: 1556 ValueError: in case of mismatch between provided layers 1557 and weights file. 1558 """ 1559 if 'keras_version' in f.attrs: 1560 original_keras_version = f.attrs['keras_version'].decode('utf8') 1561 else: 1562 original_keras_version = '1' 1563 if 'backend' in f.attrs: 1564 original_backend = f.attrs['backend'].decode('utf8') 1565 else: 1566 original_backend = None 1567 1568 filtered_layers = [] 1569 for layer in layers: 1570 weights = layer.weights 1571 if weights: 1572 filtered_layers.append(layer) 1573 1574 layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] 1575 filtered_layer_names = [] 1576 for name in layer_names: 1577 g = f[name] 1578 weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] 1579 if weight_names: 1580 filtered_layer_names.append(name) 1581 layer_names = filtered_layer_names 1582 if len(layer_names) != len(filtered_layers): 1583 raise ValueError('You are trying to load a weight file ' 1584 'containing ' + str(len(layer_names)) + 1585 ' layers into a model with ' + str(len(filtered_layers)) + 1586 ' layers.') 1587 1588 # We batch weight value assignments in a single backend call 1589 # which provides a speedup in TensorFlow. 1590 weight_value_tuples = [] 1591 for k, name in enumerate(layer_names): 1592 g = f[name] 1593 weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] 1594 weight_values = [g[weight_name] for weight_name in weight_names] 1595 layer = filtered_layers[k] 1596 symbolic_weights = layer.weights 1597 weight_values = preprocess_weights_for_loading( 1598 layer, weight_values, original_keras_version, original_backend) 1599 if len(weight_values) != len(symbolic_weights): 1600 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 1601 '" in the current model) was found to ' 1602 'correspond to layer ' + name + ' in the save file. ' 1603 'However the new layer ' + layer.name + ' expects ' + 1604 str(len(symbolic_weights)) + 1605 ' weights, but the saved weights have ' + 1606 str(len(weight_values)) + ' elements.') 1607 weight_value_tuples += zip(symbolic_weights, weight_values) 1608 K.batch_set_value(weight_value_tuples) 1609 1610 1611 def load_weights_from_hdf5_group_by_name(f, layers): 1612 """Implements name-based weight loading. 1613 1614 (instead of topological weight loading). 1615 1616 Layers that have no matching name are skipped. 1617 1618 Arguments: 1619 f: A pointer to a HDF5 group. 1620 layers: a list of target layers. 1621 1622 Raises: 1623 ValueError: in case of mismatch between provided layers 1624 and weights file. 1625 """ 1626 if 'keras_version' in f.attrs: 1627 original_keras_version = f.attrs['keras_version'].decode('utf8') 1628 else: 1629 original_keras_version = '1' 1630 if 'backend' in f.attrs: 1631 original_backend = f.attrs['backend'].decode('utf8') 1632 else: 1633 original_backend = None 1634 1635 # New file format. 1636 layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] 1637 1638 # Reverse index of layer name to list of layers with name. 1639 index = {} 1640 for layer in layers: 1641 if layer.name: 1642 index.setdefault(layer.name, []).append(layer) 1643 1644 # We batch weight value assignments in a single backend call 1645 # which provides a speedup in TensorFlow. 1646 weight_value_tuples = [] 1647 for k, name in enumerate(layer_names): 1648 g = f[name] 1649 weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] 1650 weight_values = [g[weight_name] for weight_name in weight_names] 1651 1652 for layer in index.get(name, []): 1653 symbolic_weights = layer.weights 1654 weight_values = preprocess_weights_for_loading( 1655 layer, weight_values, original_keras_version, original_backend) 1656 if len(weight_values) != len(symbolic_weights): 1657 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 1658 '") expects ' + str(len(symbolic_weights)) + 1659 ' weight(s), but the saved weights' + ' have ' + 1660 str(len(weight_values)) + ' element(s).') 1661 # Set values. 1662 for i in range(len(weight_values)): 1663 weight_value_tuples.append((symbolic_weights[i], weight_values[i])) 1664 K.batch_set_value(weight_value_tuples) 1665 1666 1667 def shape_type_conversion(fn): 1668 """Decorator that handles tuple/TensorShape conversion. 1669 1670 Used in `compute_output_shape` and `build`. 1671 1672 Arguments: 1673 fn: function to wrap. 1674 1675 Returns: 1676 Wrapped function. 1677 """ 1678 1679 def wrapper(instance, input_shape): 1680 if input_shape is not None: 1681 if isinstance(input_shape, list): 1682 input_shape = [ 1683 tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] 1684 else: 1685 input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) 1686 output_shape = fn(instance, input_shape) 1687 if output_shape is not None: 1688 if isinstance(output_shape, list): 1689 return [tensor_shape.TensorShape(x) for x in output_shape] 1690 return tensor_shape.TensorShape(output_shape) 1691 1692 return wrapper 1693