Home | History | Annotate | Download | only in wrappers
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Wrapper for using the Scikit-Learn API with Keras models.
     16 """
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 import types
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.keras._impl.keras.models import Sequential
     27 from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
     28 from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
     29 from tensorflow.python.util.tf_export import tf_export
     30 
     31 
     32 class BaseWrapper(object):
     33   """Base class for the Keras scikit-learn wrapper.
     34 
     35   Warning: This class should not be used directly.
     36   Use descendant classes instead.
     37 
     38   Arguments:
     39       build_fn: callable function or class instance
     40       **sk_params: model parameters & fitting parameters
     41 
     42   The `build_fn` should construct, compile and return a Keras model, which
     43   will then be used to fit/predict. One of the following
     44   three values could be passed to `build_fn`:
     45   1. A function
     46   2. An instance of a class that implements the `__call__` method
     47   3. None. This means you implement a class that inherits from either
     48   `KerasClassifier` or `KerasRegressor`. The `__call__` method of the
     49   present class will then be treated as the default `build_fn`.
     50 
     51   `sk_params` takes both model parameters and fitting parameters. Legal model
     52   parameters are the arguments of `build_fn`. Note that like all other
     53   estimators in scikit-learn, `build_fn` should provide default values for
     54   its arguments, so that you could create the estimator without passing any
     55   values to `sk_params`.
     56 
     57   `sk_params` could also accept parameters for calling `fit`, `predict`,
     58   `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`).
     59   fitting (predicting) parameters are selected in the following order:
     60 
     61   1. Values passed to the dictionary arguments of
     62   `fit`, `predict`, `predict_proba`, and `score` methods
     63   2. Values passed to `sk_params`
     64   3. The default values of the `keras.models.Sequential`
     65   `fit`, `predict`, `predict_proba` and `score` methods
     66 
     67   When using scikit-learn's `grid_search` API, legal tunable parameters are
     68   those you could pass to `sk_params`, including fitting parameters.
     69   In other words, you could use `grid_search` to search for the best
     70   `batch_size` or `epochs` as well as the model parameters.
     71   """
     72 
     73   def __init__(self, build_fn=None, **sk_params):
     74     self.build_fn = build_fn
     75     self.sk_params = sk_params
     76     self.check_params(sk_params)
     77 
     78   def check_params(self, params):
     79     """Checks for user typos in `params`.
     80 
     81     Arguments:
     82         params: dictionary; the parameters to be checked
     83 
     84     Raises:
     85         ValueError: if any member of `params` is not a valid argument.
     86     """
     87     legal_params_fns = [
     88         Sequential.fit, Sequential.predict, Sequential.predict_classes,
     89         Sequential.evaluate
     90     ]
     91     if self.build_fn is None:
     92       legal_params_fns.append(self.__call__)
     93     elif (not isinstance(self.build_fn, types.FunctionType) and
     94           not isinstance(self.build_fn, types.MethodType)):
     95       legal_params_fns.append(self.build_fn.__call__)
     96     else:
     97       legal_params_fns.append(self.build_fn)
     98 
     99     for params_name in params:
    100       for fn in legal_params_fns:
    101         if has_arg(fn, params_name):
    102           break
    103       else:
    104         if params_name != 'nb_epoch':
    105           raise ValueError('{} is not a legal parameter'.format(params_name))
    106 
    107   def get_params(self, **params):  # pylint: disable=unused-argument
    108     """Gets parameters for this estimator.
    109 
    110     Arguments:
    111         **params: ignored (exists for API compatibility).
    112 
    113     Returns:
    114         Dictionary of parameter names mapped to their values.
    115     """
    116     res = copy.deepcopy(self.sk_params)
    117     res.update({'build_fn': self.build_fn})
    118     return res
    119 
    120   def set_params(self, **params):
    121     """Sets the parameters of this estimator.
    122 
    123     Arguments:
    124         **params: Dictionary of parameter names mapped to their values.
    125 
    126     Returns:
    127         self
    128     """
    129     self.check_params(params)
    130     self.sk_params.update(params)
    131     return self
    132 
    133   def fit(self, x, y, **kwargs):
    134     """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
    135 
    136     Arguments:
    137         x : array-like, shape `(n_samples, n_features)`
    138             Training samples where `n_samples` is the number of samples
    139             and `n_features` is the number of features.
    140         y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
    141             True labels for `x`.
    142         **kwargs: dictionary arguments
    143             Legal arguments are the arguments of `Sequential.fit`
    144 
    145     Returns:
    146         history : object
    147             details about the training history at each epoch.
    148     """
    149     if self.build_fn is None:
    150       self.model = self.__call__(**self.filter_sk_params(self.__call__))
    151     elif (not isinstance(self.build_fn, types.FunctionType) and
    152           not isinstance(self.build_fn, types.MethodType)):
    153       self.model = self.build_fn(
    154           **self.filter_sk_params(self.build_fn.__call__))
    155     else:
    156       self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
    157 
    158     loss_name = self.model.loss
    159     if hasattr(loss_name, '__name__'):
    160       loss_name = loss_name.__name__
    161     if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
    162       y = to_categorical(y)
    163 
    164     fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
    165     fit_args.update(kwargs)
    166 
    167     history = self.model.fit(x, y, **fit_args)
    168 
    169     return history
    170 
    171   def filter_sk_params(self, fn, override=None):
    172     """Filters `sk_params` and returns those in `fn`'s arguments.
    173 
    174     Arguments:
    175         fn : arbitrary function
    176         override: dictionary, values to override `sk_params`
    177 
    178     Returns:
    179         res : dictionary containing variables
    180             in both `sk_params` and `fn`'s arguments.
    181     """
    182     override = override or {}
    183     res = {}
    184     for name, value in self.sk_params.items():
    185       if has_arg(fn, name):
    186         res.update({name: value})
    187     res.update(override)
    188     return res
    189 
    190 
    191 @tf_export('keras.wrappers.scikit_learn.KerasClassifier')
    192 class KerasClassifier(BaseWrapper):
    193   """Implementation of the scikit-learn classifier API for Keras.
    194   """
    195 
    196   def fit(self, x, y, **kwargs):
    197     """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
    198 
    199     Arguments:
    200         x : array-like, shape `(n_samples, n_features)`
    201             Training samples where `n_samples` is the number of samples
    202             and `n_features` is the number of features.
    203         y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
    204             True labels for `x`.
    205         **kwargs: dictionary arguments
    206             Legal arguments are the arguments of `Sequential.fit`
    207 
    208     Returns:
    209         history : object
    210             details about the training history at each epoch.
    211 
    212     Raises:
    213         ValueError: In case of invalid shape for `y` argument.
    214     """
    215     y = np.array(y)
    216     if len(y.shape) == 2 and y.shape[1] > 1:
    217       self.classes_ = np.arange(y.shape[1])
    218     elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
    219       self.classes_ = np.unique(y)
    220       y = np.searchsorted(self.classes_, y)
    221     else:
    222       raise ValueError('Invalid shape for y: ' + str(y.shape))
    223     self.n_classes_ = len(self.classes_)
    224     return super(KerasClassifier, self).fit(x, y, **kwargs)
    225 
    226   def predict(self, x, **kwargs):
    227     """Returns the class predictions for the given test data.
    228 
    229     Arguments:
    230         x: array-like, shape `(n_samples, n_features)`
    231             Test samples where `n_samples` is the number of samples
    232             and `n_features` is the number of features.
    233         **kwargs: dictionary arguments
    234             Legal arguments are the arguments
    235             of `Sequential.predict_classes`.
    236 
    237     Returns:
    238         preds: array-like, shape `(n_samples,)`
    239             Class predictions.
    240     """
    241     kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
    242     classes = self.model.predict_classes(x, **kwargs)
    243     return self.classes_[classes]
    244 
    245   def predict_proba(self, x, **kwargs):
    246     """Returns class probability estimates for the given test data.
    247 
    248     Arguments:
    249         x: array-like, shape `(n_samples, n_features)`
    250             Test samples where `n_samples` is the number of samples
    251             and `n_features` is the number of features.
    252         **kwargs: dictionary arguments
    253             Legal arguments are the arguments
    254             of `Sequential.predict_classes`.
    255 
    256     Returns:
    257         proba: array-like, shape `(n_samples, n_outputs)`
    258             Class probability estimates.
    259             In the case of binary classification,
    260             to match the scikit-learn API,
    261             will return an array of shape `(n_samples, 2)`
    262             (instead of `(n_sample, 1)` as in Keras).
    263     """
    264     kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
    265     probs = self.model.predict_proba(x, **kwargs)
    266 
    267     # check if binary classification
    268     if probs.shape[1] == 1:
    269       # first column is probability of class 0 and second is of class 1
    270       probs = np.hstack([1 - probs, probs])
    271     return probs
    272 
    273   def score(self, x, y, **kwargs):
    274     """Returns the mean accuracy on the given test data and labels.
    275 
    276     Arguments:
    277         x: array-like, shape `(n_samples, n_features)`
    278             Test samples where `n_samples` is the number of samples
    279             and `n_features` is the number of features.
    280         y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
    281             True labels for `x`.
    282         **kwargs: dictionary arguments
    283             Legal arguments are the arguments of `Sequential.evaluate`.
    284 
    285     Returns:
    286         score: float
    287             Mean accuracy of predictions on `x` wrt. `y`.
    288 
    289     Raises:
    290         ValueError: If the underlying model isn't configured to
    291             compute accuracy. You should pass `metrics=["accuracy"]` to
    292             the `.compile()` method of the model.
    293     """
    294     y = np.searchsorted(self.classes_, y)
    295     kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
    296 
    297     loss_name = self.model.loss
    298     if hasattr(loss_name, '__name__'):
    299       loss_name = loss_name.__name__
    300     if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
    301       y = to_categorical(y)
    302 
    303     outputs = self.model.evaluate(x, y, **kwargs)
    304     if not isinstance(outputs, list):
    305       outputs = [outputs]
    306     for name, output in zip(self.model.metrics_names, outputs):
    307       if name == 'acc':
    308         return output
    309     raise ValueError('The model is not configured to compute accuracy. '
    310                      'You should pass `metrics=["accuracy"]` to '
    311                      'the `model.compile()` method.')
    312 
    313 
    314 @tf_export('keras.wrappers.scikit_learn.KerasRegressor')
    315 class KerasRegressor(BaseWrapper):
    316   """Implementation of the scikit-learn regressor API for Keras.
    317   """
    318 
    319   def predict(self, x, **kwargs):
    320     """Returns predictions for the given test data.
    321 
    322     Arguments:
    323         x: array-like, shape `(n_samples, n_features)`
    324             Test samples where `n_samples` is the number of samples
    325             and `n_features` is the number of features.
    326         **kwargs: dictionary arguments
    327             Legal arguments are the arguments of `Sequential.predict`.
    328 
    329     Returns:
    330         preds: array-like, shape `(n_samples,)`
    331             Predictions.
    332     """
    333     kwargs = self.filter_sk_params(Sequential.predict, kwargs)
    334     return np.squeeze(self.model.predict(x, **kwargs))
    335 
    336   def score(self, x, y, **kwargs):
    337     """Returns the mean loss on the given test data and labels.
    338 
    339     Arguments:
    340         x: array-like, shape `(n_samples, n_features)`
    341             Test samples where `n_samples` is the number of samples
    342             and `n_features` is the number of features.
    343         y: array-like, shape `(n_samples,)`
    344             True labels for `x`.
    345         **kwargs: dictionary arguments
    346             Legal arguments are the arguments of `Sequential.evaluate`.
    347 
    348     Returns:
    349         score: float
    350             Mean accuracy of predictions on `x` wrt. `y`.
    351     """
    352     kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
    353     loss = self.model.evaluate(x, y, **kwargs)
    354     if isinstance(loss, list):
    355       return -loss[0]
    356     return -loss
    357