1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Built-in activation functions. 16 """ 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.python.keras._impl.keras import backend as K 24 from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object 25 from tensorflow.python.layers.base import Layer 26 from tensorflow.python.platform import tf_logging as logging 27 from tensorflow.python.util.tf_export import tf_export 28 29 30 @tf_export('keras.activations.softmax') 31 def softmax(x, axis=-1): 32 """Softmax activation function. 33 34 Arguments: 35 x : Tensor. 36 axis: Integer, axis along which the softmax normalization is applied. 37 38 Returns: 39 Tensor, output of softmax transformation. 40 41 Raises: 42 ValueError: In case `dim(x) == 1`. 43 """ 44 ndim = K.ndim(x) 45 if ndim == 2: 46 return K.softmax(x) 47 elif ndim > 2: 48 e = K.exp(x - K.max(x, axis=axis, keepdims=True)) 49 s = K.sum(e, axis=axis, keepdims=True) 50 return e / s 51 else: 52 raise ValueError('Cannot apply softmax to a tensor that is 1D') 53 54 55 @tf_export('keras.activations.elu') 56 def elu(x, alpha=1.0): 57 return K.elu(x, alpha) 58 59 60 @tf_export('keras.activations.selu') 61 def selu(x): 62 """Scaled Exponential Linear Unit. (Klambauer et al., 2017). 63 64 Arguments: 65 x: A tensor or variable to compute the activation function for. 66 67 Returns: 68 Tensor with the same shape and dtype as `x`. 69 70 # Note 71 - To be used together with the initialization "lecun_normal". 72 - To be used together with the dropout variant "AlphaDropout". 73 74 """ 75 alpha = 1.6732632423543772848170429916717 76 scale = 1.0507009873554804934193349852946 77 return scale * K.elu(x, alpha) 78 79 80 @tf_export('keras.activations.softplus') 81 def softplus(x): 82 return K.softplus(x) 83 84 85 @tf_export('keras.activations.softsign') 86 def softsign(x): 87 return K.softsign(x) 88 89 90 @tf_export('keras.activations.relu') 91 def relu(x, alpha=0., max_value=None): 92 return K.relu(x, alpha=alpha, max_value=max_value) 93 94 95 @tf_export('keras.activations.tanh') 96 def tanh(x): 97 return K.tanh(x) 98 99 100 @tf_export('keras.activations.sigmoid') 101 def sigmoid(x): 102 return K.sigmoid(x) 103 104 105 @tf_export('keras.activations.hard_sigmoid') 106 def hard_sigmoid(x): 107 return K.hard_sigmoid(x) 108 109 110 @tf_export('keras.activations.linear') 111 def linear(x): 112 return x 113 114 115 @tf_export('keras.activations.serialize') 116 def serialize(activation): 117 return activation.__name__ 118 119 120 @tf_export('keras.activations.deserialize') 121 def deserialize(name, custom_objects=None): 122 return deserialize_keras_object( 123 name, 124 module_objects=globals(), 125 custom_objects=custom_objects, 126 printable_module_name='activation function') 127 128 129 @tf_export('keras.activations.get') 130 def get(identifier): 131 if identifier is None: 132 return linear 133 if isinstance(identifier, six.string_types): 134 identifier = str(identifier) 135 return deserialize(identifier) 136 elif callable(identifier): 137 if isinstance(identifier, Layer): 138 logging.warning( 139 'Do not pass a layer instance (such as {identifier}) as the ' 140 'activation argument of another layer. Instead, advanced ' 141 'activation layers should be used just like any other ' 142 'layer in a model.'.format(identifier=identifier.__class__.__name__)) 143 return identifier 144 else: 145 raise ValueError('Could not interpret ' 146 'activation function identifier:', identifier) 147