Home | History | Annotate | Download | only in keras

Lines Matching refs:model

16 """Code for model cloning, plus model-related API entries.
40 Model = training.Model # pylint: disable=invalid-name
53 def _clone_functional_model(model, input_tensors=None, share_weights=False):
54 """Clone a functional `Model` instance.
56 Model cloning is similar to calling a model on new inputs,
61 model: Instance of `Model`.
63 to build the model upon. If not provided,
66 cloned and original model. Note this still clones the input layers.
67 This is required when we create a per-replica copy of the model with
72 An instance of `Model` reproducing the behavior
73 of the original model, on top of new inputs tensors,
77 ValueError: in case of invalid `model` argument value.
79 if not isinstance(model, Model):
80 raise ValueError('Expected `model` argument '
81 'to be a `Model` instance, got ', model)
82 if isinstance(model, Sequential):
83 raise ValueError('Expected `model` argument '
84 'to be a functional `Model` instance, '
85 'got a `Sequential` instance instead:', model)
90 # Create placeholders to build the model on top of.
92 for layer in model._input_layers:
110 original_input_layer = model._input_layers[i]
123 for x, y in zip(model.inputs, input_tensors):
126 # Iterated over every node in the reference model, in depth order.
127 depth_keys = list(model._nodes_by_depth.keys())
130 nodes = model._nodes_by_depth[depth]
163 # Check that we did compute the model outputs,
164 # then instantiate a new model from inputs and outputs.
166 for x in model.outputs:
170 input_tensors = nest.pack_sequence_as(model._nested_inputs, input_tensors)
171 output_tensors = nest.pack_sequence_as(model._nested_outputs, output_tensors)
172 return Model(input_tensors, output_tensors, name=model.name)
175 def _clone_sequential_model(model, input_tensors=None, share_weights=False):
176 """Clone a `Sequential` model instance.
178 Model cloning is similar to calling a model on new inputs,
183 model: Instance of `Sequential`.
185 to build the model upon. If not provided,
188 cloned and original model. Note this still clones the input layers.
189 This is required when we create a per-replica copy of the model with
195 of the original model, on top of new inputs tensors,
199 ValueError: in case of invalid `model` argument value.
201 if not isinstance(model, Sequential):
202 raise ValueError('Expected `model` argument '
203 'to be a `Sequential` model instance, '
204 'but got:', model)
206 # Use model._layers to ensure that all layers are cloned. The model's layers
207 # property will exclude the initial InputLayer (if it exists) in the model,
208 # resulting in a different Sequential model structure.
213 for layer in model._layers:
219 layers = [_clone_layer(layer) for layer in model._layers]
220 return Sequential(layers=layers, name=model.name)
222 # If input tensors are provided, the original model's InputLayer is
225 layer for layer in model._layers if not isinstance(layer, InputLayer)]
229 raise ValueError('To clone a `Sequential` model, we expect '
239 return Sequential(layers=[origin_layer] + layers, name=model.name)
241 raise ValueError('Cannot clone a `Sequential` model on top '
247 return Sequential(layers=[input_layer] + layers, name=model.name)
251 def clone_model(model, input_tensors=None):
252 """Clone any `Model` instance.
254 Model cloning is similar to calling a model on new inputs,
259 model: Instance of `Model`
260 (could be a functional model or a Sequential model).
262 to build the model upon. If not provided,
266 An instance of `Model` reproducing the behavior
267 of the original model, on top of new inputs tensors,
271 ValueError: in case of invalid `model` argument value.
273 if isinstance(model, Sequential):
274 return _clone_sequential_model(model, input_tensors=input_tensors)
276 return _clone_functional_model(model, input_tensors=input_tensors)
279 # "Clone" a subclassed model by reseting all of the attributes.
280 def _in_place_subclassed_model_reset(model):
281 """Substitute for model cloning that works for subclassed models.
284 To "instantiate" an identical model in a new TF graph, we reuse the original
285 model object, but we clear its state.
287 After calling this function on a model instance, you can use the model
288 instance as if it were a model clone (in particular you can use it in a new
291 This method clears the state of the input model. It is thus destructive.
296 model: Instance of a Keras model created via subclassing.
299 ValueError: In case the model uses a subclassed model as inner layer.
301 assert not model._is_graph_network # Only makes sense for subclassed networks
302 # Retrieve all layers tracked by the model as well as their attribute names
304 for name in dir(model):
306 value = getattr(model, name)
311 assert value in model.layers
325 'model: %s' % name)
327 # Replace layers on the model with fresh layers
329 original_layers = model._layers[:]
330 setattr_tracking = model._setattr_tracking
331 model._setattr_tracking = False
332 model._layers = []
341 'model: %s' % layer)
344 setattr(model, name, fresh_layer)
345 model._layers.append(fresh_layer)
347 # Cache original model build attributes (in addition to layers)
348 if (not hasattr(model, '_original_attributes_cache') or
349 model._original_attributes_cache is None):
350 if model.built:
375 attributes_cache[name] = getattr(model, name)
376 model._original_attributes_cache = attributes_cache
377 _reset_build_compile_trackers(model)
378 model._setattr_tracking = setattr_tracking
381 def _reset_build_compile_trackers(model):
382 """Reset state trackers for model.
392 model: the model that is being reset
395 model.built = False
396 model.inputs = None
397 model.outputs = None
399 model._is_compiled = False # pylint:disable=protected-access
400 model.optimizer = None
403 model):
404 """Restores the original state of a model after it was "reset".
410 model: Instance of a Keras model created via subclassing, on which
413 assert not model._is_graph_network
415 if (hasattr(model, '_original_attributes_cache') and
416 model._original_attributes_cache is not None):
421 setattr_tracking = model._setattr_tracking
422 model._setattr_tracking = False
423 model._layers = []
424 for name, value in model._original_attributes_cache.items():
425 setattr(model, name, value)
427 model._layers.append(value)
428 model._original_attributes_cache = None
429 model._setattr_tracking = setattr_tracking
431 # Restore to the state of a never-called model.
432 _reset_build_compile_trackers(model)
436 model, input_tensors=None, target_tensors=None, custom_objects=None,
438 """Clone a `Model` and build/compile it with the same settings used before.
441 model. When using a separate graph, `in_place_reset` must be `False`.
447 model: `tf.keras.Model` object. Can be Functional, Sequential, or
449 input_tensors: Optional list of input tensors to build the model upon. If
451 target_tensors: Optional list of target tensors for compiling the model. If
455 compile_clone: Boolean, whether to compile model clone (default `True`).
456 in_place_reset: Boolean, whether to reset the model in place. Only used if
457 the model is a subclassed model. In the case of a subclassed model,
459 original model, use the function
460 `in_place_subclassed_model_state_restoration(model)`.
463 model is cloned into an Estimator model function, because Estimators
467 Clone of the model.
471 - cloning a subclassed model with `in_place_reset` set to False.
472 - compiling the clone when the original model has not been compiled.
476 orig_optimizer = model.optimizer
479 'Error when cloning model: compile_clone was set to True, but the '
480 'original model has not been compiled.')
482 if model._is_graph_network or isinstance(model, Sequential):
485 clone = clone_model(model, input_tensors=input_tensors)
487 clone = clone_model(model, input_tensors=input_tensors)
491 getattr(model, '_build_input_shape', None) is not None]):
492 # Set model inputs to build the model and add input/output properties.
494 # sequential model has multiple inputs.
496 K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype))
500 'This model is a subclassed model. '
501 'Such a model cannot be cloned, but there is a workaround where '
502 'the model is reset in-place. To use this, please set the argument '
504 'original model. To restore the attributes, call '
505 '`in_place_subclassed_model_state_restoration(model)`.')
506 clone = model
526 model.loss,
527 metrics=metrics_module.clone_metrics(model._compile_metrics),
528 loss_weights=model.loss_weights,
529 sample_weight_mode=model.sample_weight_mode,
531 model._compile_weighted_metrics),