Home | History | Annotate | Download | only in keras
      1 # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Keras backend config API."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.util.tf_export import keras_export
     21 
     22 # The type of float to use throughout a session.
     23 _FLOATX = 'float32'
     24 
     25 # Epsilon fuzz factor used throughout the codebase.
     26 _EPSILON = 1e-7
     27 
     28 # Default image data format, one of "channels_last", "channels_first".
     29 _IMAGE_DATA_FORMAT = 'channels_last'
     30 
     31 
     32 @keras_export('keras.backend.epsilon')
     33 def epsilon():
     34   """Returns the value of the fuzz factor used in numeric expressions.
     35 
     36   Returns:
     37       A float.
     38 
     39   Example:
     40   ```python
     41   keras.backend.epsilon() >>>1e-07
     42   ```
     43   """
     44   return _EPSILON
     45 
     46 
     47 @keras_export('keras.backend.set_epsilon')
     48 def set_epsilon(value):
     49   """Sets the value of the fuzz factor used in numeric expressions.
     50 
     51   Arguments:
     52       value: float. New value of epsilon.
     53   Example: ```python from keras import backend as K K.epsilon() >>> 1e-07
     54     K.set_epsilon(1e-05) K.epsilon() >>> 1e-05 ```
     55   """
     56   global _EPSILON
     57   _EPSILON = value
     58 
     59 
     60 @keras_export('keras.backend.floatx')
     61 def floatx():
     62   """Returns the default float type, as a string.
     63 
     64   E.g. 'float16', 'float32', 'float64'.
     65 
     66   Returns:
     67       String, the current default float type.
     68 
     69   Example:
     70   ```python
     71   keras.backend.floatx() >>> 'float32'
     72   ```
     73   """
     74   return _FLOATX
     75 
     76 
     77 @keras_export('keras.backend.set_floatx')
     78 def set_floatx(value):
     79   """Sets the default float type.
     80 
     81   Arguments:
     82       value: String; 'float16', 'float32', or 'float64'.
     83   Example: ```python from keras import backend as K K.floatx() >>> 'float32'
     84     K.set_floatx('float16') K.floatx() >>> 'float16' ```
     85 
     86   Raises:
     87       ValueError: In case of invalid value.
     88   """
     89   global _FLOATX
     90   if value not in {'float16', 'float32', 'float64'}:
     91     raise ValueError('Unknown floatx type: ' + str(value))
     92   _FLOATX = str(value)
     93 
     94 
     95 @keras_export('keras.backend.image_data_format')
     96 def image_data_format():
     97   """Returns the default image data format convention.
     98 
     99   Returns:
    100       A string, either `'channels_first'` or `'channels_last'`
    101 
    102   Example:
    103   ```python
    104   keras.backend.image_data_format() >>> 'channels_first'
    105   ```
    106   """
    107   return _IMAGE_DATA_FORMAT
    108 
    109 
    110 @keras_export('keras.backend.set_image_data_format')
    111 def set_image_data_format(data_format):
    112   """Sets the value of the image data format convention.
    113 
    114   Arguments:
    115       data_format: string. `'channels_first'` or `'channels_last'`.
    116   Example: ```python from keras import backend as K K.image_data_format() >>>
    117     'channels_first' K.set_image_data_format('channels_last')
    118     K.image_data_format() >>> 'channels_last' ```
    119 
    120   Raises:
    121       ValueError: In case of invalid `data_format` value.
    122   """
    123   global _IMAGE_DATA_FORMAT
    124   if data_format not in {'channels_last', 'channels_first'}:
    125     raise ValueError('Unknown data_format: ' + str(data_format))
    126   _IMAGE_DATA_FORMAT = str(data_format)
    127