Home | History | Annotate | Download | only in estimator
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Keras estimator API."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.util.tf_export import keras_export
     22 
     23 # Keras has undeclared dependency on tensorflow/estimator:estimator_py.
     24 # As long as you depend //third_party/py/tensorflow:tensorflow target
     25 # everything will work as normal.
     26 
     27 
     28 # LINT.IfChange
     29 @keras_export('keras.estimator.model_to_estimator')
     30 def model_to_estimator(
     31     keras_model=None,
     32     keras_model_path=None,
     33     custom_objects=None,
     34     model_dir=None,
     35     config=None):
     36   """Constructs an `Estimator` instance from given keras model.
     37 
     38   For usage example, please see:
     39   [Creating estimators from Keras
     40   Models](https://tensorflow.org/guide/estimators#model_to_estimator).
     41 
     42   Args:
     43     keras_model: A compiled Keras model object. This argument is mutually
     44       exclusive with `keras_model_path`.
     45     keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
     46       format, which can be generated with the `save()` method of a Keras model.
     47       This argument is mutually exclusive with `keras_model`.
     48     custom_objects: Dictionary for custom objects.
     49     model_dir: Directory to save `Estimator` model parameters, graph, summary
     50       files for TensorBoard, etc.
     51     config: `RunConfig` to config `Estimator`.
     52 
     53   Returns:
     54     An Estimator from given keras model.
     55 
     56   Raises:
     57     ValueError: if neither keras_model nor keras_model_path was given.
     58     ValueError: if both keras_model and keras_model_path was given.
     59     ValueError: if the keras_model_path is a GCS URI.
     60     ValueError: if keras_model has not been compiled.
     61   """
     62   try:
     63     from tensorflow_estimator.python.estimator import keras as keras_lib  # pylint: disable=g-import-not-at-top
     64   except ImportError:
     65     raise NotImplementedError(
     66         'tf.keras.estimator.model_to_estimator function not available in your '
     67         'installation.')
     68   return keras_lib.model_to_estimator(
     69       keras_model=keras_model,
     70       keras_model_path=keras_model_path,
     71       custom_objects=custom_objects,
     72       model_dir=model_dir,
     73       config=config)
     74 
     75 # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)
     76 
     77 
     78