1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Wrapper for using the Scikit-Learn API with Keras models. 16 """ 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import copy 22 import types 23 24 import numpy as np 25 26 from tensorflow.python.keras._impl.keras.models import Sequential 27 from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg 28 from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical 29 from tensorflow.python.util.tf_export import tf_export 30 31 32 class BaseWrapper(object): 33 """Base class for the Keras scikit-learn wrapper. 34 35 Warning: This class should not be used directly. 36 Use descendant classes instead. 37 38 Arguments: 39 build_fn: callable function or class instance 40 **sk_params: model parameters & fitting parameters 41 42 The `build_fn` should construct, compile and return a Keras model, which 43 will then be used to fit/predict. One of the following 44 three values could be passed to `build_fn`: 45 1. A function 46 2. An instance of a class that implements the `__call__` method 47 3. None. This means you implement a class that inherits from either 48 `KerasClassifier` or `KerasRegressor`. The `__call__` method of the 49 present class will then be treated as the default `build_fn`. 50 51 `sk_params` takes both model parameters and fitting parameters. Legal model 52 parameters are the arguments of `build_fn`. Note that like all other 53 estimators in scikit-learn, `build_fn` should provide default values for 54 its arguments, so that you could create the estimator without passing any 55 values to `sk_params`. 56 57 `sk_params` could also accept parameters for calling `fit`, `predict`, 58 `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`). 59 fitting (predicting) parameters are selected in the following order: 60 61 1. Values passed to the dictionary arguments of 62 `fit`, `predict`, `predict_proba`, and `score` methods 63 2. Values passed to `sk_params` 64 3. The default values of the `keras.models.Sequential` 65 `fit`, `predict`, `predict_proba` and `score` methods 66 67 When using scikit-learn's `grid_search` API, legal tunable parameters are 68 those you could pass to `sk_params`, including fitting parameters. 69 In other words, you could use `grid_search` to search for the best 70 `batch_size` or `epochs` as well as the model parameters. 71 """ 72 73 def __init__(self, build_fn=None, **sk_params): 74 self.build_fn = build_fn 75 self.sk_params = sk_params 76 self.check_params(sk_params) 77 78 def check_params(self, params): 79 """Checks for user typos in `params`. 80 81 Arguments: 82 params: dictionary; the parameters to be checked 83 84 Raises: 85 ValueError: if any member of `params` is not a valid argument. 86 """ 87 legal_params_fns = [ 88 Sequential.fit, Sequential.predict, Sequential.predict_classes, 89 Sequential.evaluate 90 ] 91 if self.build_fn is None: 92 legal_params_fns.append(self.__call__) 93 elif (not isinstance(self.build_fn, types.FunctionType) and 94 not isinstance(self.build_fn, types.MethodType)): 95 legal_params_fns.append(self.build_fn.__call__) 96 else: 97 legal_params_fns.append(self.build_fn) 98 99 for params_name in params: 100 for fn in legal_params_fns: 101 if has_arg(fn, params_name): 102 break 103 else: 104 if params_name != 'nb_epoch': 105 raise ValueError('{} is not a legal parameter'.format(params_name)) 106 107 def get_params(self, **params): # pylint: disable=unused-argument 108 """Gets parameters for this estimator. 109 110 Arguments: 111 **params: ignored (exists for API compatibility). 112 113 Returns: 114 Dictionary of parameter names mapped to their values. 115 """ 116 res = copy.deepcopy(self.sk_params) 117 res.update({'build_fn': self.build_fn}) 118 return res 119 120 def set_params(self, **params): 121 """Sets the parameters of this estimator. 122 123 Arguments: 124 **params: Dictionary of parameter names mapped to their values. 125 126 Returns: 127 self 128 """ 129 self.check_params(params) 130 self.sk_params.update(params) 131 return self 132 133 def fit(self, x, y, **kwargs): 134 """Constructs a new model with `build_fn` & fit the model to `(x, y)`. 135 136 Arguments: 137 x : array-like, shape `(n_samples, n_features)` 138 Training samples where `n_samples` is the number of samples 139 and `n_features` is the number of features. 140 y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 141 True labels for `x`. 142 **kwargs: dictionary arguments 143 Legal arguments are the arguments of `Sequential.fit` 144 145 Returns: 146 history : object 147 details about the training history at each epoch. 148 """ 149 if self.build_fn is None: 150 self.model = self.__call__(**self.filter_sk_params(self.__call__)) 151 elif (not isinstance(self.build_fn, types.FunctionType) and 152 not isinstance(self.build_fn, types.MethodType)): 153 self.model = self.build_fn( 154 **self.filter_sk_params(self.build_fn.__call__)) 155 else: 156 self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) 157 158 loss_name = self.model.loss 159 if hasattr(loss_name, '__name__'): 160 loss_name = loss_name.__name__ 161 if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: 162 y = to_categorical(y) 163 164 fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit)) 165 fit_args.update(kwargs) 166 167 history = self.model.fit(x, y, **fit_args) 168 169 return history 170 171 def filter_sk_params(self, fn, override=None): 172 """Filters `sk_params` and returns those in `fn`'s arguments. 173 174 Arguments: 175 fn : arbitrary function 176 override: dictionary, values to override `sk_params` 177 178 Returns: 179 res : dictionary containing variables 180 in both `sk_params` and `fn`'s arguments. 181 """ 182 override = override or {} 183 res = {} 184 for name, value in self.sk_params.items(): 185 if has_arg(fn, name): 186 res.update({name: value}) 187 res.update(override) 188 return res 189 190 191 @tf_export('keras.wrappers.scikit_learn.KerasClassifier') 192 class KerasClassifier(BaseWrapper): 193 """Implementation of the scikit-learn classifier API for Keras. 194 """ 195 196 def fit(self, x, y, **kwargs): 197 """Constructs a new model with `build_fn` & fit the model to `(x, y)`. 198 199 Arguments: 200 x : array-like, shape `(n_samples, n_features)` 201 Training samples where `n_samples` is the number of samples 202 and `n_features` is the number of features. 203 y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 204 True labels for `x`. 205 **kwargs: dictionary arguments 206 Legal arguments are the arguments of `Sequential.fit` 207 208 Returns: 209 history : object 210 details about the training history at each epoch. 211 212 Raises: 213 ValueError: In case of invalid shape for `y` argument. 214 """ 215 y = np.array(y) 216 if len(y.shape) == 2 and y.shape[1] > 1: 217 self.classes_ = np.arange(y.shape[1]) 218 elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1: 219 self.classes_ = np.unique(y) 220 y = np.searchsorted(self.classes_, y) 221 else: 222 raise ValueError('Invalid shape for y: ' + str(y.shape)) 223 self.n_classes_ = len(self.classes_) 224 return super(KerasClassifier, self).fit(x, y, **kwargs) 225 226 def predict(self, x, **kwargs): 227 """Returns the class predictions for the given test data. 228 229 Arguments: 230 x: array-like, shape `(n_samples, n_features)` 231 Test samples where `n_samples` is the number of samples 232 and `n_features` is the number of features. 233 **kwargs: dictionary arguments 234 Legal arguments are the arguments 235 of `Sequential.predict_classes`. 236 237 Returns: 238 preds: array-like, shape `(n_samples,)` 239 Class predictions. 240 """ 241 kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs) 242 classes = self.model.predict_classes(x, **kwargs) 243 return self.classes_[classes] 244 245 def predict_proba(self, x, **kwargs): 246 """Returns class probability estimates for the given test data. 247 248 Arguments: 249 x: array-like, shape `(n_samples, n_features)` 250 Test samples where `n_samples` is the number of samples 251 and `n_features` is the number of features. 252 **kwargs: dictionary arguments 253 Legal arguments are the arguments 254 of `Sequential.predict_classes`. 255 256 Returns: 257 proba: array-like, shape `(n_samples, n_outputs)` 258 Class probability estimates. 259 In the case of binary classification, 260 to match the scikit-learn API, 261 will return an array of shape `(n_samples, 2)` 262 (instead of `(n_sample, 1)` as in Keras). 263 """ 264 kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs) 265 probs = self.model.predict_proba(x, **kwargs) 266 267 # check if binary classification 268 if probs.shape[1] == 1: 269 # first column is probability of class 0 and second is of class 1 270 probs = np.hstack([1 - probs, probs]) 271 return probs 272 273 def score(self, x, y, **kwargs): 274 """Returns the mean accuracy on the given test data and labels. 275 276 Arguments: 277 x: array-like, shape `(n_samples, n_features)` 278 Test samples where `n_samples` is the number of samples 279 and `n_features` is the number of features. 280 y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 281 True labels for `x`. 282 **kwargs: dictionary arguments 283 Legal arguments are the arguments of `Sequential.evaluate`. 284 285 Returns: 286 score: float 287 Mean accuracy of predictions on `x` wrt. `y`. 288 289 Raises: 290 ValueError: If the underlying model isn't configured to 291 compute accuracy. You should pass `metrics=["accuracy"]` to 292 the `.compile()` method of the model. 293 """ 294 y = np.searchsorted(self.classes_, y) 295 kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) 296 297 loss_name = self.model.loss 298 if hasattr(loss_name, '__name__'): 299 loss_name = loss_name.__name__ 300 if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: 301 y = to_categorical(y) 302 303 outputs = self.model.evaluate(x, y, **kwargs) 304 if not isinstance(outputs, list): 305 outputs = [outputs] 306 for name, output in zip(self.model.metrics_names, outputs): 307 if name == 'acc': 308 return output 309 raise ValueError('The model is not configured to compute accuracy. ' 310 'You should pass `metrics=["accuracy"]` to ' 311 'the `model.compile()` method.') 312 313 314 @tf_export('keras.wrappers.scikit_learn.KerasRegressor') 315 class KerasRegressor(BaseWrapper): 316 """Implementation of the scikit-learn regressor API for Keras. 317 """ 318 319 def predict(self, x, **kwargs): 320 """Returns predictions for the given test data. 321 322 Arguments: 323 x: array-like, shape `(n_samples, n_features)` 324 Test samples where `n_samples` is the number of samples 325 and `n_features` is the number of features. 326 **kwargs: dictionary arguments 327 Legal arguments are the arguments of `Sequential.predict`. 328 329 Returns: 330 preds: array-like, shape `(n_samples,)` 331 Predictions. 332 """ 333 kwargs = self.filter_sk_params(Sequential.predict, kwargs) 334 return np.squeeze(self.model.predict(x, **kwargs)) 335 336 def score(self, x, y, **kwargs): 337 """Returns the mean loss on the given test data and labels. 338 339 Arguments: 340 x: array-like, shape `(n_samples, n_features)` 341 Test samples where `n_samples` is the number of samples 342 and `n_features` is the number of features. 343 y: array-like, shape `(n_samples,)` 344 True labels for `x`. 345 **kwargs: dictionary arguments 346 Legal arguments are the arguments of `Sequential.evaluate`. 347 348 Returns: 349 score: float 350 Mean accuracy of predictions on `x` wrt. `y`. 351 """ 352 kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) 353 loss = self.model.evaluate(x, y, **kwargs) 354 if isinstance(loss, list): 355 return -loss[0] 356 return -loss 357