1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Keras estimator API.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.util.tf_export import keras_export 22 23 # Keras has undeclared dependency on tensorflow/estimator:estimator_py. 24 # As long as you depend //third_party/py/tensorflow:tensorflow target 25 # everything will work as normal. 26 27 28 # LINT.IfChange 29 @keras_export('keras.estimator.model_to_estimator') 30 def model_to_estimator( 31 keras_model=None, 32 keras_model_path=None, 33 custom_objects=None, 34 model_dir=None, 35 config=None): 36 """Constructs an `Estimator` instance from given keras model. 37 38 For usage example, please see: 39 [Creating estimators from Keras 40 Models](https://tensorflow.org/guide/estimators#model_to_estimator). 41 42 Args: 43 keras_model: A compiled Keras model object. This argument is mutually 44 exclusive with `keras_model_path`. 45 keras_model_path: Path to a compiled Keras model saved on disk, in HDF5 46 format, which can be generated with the `save()` method of a Keras model. 47 This argument is mutually exclusive with `keras_model`. 48 custom_objects: Dictionary for custom objects. 49 model_dir: Directory to save `Estimator` model parameters, graph, summary 50 files for TensorBoard, etc. 51 config: `RunConfig` to config `Estimator`. 52 53 Returns: 54 An Estimator from given keras model. 55 56 Raises: 57 ValueError: if neither keras_model nor keras_model_path was given. 58 ValueError: if both keras_model and keras_model_path was given. 59 ValueError: if the keras_model_path is a GCS URI. 60 ValueError: if keras_model has not been compiled. 61 """ 62 try: 63 from tensorflow_estimator.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top 64 except ImportError: 65 raise NotImplementedError( 66 'tf.keras.estimator.model_to_estimator function not available in your ' 67 'installation.') 68 return keras_lib.model_to_estimator( 69 keras_model=keras_model, 70 keras_model_path=keras_model_path, 71 custom_objects=custom_objects, 72 model_dir=model_dir, 73 config=config) 74 75 # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py) 76 77 78