Home | History | Annotate | Download | only in keras
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=protected-access
     16 # pylint: disable=redefined-outer-name
     17 # pylint: disable=redefined-builtin
     18 """Keras backend API.
     19 """
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 import collections
     25 import itertools
     26 import json
     27 import os
     28 import threading
     29 import weakref
     30 
     31 import numpy as np
     32 
     33 from tensorflow.core.protobuf import config_pb2
     34 from tensorflow.python.client import session as session_module
     35 from tensorflow.python.distribute import distribute_coordinator as dc
     36 from tensorflow.python.distribute import distribute_coordinator_context as dc_context
     37 from tensorflow.python.distribute import distribution_strategy_context
     38 from tensorflow.python.eager import context
     39 from tensorflow.python.eager import function as eager_function
     40 from tensorflow.python.eager import lift_to_graph
     41 from tensorflow.python.framework import constant_op
     42 from tensorflow.python.framework import dtypes as dtypes_module
     43 from tensorflow.python.framework import func_graph
     44 from tensorflow.python.framework import ops
     45 from tensorflow.python.framework import sparse_tensor
     46 from tensorflow.python.framework import tensor_util
     47 from tensorflow.python.keras import backend_config
     48 from tensorflow.python.ops import array_ops
     49 from tensorflow.python.ops import clip_ops
     50 from tensorflow.python.ops import control_flow_ops
     51 from tensorflow.python.ops import ctc_ops as ctc
     52 from tensorflow.python.ops import functional_ops
     53 from tensorflow.python.ops import gradients as gradients_module
     54 from tensorflow.python.ops import image_ops
     55 from tensorflow.python.ops import init_ops
     56 from tensorflow.python.ops import linalg_ops
     57 from tensorflow.python.ops import logging_ops
     58 from tensorflow.python.ops import map_fn as map_fn_lib
     59 from tensorflow.python.ops import math_ops
     60 from tensorflow.python.ops import nn
     61 from tensorflow.python.ops import random_ops
     62 from tensorflow.python.ops import resource_variable_ops
     63 from tensorflow.python.ops import sparse_ops
     64 from tensorflow.python.ops import state_ops
     65 from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
     66 from tensorflow.python.ops import tensor_array_ops
     67 from tensorflow.python.ops import variables as variables_module
     68 from tensorflow.python.training import server_lib
     69 from tensorflow.python.util import nest
     70 from tensorflow.python.util import tf_contextlib
     71 from tensorflow.python.util import tf_inspect
     72 from tensorflow.python.util.tf_export import keras_export
     73 
     74 py_all = all
     75 py_sum = sum
     76 
     77 # INTERNAL UTILS
     78 
     79 # The internal graph maintained by Keras and used by the symbolic Keras APIs
     80 # while executing eagerly (such as the functional API for model-building).
     81 _GRAPH = None
     82 
     83 # A graph which is used for constructing functions in eager mode.
     84 _CURRENT_SCRATCH_GRAPH = None
     85 
     86 # This is a thread local object that will hold the default internal TF session
     87 # used by Keras. It can be set manually via `set_session(sess)`.
     88 _SESSION = threading.local()
     89 
     90 # This dictionary holds a mapping {graph: learning_phase}.
     91 # A learning phase is a bool tensor used to run Keras models in
     92 # either train mode (learning_phase == 1) or test mode (learning_phase == 0).
     93 _GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
     94 
     95 
     96 # _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
     97 # We keep a separate reference to it to make sure it does not get removed from
     98 # _GRAPH_LEARNING_PHASES.
     99 _DUMMY_EAGER_GRAPH = threading.local()
    100 
    101 # This boolean flag can be set to True to leave variable initialization
    102 # up to the user.
    103 # Change its value via `manual_variable_initialization(value)`.
    104 _MANUAL_VAR_INIT = False
    105 
    106 # This list holds the available devices.
    107 # It is populated when `_get_available_gpus()` is called for the first time.
    108 # We assume our devices don't change henceforth.
    109 _LOCAL_DEVICES = None
    110 
    111 # This dictionary holds a mapping between a graph and variables to initialize
    112 # in the graph.
    113 _GRAPH_VARIABLES = weakref.WeakKeyDictionary()
    114 
    115 # This dictionary holds a mapping between a graph and TF optimizers created in
    116 # the graph.
    117 _GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
    118 
    119 # The below functions are kept accessible from backend for compatibility.
    120 epsilon = backend_config.epsilon
    121 floatx = backend_config.floatx
    122 image_data_format = backend_config.image_data_format
    123 set_epsilon = backend_config.set_epsilon
    124 set_floatx = backend_config.set_floatx
    125 set_image_data_format = backend_config.set_image_data_format
    126 
    127 
    128 @keras_export('keras.backend.backend')
    129 def backend():
    130   """Publicly accessible method for determining the current backend.
    131 
    132   Only exists for API compatibility with multi-backend Keras.
    133 
    134   Returns:
    135       The string "tensorflow".
    136   """
    137   return 'tensorflow'
    138 
    139 
    140 @keras_export('keras.backend.cast_to_floatx')
    141 def cast_to_floatx(x):
    142   """Cast a Numpy array to the default Keras float type.
    143 
    144   Arguments:
    145       x: Numpy array.
    146 
    147   Returns:
    148       The same Numpy array, cast to its new type.
    149 
    150   Example:
    151   ```python
    152       >>> from keras import backend as K
    153       >>> K.floatx()
    154       'float32'
    155       >>> arr = numpy.array([1.0, 2.0], dtype='float64')
    156       >>> arr.dtype
    157       dtype('float64')
    158       >>> new_arr = K.cast_to_floatx(arr)
    159       >>> new_arr
    160       array([ 1.,  2.], dtype=float32)
    161       >>> new_arr.dtype
    162       dtype('float32')
    163   ```
    164   """
    165   return np.asarray(x, dtype=floatx())
    166 
    167 
    168 # A global dictionary mapping graph objects to an index of counters used
    169 # for various layer names in each graph.
    170 # Allows to give unique autogenerated names to layers, in a graph-specific way.
    171 PER_GRAPH_LAYER_NAME_UIDS = weakref.WeakKeyDictionary()
    172 
    173 
    174 @keras_export('keras.backend.get_uid')
    175 def get_uid(prefix=''):
    176   """Associates a string prefix with an integer counter in a TensorFlow graph.
    177 
    178   Arguments:
    179     prefix: String prefix to index.
    180 
    181   Returns:
    182     Unique integer ID.
    183 
    184   Example:
    185 
    186   ```
    187     >>> get_uid('dense')
    188     1
    189     >>> get_uid('dense')
    190     2
    191   ```
    192   """
    193   graph = get_graph()
    194   if graph not in PER_GRAPH_LAYER_NAME_UIDS:
    195     PER_GRAPH_LAYER_NAME_UIDS[graph] = collections.defaultdict(int)
    196   layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
    197   layer_name_uids[prefix] += 1
    198   return layer_name_uids[prefix]
    199 
    200 
    201 @keras_export('keras.backend.reset_uids')
    202 def reset_uids():
    203   """Resets graph identifiers.
    204   """
    205   per_graph_layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS
    206   keys = list(per_graph_layer_name_uids.keys())
    207   for key in keys:
    208     del per_graph_layer_name_uids[key]
    209 
    210 
    211 @keras_export('keras.backend.clear_session')
    212 def clear_session():
    213   """Destroys the current TF graph and creates a new one.
    214 
    215   Useful to avoid clutter from old models / layers.
    216   """
    217   global _SESSION
    218   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
    219   global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
    220   global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
    221   ops.reset_default_graph()
    222   reset_uids()
    223   _SESSION.session = None
    224   graph = get_graph()
    225   with graph.as_default():
    226     with ops.name_scope(''):
    227       phase = array_ops.placeholder_with_default(
    228           False, shape=(), name='keras_learning_phase')
    229     _GRAPH_LEARNING_PHASES = {}
    230     _GRAPH_LEARNING_PHASES[graph] = phase
    231     _GRAPH_VARIABLES.pop(graph, None)
    232     _GRAPH_TF_OPTIMIZERS.pop(graph, None)
    233 
    234 
    235 @keras_export('keras.backend.manual_variable_initialization')
    236 def manual_variable_initialization(value):
    237   """Sets the manual variable initialization flag.
    238 
    239   This boolean flag determines whether
    240   variables should be initialized
    241   as they are instantiated (default), or if
    242   the user should handle the initialization
    243   (e.g. via `tf.initialize_all_variables()`).
    244 
    245   Arguments:
    246       value: Python boolean.
    247   """
    248   global _MANUAL_VAR_INIT
    249   _MANUAL_VAR_INIT = value
    250 
    251 
    252 @keras_export('keras.backend.learning_phase')
    253 def learning_phase():
    254   """Returns the learning phase flag.
    255 
    256   The learning phase flag is a bool tensor (0 = test, 1 = train)
    257   to be passed as input to any Keras function
    258   that uses a different behavior at train time and test time.
    259 
    260   Returns:
    261       Learning phase (scalar integer tensor or Python integer).
    262   """
    263   if ops.get_default_graph() is _GRAPH:
    264     # Don't enter an init_scope for the learning phase if eager execution
    265     # is enabled but we're inside the Keras workspace graph.
    266     return symbolic_learning_phase()
    267   with ops.init_scope():
    268     # We always check & set the learning phase inside the init_scope,
    269     # otherwise the wrong default_graph will be used to look up the learning
    270     # phase inside of functions & defuns.
    271     #
    272     # This is because functions & defuns (both in graph & in eager mode)
    273     # will always execute non-eagerly using a function-specific default
    274     # subgraph.
    275     if context.executing_eagerly():
    276       if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
    277         # Fallback to inference mode as default.
    278         return 0
    279       return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
    280     return symbolic_learning_phase()
    281 
    282 
    283 def symbolic_learning_phase():
    284   graph = get_graph()
    285   with graph.as_default():
    286     if graph not in _GRAPH_LEARNING_PHASES:
    287       with ops.name_scope(''):
    288         phase = array_ops.placeholder_with_default(
    289             False, shape=(), name='keras_learning_phase')
    290       _GRAPH_LEARNING_PHASES[graph] = phase
    291     return _GRAPH_LEARNING_PHASES[graph]
    292 
    293 
    294 @keras_export('keras.backend.set_learning_phase')
    295 def set_learning_phase(value):
    296   """Sets the learning phase to a fixed value.
    297 
    298   Arguments:
    299       value: Learning phase value, either 0 or 1 (integers).
    300 
    301   Raises:
    302       ValueError: if `value` is neither `0` nor `1`.
    303   """
    304   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
    305   if value not in {0, 1}:
    306     raise ValueError('Expected learning phase to be 0 or 1.')
    307   with ops.init_scope():
    308     if context.executing_eagerly():
    309       # In an eager context, the learning phase values applies to both the eager
    310       # context and the internal Keras graph.
    311       _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
    312     _GRAPH_LEARNING_PHASES[get_graph()] = value
    313 
    314 
    315 def set_eager_learning_phase(value):
    316   """Internal utility that sets the learning phase in eager execution only.
    317 
    318   Arguments:
    319       value: Learning phase value, either 0 or 1 (integers).
    320   """
    321   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
    322   assert value in {0, 1}
    323   assert context.executing_eagerly()
    324   _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
    325 
    326 
    327 @keras_export('keras.backend.learning_phase_scope')
    328 @tf_contextlib.contextmanager
    329 def learning_phase_scope(value):
    330   """Provides a scope within which the learning phase is equal to `value`.
    331 
    332   The learning phase gets restored to its original value upon exiting the scope.
    333 
    334   Arguments:
    335      value: Learning phase value, either 0 or 1 (integers).
    336 
    337   Yields:
    338     None.
    339 
    340   Raises:
    341      ValueError: if `value` is neither `0` nor `1`.
    342   """
    343   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
    344   if value not in {0, 1}:
    345     raise ValueError('Expected learning phase to be 0 or 1.')
    346 
    347   with ops.init_scope():
    348     if context.executing_eagerly():
    349       previous_eager_value = _GRAPH_LEARNING_PHASES.get(
    350           _DUMMY_EAGER_GRAPH, None)
    351     previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
    352 
    353   try:
    354     set_learning_phase(value)
    355     yield
    356   finally:
    357     # Restore learning phase to initial value.
    358     with ops.init_scope():
    359       if context.executing_eagerly():
    360         if previous_eager_value is not None:
    361           _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_eager_value
    362         elif _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES:
    363           del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
    364 
    365       graph = get_graph()
    366       if previous_graph_value is not None:
    367         _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
    368       elif graph in _GRAPH_LEARNING_PHASES:
    369         del _GRAPH_LEARNING_PHASES[graph]
    370 
    371 @tf_contextlib.contextmanager
    372 def eager_learning_phase_scope(value):
    373   """Internal scope that sets the learning phase in eager execution only.
    374 
    375   Arguments:
    376       value: Learning phase value, either 0 or 1 (integers).
    377 
    378   Yields:
    379     None.
    380 
    381   Raises:
    382      ValueError: if `value` is neither `0` nor `1`.
    383   """
    384   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
    385   assert value in {0, 1}
    386   assert context.executing_eagerly()
    387   previous_value = learning_phase()
    388   try:
    389     _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
    390     yield
    391   finally:
    392     # Restore learning phase to initial value.
    393     _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
    394 
    395 
    396 def _current_graph(op_input_list):
    397   """Return the graph members of `op_input_list`, or the current graph."""
    398   return ops._get_graph_from_inputs(op_input_list)
    399 
    400 
    401 def _get_session(op_input_list=()):
    402   """Returns the session object for the current thread."""
    403   global _SESSION
    404   default_session = ops.get_default_session()
    405   if default_session is not None:
    406     session = default_session
    407   else:
    408     if ops.inside_function():
    409       raise RuntimeError('Cannot get session inside Tensorflow graph function.')
    410     # If we don't have a session, or that session does not match the current
    411     # graph, create and cache a new session.
    412     if (getattr(_SESSION, 'session', None) is None or
    413         _SESSION.session.graph is not _current_graph(op_input_list)):
    414       # If we are creating the Session inside a tf.distribute.Strategy scope,
    415       # we ask the strategy for the right session options to use.
    416       if distribution_strategy_context.has_strategy():
    417         configure_and_create_distributed_session(
    418             distribution_strategy_context.get_strategy())
    419       else:
    420         _SESSION.session = session_module.Session(
    421             config=get_default_session_config())
    422     session = _SESSION.session
    423   return session
    424 
    425 
    426 @keras_export(v1=['keras.backend.get_session'])
    427 def get_session(op_input_list=()):
    428   """Returns the TF session to be used by the backend.
    429 
    430   If a default TensorFlow session is available, we will return it.
    431 
    432   Else, we will return the global Keras session assuming it matches
    433   the current graph.
    434 
    435   If no global Keras session exists at this point:
    436   we will create a new global session.
    437 
    438   Note that you can manually set the global session
    439   via `K.set_session(sess)`.
    440 
    441   Arguments:
    442       op_input_list: An option sequence of tensors or ops, which will be used
    443         to determine the current graph. Otherwise the default graph will be
    444         used.
    445 
    446   Returns:
    447       A TensorFlow session.
    448   """
    449   session = _get_session(op_input_list)
    450   if not _MANUAL_VAR_INIT:
    451     with session.graph.as_default():
    452       _initialize_variables(session)
    453   return session
    454 
    455 
    456 def get_graph():
    457   if context.executing_eagerly():
    458     global _GRAPH
    459     if _GRAPH is None:
    460       _GRAPH = func_graph.FuncGraph('keras_graph')
    461     return _GRAPH
    462   else:
    463     return ops.get_default_graph()
    464 
    465 
    466 @tf_contextlib.contextmanager
    467 def _scratch_graph(graph=None):
    468   """Retrieve a shared and temporary func graph.
    469 
    470   The eager execution path lifts a subgraph from the keras global graph into
    471   a scratch graph in order to create a function. DistributionStrategies, in
    472   turn, constructs multiple functions as well as a final combined function. In
    473   order for that logic to work correctly, all of the functions need to be
    474   created on the same scratch FuncGraph.
    475 
    476   Args:
    477     graph: A graph to be used as the current scratch graph. If not set then
    478       a scratch graph will either be retrieved or created:
    479 
    480   Yields:
    481     The current scratch graph.
    482   """
    483   global _CURRENT_SCRATCH_GRAPH
    484   if (_CURRENT_SCRATCH_GRAPH is not None and graph is not None and
    485       _CURRENT_SCRATCH_GRAPH is not graph):
    486     raise ValueError('Multiple scratch graphs specified.')
    487 
    488   if _CURRENT_SCRATCH_GRAPH:
    489     yield _CURRENT_SCRATCH_GRAPH
    490     return
    491 
    492   graph = graph or func_graph.FuncGraph('keras_scratch_graph')
    493   try:
    494     _CURRENT_SCRATCH_GRAPH = graph
    495     yield graph
    496   finally:
    497     _CURRENT_SCRATCH_GRAPH = None
    498 
    499 
    500 @keras_export('keras.backend.set_session')
    501 def set_session(session):
    502   """Sets the global TensorFlow session.
    503 
    504   Arguments:
    505       session: A TF Session.
    506   """
    507   global _SESSION
    508   _SESSION.session = session
    509 
    510 
    511 def get_default_session_config():
    512   if not os.environ.get('OMP_NUM_THREADS'):
    513     config = config_pb2.ConfigProto(allow_soft_placement=True)
    514   else:
    515     num_thread = int(os.environ.get('OMP_NUM_THREADS'))
    516     config = config_pb2.ConfigProto(
    517         intra_op_parallelism_threads=num_thread,
    518         inter_op_parallelism_threads=num_thread,
    519         allow_soft_placement=True)
    520   return config
    521 
    522 
    523 # DEVICE MANIPULATION
    524 
    525 
    526 class _TfDeviceCaptureOp(object):
    527   """Class for capturing the TF device scope."""
    528 
    529   def __init__(self):
    530     self.device = None
    531 
    532   def _set_device(self, device):
    533     """This method captures TF's explicit device scope setting."""
    534     self.device = device
    535 
    536 
    537 def _get_current_tf_device():
    538   """Return explicit device of current context, otherwise returns `None`.
    539 
    540   Returns:
    541       If the current device scope is explicitly set, it returns a string with
    542       the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
    543       return `None`.
    544   """
    545   graph = get_graph()
    546   op = _TfDeviceCaptureOp()
    547   graph._apply_device_functions(op)
    548   return op.device
    549 
    550 
    551 def _is_current_explicit_device(device_type):
    552   """Check if the current device is explicitly set on the device type specified.
    553 
    554   Arguments:
    555       device_type: A string containing `GPU` or `CPU` (case-insensitive).
    556 
    557   Returns:
    558       A boolean indicating if the current device scope is explicitly set on the
    559       device type.
    560 
    561   Raises:
    562       ValueError: If the `device_type` string indicates an unsupported device.
    563   """
    564   device_type = device_type.upper()
    565   if device_type not in ['CPU', 'GPU']:
    566     raise ValueError('`device_type` should be either "CPU" or "GPU".')
    567   device = _get_current_tf_device()
    568   return device is not None and device.device_type == device_type.upper()
    569 
    570 
    571 def _get_available_gpus():
    572   """Get a list of available gpu devices (formatted as strings).
    573 
    574   Returns:
    575       A list of available GPU devices.
    576   """
    577   if ops.executing_eagerly_outside_functions():
    578     # Returns names of devices directly.
    579     return [name for name in context.list_devices() if 'GPU' in name]
    580 
    581   global _LOCAL_DEVICES
    582   if _LOCAL_DEVICES is None:
    583     _LOCAL_DEVICES = get_session().list_devices()
    584   return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
    585 
    586 
    587 def _has_nchw_support():
    588   """Check whether the current scope supports NCHW ops.
    589 
    590   TensorFlow does not support NCHW on CPU. Therefore we check if we are not
    591   explicitly put on
    592   CPU, and have GPUs available. In this case there will be soft-placing on the
    593   GPU device.
    594 
    595   Returns:
    596       bool: if the current scope device placement would support nchw
    597   """
    598   explicitly_on_cpu = _is_current_explicit_device('CPU')
    599   gpus_available = bool(_get_available_gpus())
    600   return not explicitly_on_cpu and gpus_available
    601 
    602 
    603 # VARIABLE MANIPULATION
    604 
    605 
    606 def _to_tensor(x, dtype):
    607   """Convert the input `x` to a tensor of type `dtype`.
    608 
    609   Arguments:
    610       x: An object to be converted (numpy array, list, tensors).
    611       dtype: The destination type.
    612 
    613   Returns:
    614       A tensor.
    615   """
    616   return ops.convert_to_tensor(x, dtype=dtype)
    617 
    618 
    619 @keras_export('keras.backend.is_sparse')
    620 def is_sparse(tensor):
    621   """Returns whether a tensor is a sparse tensor.
    622 
    623   Arguments:
    624       tensor: A tensor instance.
    625 
    626   Returns:
    627       A boolean.
    628 
    629   Example:
    630   ```python
    631       >>> from keras import backend as K
    632       >>> a = K.placeholder((2, 2), sparse=False)
    633       >>> print(K.is_sparse(a))
    634       False
    635       >>> b = K.placeholder((2, 2), sparse=True)
    636       >>> print(K.is_sparse(b))
    637       True
    638   ```
    639   """
    640   return isinstance(tensor, sparse_tensor.SparseTensor)
    641 
    642 
    643 @keras_export('keras.backend.to_dense')
    644 def to_dense(tensor):
    645   """Converts a sparse tensor into a dense tensor and returns it.
    646 
    647   Arguments:
    648       tensor: A tensor instance (potentially sparse).
    649 
    650   Returns:
    651       A dense tensor.
    652 
    653   Examples:
    654   ```python
    655       >>> from keras import backend as K
    656       >>> b = K.placeholder((2, 2), sparse=True)
    657       >>> print(K.is_sparse(b))
    658       True
    659       >>> c = K.to_dense(b)
    660       >>> print(K.is_sparse(c))
    661       False
    662   ```
    663   """
    664   if is_sparse(tensor):
    665     return sparse_ops.sparse_tensor_to_dense(tensor)
    666   else:
    667     return tensor
    668 
    669 
    670 name_scope = ops.name_scope
    671 
    672 
    673 @keras_export('keras.backend.variable')
    674 def variable(value, dtype=None, name=None, constraint=None):
    675   """Instantiates a variable and returns it.
    676 
    677   Arguments:
    678       value: Numpy array, initial value of the tensor.
    679       dtype: Tensor type.
    680       name: Optional name string for the tensor.
    681       constraint: Optional projection function to be
    682           applied to the variable after an optimizer update.
    683 
    684   Returns:
    685       A variable instance (with Keras metadata included).
    686 
    687   Examples:
    688   ```python
    689       >>> import numpy as np
    690       >>> from keras import backend as K
    691       >>> val = np.array([[1, 2], [3, 4]])
    692       >>> kvar = K.variable(value=val, dtype='float64', name='example_var')
    693       >>> K.dtype(kvar)
    694       'float64'
    695       >>> print(kvar)
    696       example_var
    697       >>> kvar.eval()
    698       array([[ 1.,  2.],
    699              [ 3.,  4.]])
    700   ```
    701   """
    702   if dtype is None:
    703     dtype = floatx()
    704   if hasattr(value, 'tocoo'):
    705     sparse_coo = value.tocoo()
    706     indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
    707         sparse_coo.col, 1)), 1)
    708     v = sparse_tensor.SparseTensor(
    709         indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
    710     v._keras_shape = sparse_coo.shape
    711     return v
    712   v = resource_variable_ops.ResourceVariable(
    713       value,
    714       dtype=dtypes_module.as_dtype(dtype),
    715       name=name,
    716       constraint=constraint)
    717   if isinstance(value, np.ndarray):
    718     v._keras_shape = value.shape
    719   elif hasattr(value, 'shape'):
    720     v._keras_shape = int_shape(value)
    721   track_variable(v)
    722   return v
    723 
    724 
    725 def track_tf_optimizer(tf_optimizer):
    726   """Tracks the given TF optimizer for initialization of its variables."""
    727   if context.executing_eagerly():
    728     return
    729   graph = get_graph()
    730   optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
    731   optimizers.add(tf_optimizer)
    732 
    733 
    734 def track_variable(v):
    735   """Tracks the given variable for initialization."""
    736   if context.executing_eagerly():
    737     return
    738   graph = v.graph if hasattr(v, 'graph') else get_graph()
    739   if graph not in _GRAPH_VARIABLES:
    740     _GRAPH_VARIABLES[graph] = weakref.WeakSet()
    741   _GRAPH_VARIABLES[graph].add(v)
    742 
    743 
    744 def _get_variables(graph=None):
    745   """Returns variables corresponding to the given graph for initialization."""
    746   assert not context.executing_eagerly()
    747   variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
    748   for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
    749     variables.update(opt.optimizer.variables())
    750   return variables
    751 
    752 
    753 def _initialize_variables(session):
    754   """Utility to initialize uninitialized variables on the fly."""
    755   variables = _get_variables(get_graph())
    756   candidate_vars = []
    757   for v in variables:
    758     if not getattr(v, '_keras_initialized', False):
    759       candidate_vars.append(v)
    760   if candidate_vars:
    761     # This step is expensive, so we only run it on variables not already
    762     # marked as initialized.
    763     is_initialized = session.run(
    764         [variables_module.is_variable_initialized(v) for v in candidate_vars])
    765     uninitialized_vars = []
    766     for flag, v in zip(is_initialized, candidate_vars):
    767       if not flag:
    768         uninitialized_vars.append(v)
    769       v._keras_initialized = True
    770     if uninitialized_vars:
    771       session.run(variables_module.variables_initializer(uninitialized_vars))
    772 
    773 
    774 @keras_export('keras.backend.constant')
    775 def constant(value, dtype=None, shape=None, name=None):
    776   """Creates a constant tensor.
    777 
    778   Arguments:
    779       value: A constant value (or list)
    780       dtype: The type of the elements of the resulting tensor.
    781       shape: Optional dimensions of resulting tensor.
    782       name: Optional name for the tensor.
    783 
    784   Returns:
    785       A Constant Tensor.
    786   """
    787   if dtype is None:
    788     dtype = floatx()
    789 
    790   # If the outer context is eager but we are executing under the keras
    791   # FuncGraph, we create EagerTensors and use them as constants.
    792   if (ops.executing_eagerly_outside_functions() and
    793       getattr(get_graph(), 'name', '') == 'keras_graph'):
    794     with ops.init_scope():
    795       return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
    796 
    797   return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
    798 
    799 
    800 def is_keras_tensor(x):
    801   """Returns whether `x` is a Keras tensor.
    802 
    803   A "Keras tensor" is a tensor that was returned by a Keras layer,
    804   (`Layer` class) or by `Input`.
    805 
    806   Arguments:
    807       x: A candidate tensor.
    808 
    809   Returns:
    810       A boolean: Whether the argument is a Keras tensor.
    811 
    812   Raises:
    813       ValueError: In case `x` is not a symbolic tensor.
    814 
    815   Examples:
    816   ```python
    817       >>> import tensorflow as tf
    818       >>> import numpy
    819       >>> from keras import backend as K
    820       >>> from keras.layers import Input, Dense
    821       >>> np_var = numpy.array([1, 2])
    822       >>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic tensor.
    823       ValueError
    824       >>> k_var = tf.placeholder('float32', shape=(1,1))
    825       >>> K.is_keras_tensor(k_var) # A variable indirectly created outside of
    826       keras is not a Keras tensor.
    827       False
    828       >>> keras_var = K.variable(np_var)
    829       >>> K.is_keras_tensor(keras_var)  # A variable created with the keras
    830       backend is not a Keras tensor.
    831       False
    832       >>> keras_placeholder = K.placeholder(shape=(2, 4, 5))
    833       >>> K.is_keras_tensor(keras_placeholder)  # A placeholder is not a Keras
    834       tensor.
    835       False
    836       >>> keras_input = Input([10])
    837       >>> K.is_keras_tensor(keras_input) # An Input is a Keras tensor.
    838       True
    839       >>> keras_layer_output = Dense(10)(keras_input)
    840       >>> K.is_keras_tensor(keras_layer_output) # Any Keras layer output is a
    841       Keras tensor.
    842       True
    843   ```
    844   """
    845   if not isinstance(x, (ops.Tensor,
    846                         variables_module.Variable,
    847                         sparse_tensor.SparseTensor)):
    848     raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
    849                      '`. Expected a symbolic tensor instance.')
    850   return hasattr(x, '_keras_history')
    851 
    852 
    853 @keras_export('keras.backend.placeholder')
    854 def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
    855   """Instantiates a placeholder tensor and returns it.
    856 
    857   Arguments:
    858       shape: Shape of the placeholder
    859           (integer tuple, may include `None` entries).
    860       ndim: Number of axes of the tensor.
    861           At least one of {`shape`, `ndim`} must be specified.
    862           If both are specified, `shape` is used.
    863       dtype: Placeholder type.
    864       sparse: Boolean, whether the placeholder should have a sparse type.
    865       name: Optional name string for the placeholder.
    866 
    867   Raises:
    868       ValueError: If called with eager execution.
    869 
    870   Returns:
    871       Tensor instance (with Keras metadata included).
    872 
    873   Examples:
    874   ```python
    875       >>> from keras import backend as K
    876       >>> input_ph = K.placeholder(shape=(2, 4, 5))
    877       >>> input_ph
    878       <tf.Tensor 'Placeholder_4:0' shape=(2, 4, 5) dtype=float32>
    879   ```
    880   """
    881   if dtype is None:
    882     dtype = floatx()
    883   if not shape:
    884     if ndim:
    885       shape = tuple([None for _ in range(ndim)])
    886   with get_graph().as_default():
    887     if sparse:
    888       x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
    889     else:
    890       x = array_ops.placeholder(dtype, shape=shape, name=name)
    891   return x
    892 
    893 
    894 def is_placeholder(x):
    895   """Returns whether `x` is a placeholder.
    896 
    897   Arguments:
    898       x: A candidate placeholder.
    899 
    900   Returns:
    901       Boolean.
    902   """
    903   try:
    904     return x.op.type == 'Placeholder'
    905   except AttributeError:
    906     return False
    907 
    908 
    909 @keras_export('keras.backend.shape')
    910 def shape(x):
    911   """Returns the symbolic shape of a tensor or variable.
    912 
    913   Arguments:
    914       x: A tensor or variable.
    915 
    916   Returns:
    917       A symbolic shape (which is itself a tensor).
    918 
    919   Examples:
    920 
    921   ```python
    922       # TensorFlow example
    923       >>> from keras import backend as K
    924       >>> tf_session = K.get_session()
    925       >>> val = np.array([[1, 2], [3, 4]])
    926       >>> kvar = K.variable(value=val)
    927       >>> input = keras.backend.placeholder(shape=(2, 4, 5))
    928       >>> K.shape(kvar)
    929       <tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
    930       >>> K.shape(input)
    931       <tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
    932       # To get integer shape (Instead, you can use K.int_shape(x))
    933       >>> K.shape(kvar).eval(session=tf_session)
    934       array([2, 2], dtype=int32)
    935       >>> K.shape(input).eval(session=tf_session)
    936       array([2, 4, 5], dtype=int32)
    937   ```
    938   """
    939   return array_ops.shape(x)
    940 
    941 
    942 @keras_export('keras.backend.int_shape')
    943 def int_shape(x):
    944   """Returns the shape of tensor or variable as a tuple of int or None entries.
    945 
    946   Arguments:
    947       x: Tensor or variable.
    948 
    949   Returns:
    950       A tuple of integers (or None entries).
    951 
    952   Examples:
    953   ```python
    954       >>> from keras import backend as K
    955       >>> input = K.placeholder(shape=(2, 4, 5))
    956       >>> K.int_shape(input)
    957       (2, 4, 5)
    958       >>> val = np.array([[1, 2], [3, 4]])
    959       >>> kvar = K.variable(value=val)
    960       >>> K.int_shape(kvar)
    961       (2, 2)
    962   ```
    963   """
    964   try:
    965     shape = x.shape
    966     if not isinstance(shape, tuple):
    967       shape = tuple(shape.as_list())
    968     return shape
    969   except ValueError:
    970     return None
    971 
    972 
    973 @keras_export('keras.backend.ndim')
    974 def ndim(x):
    975   """Returns the number of axes in a tensor, as an integer.
    976 
    977   Arguments:
    978       x: Tensor or variable.
    979 
    980   Returns:
    981       Integer (scalar), number of axes.
    982 
    983   Examples:
    984   ```python
    985       >>> from keras import backend as K
    986       >>> input = K.placeholder(shape=(2, 4, 5))
    987       >>> val = np.array([[1, 2], [3, 4]])
    988       >>> kvar = K.variable(value=val)
    989       >>> K.ndim(input)
    990       3
    991       >>> K.ndim(kvar)
    992       2
    993   ```
    994   """
    995   dims = x.shape._dims
    996   if dims is not None:
    997     return len(dims)
    998   return None
    999 
   1000 
   1001 @keras_export('keras.backend.dtype')
   1002 def dtype(x):
   1003   """Returns the dtype of a Keras tensor or variable, as a string.
   1004 
   1005   Arguments:
   1006       x: Tensor or variable.
   1007 
   1008   Returns:
   1009       String, dtype of `x`.
   1010 
   1011   Examples:
   1012   ```python
   1013       >>> from keras import backend as K
   1014       >>> K.dtype(K.placeholder(shape=(2,4,5)))
   1015       'float32'
   1016       >>> K.dtype(K.placeholder(shape=(2,4,5), dtype='float32'))
   1017       'float32'
   1018       >>> K.dtype(K.placeholder(shape=(2,4,5), dtype='float64'))
   1019       'float64'
   1020       # Keras variable
   1021       >>> kvar = K.variable(np.array([[1, 2], [3, 4]]))
   1022       >>> K.dtype(kvar)
   1023       'float32'
   1024       >>> kvar = K.variable(np.array([[1, 2], [3, 4]]), dtype='float32')
   1025       >>> K.dtype(kvar)
   1026       'float32'
   1027   ```
   1028   """
   1029   return x.dtype.base_dtype.name
   1030 
   1031 
   1032 @keras_export('keras.backend.eval')
   1033 def eval(x):
   1034   """Evaluates the value of a variable.
   1035 
   1036   Arguments:
   1037       x: A variable.
   1038 
   1039   Returns:
   1040       A Numpy array.
   1041 
   1042   Examples:
   1043   ```python
   1044       >>> from keras import backend as K
   1045       >>> kvar = K.variable(np.array([[1, 2], [3, 4]]), dtype='float32')
   1046       >>> K.eval(kvar)
   1047       array([[ 1.,  2.],
   1048              [ 3.,  4.]], dtype=float32)
   1049   ```
   1050   """
   1051   return get_value(to_dense(x))
   1052 
   1053 
   1054 @keras_export('keras.backend.zeros')
   1055 def zeros(shape, dtype=None, name=None):
   1056   """Instantiates an all-zeros variable and returns it.
   1057 
   1058   Arguments:
   1059       shape: Tuple of integers, shape of returned Keras variable
   1060       dtype: String, data type of returned Keras variable
   1061       name: String, name of returned Keras variable
   1062 
   1063   Returns:
   1064       A variable (including Keras metadata), filled with `0.0`.
   1065       Note that if `shape` was symbolic, we cannot return a variable,
   1066       and will return a dynamically-shaped tensor instead.
   1067 
   1068   Example:
   1069   ```python
   1070       >>> from keras import backend as K
   1071       >>> kvar = K.zeros((3,4))
   1072       >>> K.eval(kvar)
   1073       array([[ 0.,  0.,  0.,  0.],
   1074              [ 0.,  0.,  0.,  0.],
   1075              [ 0.,  0.,  0.,  0.]], dtype=float32)
   1076   ```
   1077   """
   1078   with ops.init_scope():
   1079     if dtype is None:
   1080       dtype = floatx()
   1081     tf_dtype = dtypes_module.as_dtype(dtype)
   1082     v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
   1083     if py_all(v.shape.as_list()):
   1084       return variable(v, dtype=dtype, name=name)
   1085     track_variable(v)
   1086     return v
   1087 
   1088 
   1089 @keras_export('keras.backend.ones')
   1090 def ones(shape, dtype=None, name=None):
   1091   """Instantiates an all-ones variable and returns it.
   1092 
   1093   Arguments:
   1094       shape: Tuple of integers, shape of returned Keras variable.
   1095       dtype: String, data type of returned Keras variable.
   1096       name: String, name of returned Keras variable.
   1097 
   1098   Returns:
   1099       A Keras variable, filled with `1.0`.
   1100       Note that if `shape` was symbolic, we cannot return a variable,
   1101       and will return a dynamically-shaped tensor instead.
   1102 
   1103   Example:
   1104   ```python
   1105       >>> from keras import backend as K
   1106       >>> kvar = K.ones((3,4))
   1107       >>> K.eval(kvar)
   1108       array([[ 1.,  1.,  1.,  1.],
   1109              [ 1.,  1.,  1.,  1.],
   1110              [ 1.,  1.,  1.,  1.]], dtype=float32)
   1111   ```
   1112   """
   1113   with ops.init_scope():
   1114     if dtype is None:
   1115       dtype = floatx()
   1116     tf_dtype = dtypes_module.as_dtype(dtype)
   1117     v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
   1118     if py_all(v.shape.as_list()):
   1119       return variable(v, dtype=dtype, name=name)
   1120     track_variable(v)
   1121     return v
   1122 
   1123 
   1124 @keras_export('keras.backend.eye')
   1125 def eye(size, dtype=None, name=None):
   1126   """Instantiate an identity matrix and returns it.
   1127 
   1128   Arguments:
   1129       size: Integer, number of rows/columns.
   1130       dtype: String, data type of returned Keras variable.
   1131       name: String, name of returned Keras variable.
   1132 
   1133   Returns:
   1134       A Keras variable, an identity matrix.
   1135 
   1136   Example:
   1137   ```python
   1138       >>> from keras import backend as K
   1139       >>> kvar = K.eye(3)
   1140       >>> K.eval(kvar)
   1141       array([[ 1.,  0.,  0.],
   1142              [ 0.,  1.,  0.],
   1143              [ 0.,  0.,  1.]], dtype=float32)
   1144   ```
   1145 
   1146   """
   1147   if dtype is None:
   1148     dtype = floatx()
   1149   tf_dtype = dtypes_module.as_dtype(dtype)
   1150   return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
   1151 
   1152 
   1153 @keras_export('keras.backend.zeros_like')
   1154 def zeros_like(x, dtype=None, name=None):
   1155   """Instantiates an all-zeros variable of the same shape as another tensor.
   1156 
   1157   Arguments:
   1158       x: Keras variable or Keras tensor.
   1159       dtype: String, dtype of returned Keras variable.
   1160            None uses the dtype of x.
   1161       name: String, name for the variable to create.
   1162 
   1163   Returns:
   1164       A Keras variable with the shape of x filled with zeros.
   1165 
   1166   Example:
   1167   ```python
   1168       >>> from keras import backend as K
   1169       >>> kvar = K.variable(np.random.random((2,3)))
   1170       >>> kvar_zeros = K.zeros_like(kvar)
   1171       >>> K.eval(kvar_zeros)
   1172       array([[ 0.,  0.,  0.],
   1173              [ 0.,  0.,  0.]], dtype=float32)
   1174   ```
   1175   """
   1176   return array_ops.zeros_like(x, dtype=dtype, name=name)
   1177 
   1178 
   1179 @keras_export('keras.backend.ones_like')
   1180 def ones_like(x, dtype=None, name=None):
   1181   """Instantiates an all-ones variable of the same shape as another tensor.
   1182 
   1183   Arguments:
   1184       x: Keras variable or tensor.
   1185       dtype: String, dtype of returned Keras variable.
   1186            None uses the dtype of x.
   1187       name: String, name for the variable to create.
   1188 
   1189   Returns:
   1190       A Keras variable with the shape of x filled with ones.
   1191 
   1192   Example:
   1193   ```python
   1194       >>> from keras import backend as K
   1195       >>> kvar = K.variable(np.random.random((2,3)))
   1196       >>> kvar_ones = K.ones_like(kvar)
   1197       >>> K.eval(kvar_ones)
   1198       array([[ 1.,  1.,  1.],
   1199              [ 1.,  1.,  1.]], dtype=float32)
   1200   ```
   1201   """
   1202   return array_ops.ones_like(x, dtype=dtype, name=name)
   1203 
   1204 
   1205 def identity(x, name=None):
   1206   """Returns a tensor with the same content as the input tensor.
   1207 
   1208   Arguments:
   1209       x: The input tensor.
   1210       name: String, name for the variable to create.
   1211 
   1212   Returns:
   1213       A tensor of the same shape, type and content.
   1214   """
   1215   return array_ops.identity(x, name=name)
   1216 
   1217 
   1218 @keras_export('keras.backend.random_uniform_variable')
   1219 def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
   1220   """Instantiates a variable with values drawn from a uniform distribution.
   1221 
   1222   Arguments:
   1223       shape: Tuple of integers, shape of returned Keras variable.
   1224       low: Float, lower boundary of the output interval.
   1225       high: Float, upper boundary of the output interval.
   1226       dtype: String, dtype of returned Keras variable.
   1227       name: String, name of returned Keras variable.
   1228       seed: Integer, random seed.
   1229 
   1230   Returns:
   1231       A Keras variable, filled with drawn samples.
   1232 
   1233   Example:
   1234   ```python
   1235       # TensorFlow example
   1236       >>> kvar = K.random_uniform_variable((2,3), 0, 1)
   1237       >>> kvar
   1238       <tensorflow.python.ops.variables.Variable object at 0x10ab40b10>
   1239       >>> K.eval(kvar)
   1240       array([[ 0.10940075,  0.10047495,  0.476143  ],
   1241              [ 0.66137183,  0.00869417,  0.89220798]], dtype=float32)
   1242   ```
   1243   """
   1244   if dtype is None:
   1245     dtype = floatx()
   1246   tf_dtype = dtypes_module.as_dtype(dtype)
   1247   if seed is None:
   1248     # ensure that randomness is conditioned by the Numpy RNG
   1249     seed = np.random.randint(10e8)
   1250   value = init_ops.random_uniform_initializer(
   1251       low, high, dtype=tf_dtype, seed=seed)(shape)
   1252   return variable(value, dtype=dtype, name=name)
   1253 
   1254 
   1255 @keras_export('keras.backend.random_normal_variable')
   1256 def random_normal_variable(shape, mean, scale, dtype=None, name=None,
   1257                            seed=None):
   1258   """Instantiates a variable with values drawn from a normal distribution.
   1259 
   1260   Arguments:
   1261       shape: Tuple of integers, shape of returned Keras variable.
   1262       mean: Float, mean of the normal distribution.
   1263       scale: Float, standard deviation of the normal distribution.
   1264       dtype: String, dtype of returned Keras variable.
   1265       name: String, name of returned Keras variable.
   1266       seed: Integer, random seed.
   1267 
   1268   Returns:
   1269       A Keras variable, filled with drawn samples.
   1270 
   1271   Example:
   1272   ```python
   1273       # TensorFlow example
   1274       >>> kvar = K.random_normal_variable((2,3), 0, 1)
   1275       >>> kvar
   1276       <tensorflow.python.ops.variables.Variable object at 0x10ab12dd0>
   1277       >>> K.eval(kvar)
   1278       array([[ 1.19591331,  0.68685907, -0.63814116],
   1279              [ 0.92629528,  0.28055015,  1.70484698]], dtype=float32)
   1280   ```
   1281   """
   1282   if dtype is None:
   1283     dtype = floatx()
   1284   tf_dtype = dtypes_module.as_dtype(dtype)
   1285   if seed is None:
   1286     # ensure that randomness is conditioned by the Numpy RNG
   1287     seed = np.random.randint(10e8)
   1288   value = init_ops.random_normal_initializer(
   1289       mean, scale, dtype=tf_dtype, seed=seed)(shape)
   1290   return variable(value, dtype=dtype, name=name)
   1291 
   1292 
   1293 @keras_export('keras.backend.count_params')
   1294 def count_params(x):
   1295   """Returns the static number of elements in a variable or tensor.
   1296 
   1297   Arguments:
   1298       x: Variable or tensor.
   1299 
   1300   Returns:
   1301       Integer, the number of scalars in `x`.
   1302 
   1303   Example:
   1304   ```python
   1305       >>> kvar = K.zeros((2,3))
   1306       >>> K.count_params(kvar)
   1307       6
   1308       >>> K.eval(kvar)
   1309       array([[ 0.,  0.,  0.],
   1310              [ 0.,  0.,  0.]], dtype=float32)
   1311   ```
   1312   """
   1313   return np.prod(x.shape.as_list())
   1314 
   1315 
   1316 @keras_export('keras.backend.cast')
   1317 def cast(x, dtype):
   1318   """Casts a tensor to a different dtype and returns it.
   1319 
   1320   You can cast a Keras variable but it still returns a Keras tensor.
   1321 
   1322   Arguments:
   1323       x: Keras tensor (or variable).
   1324       dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
   1325 
   1326   Returns:
   1327       Keras tensor with dtype `dtype`.
   1328 
   1329   Example:
   1330   ```python
   1331       >>> from keras import backend as K
   1332       >>> input = K.placeholder((2, 3), dtype='float32')
   1333       >>> input
   1334       <tf.Tensor 'Placeholder_2:0' shape=(2, 3) dtype=float32>
   1335       # It doesn't work in-place as below.
   1336       >>> K.cast(input, dtype='float16')
   1337       <tf.Tensor 'Cast_1:0' shape=(2, 3) dtype=float16>
   1338       >>> input
   1339       <tf.Tensor 'Placeholder_2:0' shape=(2, 3) dtype=float32>
   1340       # you need to assign it.
   1341       >>> input = K.cast(input, dtype='float16')
   1342       >>> input
   1343       <tf.Tensor 'Cast_2:0' shape=(2, 3) dtype=float16>
   1344   ```
   1345   """
   1346   return math_ops.cast(x, dtype)
   1347 
   1348 
   1349 # UPDATES OPS
   1350 
   1351 
   1352 @keras_export('keras.backend.update')
   1353 def update(x, new_x):
   1354   return state_ops.assign(x, new_x)
   1355 
   1356 
   1357 @keras_export('keras.backend.update_add')
   1358 def update_add(x, increment):
   1359   """Update the value of `x` by adding `increment`.
   1360 
   1361   Arguments:
   1362       x: A Variable.
   1363       increment: A tensor of same shape as `x`.
   1364 
   1365   Returns:
   1366       The variable `x` updated.
   1367   """
   1368   return state_ops.assign_add(x, increment)
   1369 
   1370 
   1371 @keras_export('keras.backend.update_sub')
   1372 def update_sub(x, decrement):
   1373   """Update the value of `x` by subtracting `decrement`.
   1374 
   1375   Arguments:
   1376       x: A Variable.
   1377       decrement: A tensor of same shape as `x`.
   1378 
   1379   Returns:
   1380       The variable `x` updated.
   1381   """
   1382   return state_ops.assign_sub(x, decrement)
   1383 
   1384 
   1385 @keras_export('keras.backend.moving_average_update')
   1386 def moving_average_update(x, value, momentum):
   1387   """Compute the moving average of a variable.
   1388 
   1389   Arguments:
   1390       x: A Variable.
   1391       value: A tensor with the same shape as `variable`.
   1392       momentum: The moving average momentum.
   1393 
   1394   Returns:
   1395       An Operation to update the variable.
   1396   """
   1397   # `training` is higher-up than the Keras backend in the abstraction hierarchy.
   1398   # In particular, `training` depends on layers, and thus on Keras.
   1399   # moving_averages, being low-level ops, should not be part of the training
   1400   # module.
   1401   from tensorflow.python.training import moving_averages  # pylint: disable=g-import-not-at-top
   1402   return moving_averages.assign_moving_average(
   1403       x, value, momentum, zero_debias=True)
   1404 
   1405 
   1406 # LINEAR ALGEBRA
   1407 
   1408 
   1409 @keras_export('keras.backend.dot')
   1410 def dot(x, y):
   1411   """Multiplies 2 tensors (and/or variables) and returns a *tensor*.
   1412 
   1413   When attempting to multiply a nD tensor
   1414   with a nD tensor, it reproduces the Theano behavior.
   1415   (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`)
   1416 
   1417   Arguments:
   1418       x: Tensor or variable.
   1419       y: Tensor or variable.
   1420 
   1421   Returns:
   1422       A tensor, dot product of `x` and `y`.
   1423 
   1424   Examples:
   1425   ```python
   1426       # dot product between tensors
   1427       >>> x = K.placeholder(shape=(2, 3))
   1428       >>> y = K.placeholder(shape=(3, 4))
   1429       >>> xy = K.dot(x, y)
   1430       >>> xy
   1431       <tf.Tensor 'MatMul_9:0' shape=(2, 4) dtype=float32>
   1432   ```
   1433 
   1434   ```python
   1435       # dot product between tensors
   1436       >>> x = K.placeholder(shape=(32, 28, 3))
   1437       >>> y = K.placeholder(shape=(3, 4))
   1438       >>> xy = K.dot(x, y)
   1439       >>> xy
   1440       <tf.Tensor 'MatMul_9:0' shape=(32, 28, 4) dtype=float32>
   1441   ```
   1442 
   1443   ```python
   1444       # Theano-like behavior example
   1445       >>> x = K.random_uniform_variable(shape=(2, 3), low=0, high=1)
   1446       >>> y = K.ones((4, 3, 5))
   1447       >>> xy = K.dot(x, y)
   1448       >>> K.int_shape(xy)
   1449       (2, 4, 5)
   1450   ```
   1451   """
   1452   if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
   1453     x_shape = []
   1454     for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
   1455       if i is not None:
   1456         x_shape.append(i)
   1457       else:
   1458         x_shape.append(s)
   1459     x_shape = tuple(x_shape)
   1460     y_shape = []
   1461     for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
   1462       if i is not None:
   1463         y_shape.append(i)
   1464       else:
   1465         y_shape.append(s)
   1466     y_shape = tuple(y_shape)
   1467     y_permute_dim = list(range(ndim(y)))
   1468     y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
   1469     xt = array_ops.reshape(x, [-1, x_shape[-1]])
   1470     yt = array_ops.reshape(
   1471         array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
   1472     return array_ops.reshape(
   1473         math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
   1474   if is_sparse(x):
   1475     out = sparse_ops.sparse_tensor_dense_matmul(x, y)
   1476   else:
   1477     out = math_ops.matmul(x, y)
   1478   return out
   1479 
   1480 
   1481 @keras_export('keras.backend.batch_dot')
   1482 def batch_dot(x, y, axes=None):
   1483   """Batchwise dot product.
   1484 
   1485   `batch_dot` is used to compute dot product of `x` and `y` when
   1486   `x` and `y` are data in batch, i.e. in a shape of
   1487   `(batch_size, :)`.
   1488   `batch_dot` results in a tensor or variable with less dimensions
   1489   than the input. If the number of dimensions is reduced to 1,
   1490   we use `expand_dims` to make sure that ndim is at least 2.
   1491 
   1492   Arguments:
   1493       x: Keras tensor or variable with `ndim >= 2`.
   1494       y: Keras tensor or variable with `ndim >= 2`.
   1495       axes: list of (or single) int with target dimensions.
   1496           The lengths of `axes[0]` and `axes[1]` should be the same.
   1497 
   1498   Returns:
   1499       A tensor with shape equal to the concatenation of `x`'s shape
   1500       (less the dimension that was summed over) and `y`'s shape
   1501       (less the batch dimension and the dimension that was summed over).
   1502       If the final rank is 1, we reshape it to `(batch_size, 1)`.
   1503 
   1504   Examples:
   1505       Assume `x = [[1, 2], [3, 4]]` and `y = [[5, 6], [7, 8]]`
   1506       `batch_dot(x, y, axes=1) = [[17, 53]]` which is the main diagonal
   1507       of `x.dot(y.T)`, although we never have to calculate the off-diagonal
   1508       elements.
   1509 
   1510       Shape inference:
   1511       Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
   1512       If `axes` is (1, 2), to find the output shape of resultant tensor,
   1513           loop through each dimension in `x`'s shape and `y`'s shape:
   1514 
   1515       * `x.shape[0]` : 100 : append to output shape
   1516       * `x.shape[1]` : 20 : do not append to output shape,
   1517           dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
   1518       * `y.shape[0]` : 100 : do not append to output shape,
   1519           always ignore first dimension of `y`
   1520       * `y.shape[1]` : 30 : append to output shape
   1521       * `y.shape[2]` : 20 : do not append to output shape,
   1522           dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
   1523       `output_shape` = `(100, 30)`
   1524 
   1525   ```python
   1526       >>> x_batch = K.ones(shape=(32, 20, 1))
   1527       >>> y_batch = K.ones(shape=(32, 30, 20))
   1528       >>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=[1, 2])
   1529       >>> K.int_shape(xy_batch_dot)
   1530       (32, 1, 30)
   1531   ```
   1532   """
   1533   if isinstance(axes, int):
   1534     axes = (axes, axes)
   1535   x_ndim = ndim(x)
   1536   y_ndim = ndim(y)
   1537   if axes is None:
   1538     # behaves like tf.batch_matmul as default
   1539     axes = [x_ndim - 1, y_ndim - 2]
   1540   if x_ndim > y_ndim:
   1541     diff = x_ndim - y_ndim
   1542     y = array_ops.reshape(y,
   1543                           array_ops.concat(
   1544                               [array_ops.shape(y), [1] * (diff)], axis=0))
   1545   elif y_ndim > x_ndim:
   1546     diff = y_ndim - x_ndim
   1547     x = array_ops.reshape(x,
   1548                           array_ops.concat(
   1549                               [array_ops.shape(x), [1] * (diff)], axis=0))
   1550   else:
   1551     diff = 0
   1552   if ndim(x) == 2 and ndim(y) == 2:
   1553     if axes[0] == axes[1]:
   1554       out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
   1555     else:
   1556       out = math_ops.reduce_sum(
   1557           math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
   1558   else:
   1559     adj_x = None if axes[0] == ndim(x) - 1 else True
   1560     adj_y = True if axes[1] == ndim(y) - 1 else None
   1561     out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
   1562   if diff:
   1563     if x_ndim > y_ndim:
   1564       idx = x_ndim + y_ndim - 3
   1565     else:
   1566       idx = x_ndim - 1
   1567     out = array_ops.squeeze(out, list(range(idx, idx + diff)))
   1568   if ndim(out) == 1:
   1569     out = expand_dims(out, 1)
   1570   return out
   1571 
   1572 
   1573 @keras_export('keras.backend.transpose')
   1574 def transpose(x):
   1575   """Transposes a tensor and returns it.
   1576 
   1577   Arguments:
   1578       x: Tensor or variable.
   1579 
   1580   Returns:
   1581       A tensor.
   1582 
   1583   Examples:
   1584   ```python
   1585       >>> var = K.variable([[1, 2, 3], [4, 5, 6]])
   1586       >>> K.eval(var)
   1587       array([[ 1.,  2.,  3.],
   1588              [ 4.,  5.,  6.]], dtype=float32)
   1589       >>> var_transposed = K.transpose(var)
   1590       >>> K.eval(var_transposed)
   1591       array([[ 1.,  4.],
   1592              [ 2.,  5.],
   1593              [ 3.,  6.]], dtype=float32)
   1594   ```
   1595 
   1596   ```python
   1597       >>> input = K.placeholder((2, 3))
   1598       >>> input
   1599       <tf.Tensor 'Placeholder_11:0' shape=(2, 3) dtype=float32>
   1600       >>> input_transposed = K.transpose(input)
   1601       >>> input_transposed
   1602       <tf.Tensor 'transpose_4:0' shape=(3, 2) dtype=float32>
   1603 
   1604   ```
   1605   """
   1606   return array_ops.transpose(x)
   1607 
   1608 
   1609 @keras_export('keras.backend.gather')
   1610 def gather(reference, indices):
   1611   """Retrieves the elements of indices `indices` in the tensor `reference`.
   1612 
   1613   Arguments:
   1614       reference: A tensor.
   1615       indices: An integer tensor of indices.
   1616 
   1617   Returns:
   1618       A tensor of same type as `reference`.
   1619   """
   1620   return array_ops.gather(reference, indices)
   1621 
   1622 
   1623 # ELEMENT-WISE OPERATIONS
   1624 
   1625 
   1626 @keras_export('keras.backend.max')
   1627 def max(x, axis=None, keepdims=False):
   1628   """Maximum value in a tensor.
   1629 
   1630   Arguments:
   1631       x: A tensor or variable.
   1632       axis: An integer, the axis to find maximum values.
   1633       keepdims: A boolean, whether to keep the dimensions or not.
   1634           If `keepdims` is `False`, the rank of the tensor is reduced
   1635           by 1. If `keepdims` is `True`,
   1636           the reduced dimension is retained with length 1.
   1637 
   1638   Returns:
   1639       A tensor with maximum values of `x`.
   1640   """
   1641   return math_ops.reduce_max(x, axis, keepdims)
   1642 
   1643 
   1644 @keras_export('keras.backend.min')
   1645 def min(x, axis=None, keepdims=False):
   1646   """Minimum value in a tensor.
   1647 
   1648   Arguments:
   1649       x: A tensor or variable.
   1650       axis: An integer, the axis to find minimum values.
   1651       keepdims: A boolean, whether to keep the dimensions or not.
   1652           If `keepdims` is `False`, the rank of the tensor is reduced
   1653           by 1. If `keepdims` is `True`,
   1654           the reduced dimension is retained with length 1.
   1655 
   1656   Returns:
   1657       A tensor with minimum values of `x`.
   1658   """
   1659   return math_ops.reduce_min(x, axis, keepdims)
   1660 
   1661 
   1662 @keras_export('keras.backend.sum')
   1663 def sum(x, axis=None, keepdims=False):
   1664   """Sum of the values in a tensor, alongside the specified axis.
   1665 
   1666   Arguments:
   1667       x: A tensor or variable.
   1668       axis: An integer, the axis to sum over.
   1669       keepdims: A boolean, whether to keep the dimensions or not.
   1670           If `keepdims` is `False`, the rank of the tensor is reduced
   1671           by 1. If `keepdims` is `True`,
   1672           the reduced dimension is retained with length 1.
   1673 
   1674   Returns:
   1675       A tensor with sum of `x`.
   1676   """
   1677   return math_ops.reduce_sum(x, axis, keepdims)
   1678 
   1679 
   1680 @keras_export('keras.backend.prod')
   1681 def prod(x, axis=None, keepdims=False):
   1682   """Multiplies the values in a tensor, alongside the specified axis.
   1683 
   1684   Arguments:
   1685       x: A tensor or variable.
   1686       axis: An integer, the axis to compute the product.
   1687       keepdims: A boolean, whether to keep the dimensions or not.
   1688           If `keepdims` is `False`, the rank of the tensor is reduced
   1689           by 1. If `keepdims` is `True`,
   1690           the reduced dimension is retained with length 1.
   1691 
   1692   Returns:
   1693       A tensor with the product of elements of `x`.
   1694   """
   1695   return math_ops.reduce_prod(x, axis, keepdims)
   1696 
   1697 
   1698 @keras_export('keras.backend.cumsum')
   1699 def cumsum(x, axis=0):
   1700   """Cumulative sum of the values in a tensor, alongside the specified axis.
   1701 
   1702   Arguments:
   1703       x: A tensor or variable.
   1704       axis: An integer, the axis to compute the sum.
   1705 
   1706   Returns:
   1707       A tensor of the cumulative sum of values of `x` along `axis`.
   1708   """
   1709   return math_ops.cumsum(x, axis=axis)
   1710 
   1711 
   1712 @keras_export('keras.backend.cumprod')
   1713 def cumprod(x, axis=0):
   1714   """Cumulative product of the values in a tensor, alongside the specified axis.
   1715 
   1716   Arguments:
   1717       x: A tensor or variable.
   1718       axis: An integer, the axis to compute the product.
   1719 
   1720   Returns:
   1721       A tensor of the cumulative product of values of `x` along `axis`.
   1722   """
   1723   return math_ops.cumprod(x, axis=axis)
   1724 
   1725 
   1726 @keras_export('keras.backend.var')
   1727 def var(x, axis=None, keepdims=False):
   1728   """Variance of a tensor, alongside the specified axis.
   1729 
   1730   Arguments:
   1731       x: A tensor or variable.
   1732       axis: An integer, the axis to compute the variance.
   1733       keepdims: A boolean, whether to keep the dimensions or not.
   1734           If `keepdims` is `False`, the rank of the tensor is reduced
   1735           by 1. If `keepdims` is `True`,
   1736           the reduced dimension is retained with length 1.
   1737 
   1738   Returns:
   1739       A tensor with the variance of elements of `x`.
   1740   """
   1741   if x.dtype.base_dtype == dtypes_module.bool:
   1742     x = math_ops.cast(x, floatx())
   1743   return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
   1744 
   1745 
   1746 @keras_export('keras.backend.std')
   1747 def std(x, axis=None, keepdims=False):
   1748   """Standard deviation of a tensor, alongside the specified axis.
   1749 
   1750   Arguments:
   1751       x: A tensor or variable.
   1752       axis: An integer, the axis to compute the standard deviation.
   1753       keepdims: A boolean, whether to keep the dimensions or not.
   1754           If `keepdims` is `False`, the rank of the tensor is reduced
   1755           by 1. If `keepdims` is `True`,
   1756           the reduced dimension is retained with length 1.
   1757 
   1758   Returns:
   1759       A tensor with the standard deviation of elements of `x`.
   1760   """
   1761   if x.dtype.base_dtype == dtypes_module.bool:
   1762     x = math_ops.cast(x, floatx())
   1763   return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
   1764 
   1765 
   1766 @keras_export('keras.backend.mean')
   1767 def mean(x, axis=None, keepdims=False):
   1768   """Mean of a tensor, alongside the specified axis.
   1769 
   1770   Arguments:
   1771       x: A tensor or variable.
   1772       axis: A list of integer. Axes to compute the mean.
   1773       keepdims: A boolean, whether to keep the dimensions or not.
   1774           If `keepdims` is `False`, the rank of the tensor is reduced
   1775           by 1 for each entry in `axis`. If `keepdims` is `True`,
   1776           the reduced dimensions are retained with length 1.
   1777 
   1778   Returns:
   1779       A tensor with the mean of elements of `x`.
   1780   """
   1781   if x.dtype.base_dtype == dtypes_module.bool:
   1782     x = math_ops.cast(x, floatx())
   1783   return math_ops.reduce_mean(x, axis, keepdims)
   1784 
   1785 
   1786 @keras_export('keras.backend.any')
   1787 def any(x, axis=None, keepdims=False):
   1788   """Bitwise reduction (logical OR).
   1789 
   1790   Arguments:
   1791       x: Tensor or variable.
   1792       axis: axis along which to perform the reduction.
   1793       keepdims: whether the drop or broadcast the reduction axes.
   1794 
   1795   Returns:
   1796       A uint8 tensor (0s and 1s).
   1797   """
   1798   x = math_ops.cast(x, dtypes_module.bool)
   1799   return math_ops.reduce_any(x, axis, keepdims)
   1800 
   1801 
   1802 @keras_export('keras.backend.all')
   1803 def all(x, axis=None, keepdims=False):
   1804   """Bitwise reduction (logical AND).
   1805 
   1806   Arguments:
   1807       x: Tensor or variable.
   1808       axis: axis along which to perform the reduction.
   1809       keepdims: whether the drop or broadcast the reduction axes.
   1810 
   1811   Returns:
   1812       A uint8 tensor (0s and 1s).
   1813   """
   1814   x = math_ops.cast(x, dtypes_module.bool)
   1815   return math_ops.reduce_all(x, axis, keepdims)
   1816 
   1817 
   1818 @keras_export('keras.backend.argmax')
   1819 def argmax(x, axis=-1):
   1820   """Returns the index of the maximum value along an axis.
   1821 
   1822   Arguments:
   1823       x: Tensor or variable.
   1824       axis: axis along which to perform the reduction.
   1825 
   1826   Returns:
   1827       A tensor.
   1828   """
   1829   return math_ops.argmax(x, axis)
   1830 
   1831 
   1832 @keras_export('keras.backend.argmin')
   1833 def argmin(x, axis=-1):
   1834   """Returns the index of the minimum value along an axis.
   1835 
   1836   Arguments:
   1837       x: Tensor or variable.
   1838       axis: axis along which to perform the reduction.
   1839 
   1840   Returns:
   1841       A tensor.
   1842   """
   1843   return math_ops.argmin(x, axis)
   1844 
   1845 
   1846 @keras_export('keras.backend.square')
   1847 def square(x):
   1848   """Element-wise square.
   1849 
   1850   Arguments:
   1851       x: Tensor or variable.
   1852 
   1853   Returns:
   1854       A tensor.
   1855   """
   1856   return math_ops.square(x)
   1857 
   1858 
   1859 @keras_export('keras.backend.abs')
   1860 def abs(x):
   1861   """Element-wise absolute value.
   1862 
   1863   Arguments:
   1864       x: Tensor or variable.
   1865 
   1866   Returns:
   1867       A tensor.
   1868   """
   1869   return math_ops.abs(x)
   1870 
   1871 
   1872 @keras_export('keras.backend.sqrt')
   1873 def sqrt(x):
   1874   """Element-wise square root.
   1875 
   1876   Arguments:
   1877       x: Tensor or variable.
   1878 
   1879   Returns:
   1880       A tensor.
   1881   """
   1882   zero = _to_tensor(0., x.dtype.base_dtype)
   1883   inf = _to_tensor(np.inf, x.dtype.base_dtype)
   1884   x = clip_ops.clip_by_value(x, zero, inf)
   1885   return math_ops.sqrt(x)
   1886 
   1887 
   1888 @keras_export('keras.backend.exp')
   1889 def exp(x):
   1890   """Element-wise exponential.
   1891 
   1892   Arguments:
   1893       x: Tensor or variable.
   1894 
   1895   Returns:
   1896       A tensor.
   1897   """
   1898   return math_ops.exp(x)
   1899 
   1900 
   1901 @keras_export('keras.backend.log')
   1902 def log(x):
   1903   """Element-wise log.
   1904 
   1905   Arguments:
   1906       x: Tensor or variable.
   1907 
   1908   Returns:
   1909       A tensor.
   1910   """
   1911   return math_ops.log(x)
   1912 
   1913 
   1914 def logsumexp(x, axis=None, keepdims=False):
   1915   """Computes log(sum(exp(elements across dimensions of a tensor))).
   1916 
   1917   This function is more numerically stable than log(sum(exp(x))).
   1918   It avoids overflows caused by taking the exp of large inputs and
   1919   underflows caused by taking the log of small inputs.
   1920 
   1921   Arguments:
   1922       x: A tensor or variable.
   1923       axis: An integer, the axis to reduce over.
   1924       keepdims: A boolean, whether to keep the dimensions or not.
   1925           If `keepdims` is `False`, the rank of the tensor is reduced
   1926           by 1. If `keepdims` is `True`, the reduced dimension is
   1927           retained with length 1.
   1928 
   1929   Returns:
   1930       The reduced tensor.
   1931   """
   1932   return math_ops.reduce_logsumexp(x, axis, keepdims)
   1933 
   1934 
   1935 @keras_export('keras.backend.round')
   1936 def round(x):
   1937   """Element-wise rounding to the closest integer.
   1938 
   1939   In case of tie, the rounding mode used is "half to even".
   1940 
   1941   Arguments:
   1942       x: Tensor or variable.
   1943 
   1944   Returns:
   1945       A tensor.
   1946   """
   1947   return math_ops.round(x)
   1948 
   1949 
   1950 @keras_export('keras.backend.sign')
   1951 def sign(x):
   1952   """Element-wise sign.
   1953 
   1954   Arguments:
   1955       x: Tensor or variable.
   1956 
   1957   Returns:
   1958       A tensor.
   1959   """
   1960   return math_ops.sign(x)
   1961 
   1962 
   1963 @keras_export('keras.backend.pow')
   1964 def pow(x, a):
   1965   """Element-wise exponentiation.
   1966 
   1967   Arguments:
   1968       x: Tensor or variable.
   1969       a: Python integer.
   1970 
   1971   Returns:
   1972       A tensor.
   1973   """
   1974   return math_ops.pow(x, a)
   1975 
   1976 
   1977 @keras_export('keras.backend.clip')
   1978 def clip(x, min_value, max_value):
   1979   """Element-wise value clipping.
   1980 
   1981   Arguments:
   1982       x: Tensor or variable.
   1983       min_value: Python float or integer.
   1984       max_value: Python float or integer.
   1985 
   1986   Returns:
   1987       A tensor.
   1988   """
   1989   if max_value is not None and max_value < min_value:
   1990     max_value = min_value
   1991   if max_value is None:
   1992     max_value = np.inf
   1993   min_value = _to_tensor(min_value, x.dtype.base_dtype)
   1994   max_value = _to_tensor(max_value, x.dtype.base_dtype)
   1995   return clip_ops.clip_by_value(x, min_value, max_value)
   1996 
   1997 
   1998 @keras_export('keras.backend.equal')
   1999 def equal(x, y):
   2000   """Element-wise equality between two tensors.
   2001 
   2002   Arguments:
   2003       x: Tensor or variable.
   2004       y: Tensor or variable.
   2005 
   2006   Returns:
   2007       A bool tensor.
   2008   """
   2009   return math_ops.equal(x, y)
   2010 
   2011 
   2012 @keras_export('keras.backend.not_equal')
   2013 def not_equal(x, y):
   2014   """Element-wise inequality between two tensors.
   2015 
   2016   Arguments:
   2017       x: Tensor or variable.
   2018       y: Tensor or variable.
   2019 
   2020   Returns:
   2021       A bool tensor.
   2022   """
   2023   return math_ops.not_equal(x, y)
   2024 
   2025 
   2026 @keras_export('keras.backend.greater')
   2027 def greater(x, y):
   2028   """Element-wise truth value of (x > y).
   2029 
   2030   Arguments:
   2031       x: Tensor or variable.
   2032       y: Tensor or variable.
   2033 
   2034   Returns:
   2035       A bool tensor.
   2036   """
   2037   return math_ops.greater(x, y)
   2038 
   2039 
   2040 @keras_export('keras.backend.greater_equal')
   2041 def greater_equal(x, y):
   2042   """Element-wise truth value of (x >= y).
   2043 
   2044   Arguments:
   2045       x: Tensor or variable.
   2046       y: Tensor or variable.
   2047 
   2048   Returns:
   2049       A bool tensor.
   2050   """
   2051   return math_ops.greater_equal(x, y)
   2052 
   2053 
   2054 @keras_export('keras.backend.less')
   2055 def less(x, y):
   2056   """Element-wise truth value of (x < y).
   2057 
   2058   Arguments:
   2059       x: Tensor or variable.
   2060       y: Tensor or variable.
   2061 
   2062   Returns:
   2063       A bool tensor.
   2064   """
   2065   return math_ops.less(x, y)
   2066 
   2067 
   2068 @keras_export('keras.backend.less_equal')
   2069 def less_equal(x, y):
   2070   """Element-wise truth value of (x <= y).
   2071 
   2072   Arguments:
   2073       x: Tensor or variable.
   2074       y: Tensor or variable.
   2075 
   2076   Returns:
   2077       A bool tensor.
   2078   """
   2079   return math_ops.less_equal(x, y)
   2080 
   2081 
   2082 @keras_export('keras.backend.maximum')
   2083 def maximum(x, y):
   2084   """Element-wise maximum of two tensors.
   2085 
   2086   Arguments:
   2087       x: Tensor or variable.
   2088       y: Tensor or variable.
   2089 
   2090   Returns:
   2091       A tensor.
   2092   """
   2093   return math_ops.maximum(x, y)
   2094 
   2095 
   2096 @keras_export('keras.backend.minimum')
   2097 def minimum(x, y):
   2098   """Element-wise minimum of two tensors.
   2099 
   2100   Arguments:
   2101       x: Tensor or variable.
   2102       y: Tensor or variable.
   2103 
   2104   Returns:
   2105       A tensor.
   2106   """
   2107   return math_ops.minimum(x, y)
   2108 
   2109 
   2110 @keras_export('keras.backend.sin')
   2111 def sin(x):
   2112   """Computes sin of x element-wise.
   2113 
   2114   Arguments:
   2115       x: Tensor or variable.
   2116 
   2117   Returns:
   2118       A tensor.
   2119   """
   2120   return math_ops.sin(x)
   2121 
   2122 
   2123 @keras_export('keras.backend.cos')
   2124 def cos(x):
   2125   """Computes cos of x element-wise.
   2126 
   2127   Arguments:
   2128       x: Tensor or variable.
   2129 
   2130   Returns:
   2131       A tensor.
   2132   """
   2133   return math_ops.cos(x)
   2134 
   2135 
   2136 def _regular_normalize_batch_in_training(x,
   2137                                          gamma,
   2138                                          beta,
   2139                                          reduction_axes,
   2140                                          epsilon=1e-3):
   2141   """Non-fused version of `normalize_batch_in_training`.
   2142 
   2143   Arguments:
   2144       x: Input tensor or variable.
   2145       gamma: Tensor by which to scale the input.
   2146       beta: Tensor with which to center the input.
   2147       reduction_axes: iterable of integers,
   2148           axes over which to normalize.
   2149       epsilon: Fuzz factor.
   2150 
   2151   Returns:
   2152       A tuple length of 3, `(normalized_tensor, mean, variance)`.
   2153   """
   2154   mean, var = nn.moments(x, reduction_axes, None, None, False)
   2155   normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
   2156   return normed, mean, var
   2157 
   2158 
   2159 def _broadcast_normalize_batch_in_training(x,
   2160                                            gamma,
   2161                                            beta,
   2162                                            reduction_axes,
   2163                                            epsilon=1e-3):
   2164   """Non-fused, broadcast version of `normalize_batch_in_training`.
   2165 
   2166   Arguments:
   2167       x: Input tensor or variable.
   2168       gamma: Tensor by which to scale the input.
   2169       beta: Tensor with which to center the input.
   2170       reduction_axes: iterable of integers,
   2171           axes over which to normalize.
   2172       epsilon: Fuzz factor.
   2173 
   2174   Returns:
   2175       A tuple length of 3, `(normalized_tensor, mean, variance)`.
   2176   """
   2177   mean, var = nn.moments(x, reduction_axes, None, None, False)
   2178   target_shape = []
   2179   for axis in range(ndim(x)):
   2180     if axis in reduction_axes:
   2181       target_shape.append(1)
   2182     else:
   2183       target_shape.append(array_ops.shape(x)[axis])
   2184   target_shape = array_ops.stack(target_shape)
   2185 
   2186   broadcast_mean = array_ops.reshape(mean, target_shape)
   2187   broadcast_var = array_ops.reshape(var, target_shape)
   2188   if gamma is None:
   2189     broadcast_gamma = None
   2190   else:
   2191     broadcast_gamma = array_ops.reshape(gamma, target_shape)
   2192   if beta is None:
   2193     broadcast_beta = None
   2194   else:
   2195     broadcast_beta = array_ops.reshape(beta, target_shape)
   2196 
   2197   normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
   2198                                   broadcast_beta, broadcast_gamma, epsilon)
   2199   return normed, mean, var
   2200 
   2201 
   2202 def _fused_normalize_batch_in_training(x,
   2203                                        gamma,
   2204                                        beta,
   2205                                        reduction_axes,
   2206                                        epsilon=1e-3):
   2207   """Fused version of `normalize_batch_in_training`.
   2208 
   2209   Arguments:
   2210       x: Input tensor or variable.
   2211       gamma: Tensor by which to scale the input.
   2212       beta: Tensor with which to center the input.
   2213       reduction_axes: iterable of integers,
   2214           axes over which to normalize.
   2215       epsilon: Fuzz factor.
   2216 
   2217   Returns:
   2218       A tuple length of 3, `(normalized_tensor, mean, variance)`.
   2219   """
   2220   if list(reduction_axes) == [0, 1, 2]:
   2221     normalization_axis = 3
   2222     tf_data_format = 'NHWC'
   2223   else:
   2224     normalization_axis = 1
   2225     tf_data_format = 'NCHW'
   2226 
   2227   if gamma is None:
   2228     gamma = constant_op.constant(
   2229         1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
   2230   if beta is None:
   2231     beta = constant_op.constant(
   2232         0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
   2233 
   2234   return nn.fused_batch_norm(
   2235       x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
   2236 
   2237 
   2238 @keras_export('keras.backend.normalize_batch_in_training')
   2239 def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
   2240   """Computes mean and std for batch then apply batch_normalization on batch.
   2241 
   2242   Arguments:
   2243       x: Input tensor or variable.
   2244       gamma: Tensor by which to scale the input.
   2245       beta: Tensor with which to center the input.
   2246       reduction_axes: iterable of integers,
   2247           axes over which to normalize.
   2248       epsilon: Fuzz factor.
   2249 
   2250   Returns:
   2251       A tuple length of 3, `(normalized_tensor, mean, variance)`.
   2252   """
   2253   if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
   2254     if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
   2255       return _broadcast_normalize_batch_in_training(
   2256           x, gamma, beta, reduction_axes, epsilon=epsilon)
   2257     return _fused_normalize_batch_in_training(
   2258         x, gamma, beta, reduction_axes, epsilon=epsilon)
   2259   else:
   2260     if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
   2261       return _regular_normalize_batch_in_training(
   2262           x, gamma, beta, reduction_axes, epsilon=epsilon)
   2263     else:
   2264       return _broadcast_normalize_batch_in_training(
   2265           x, gamma, beta, reduction_axes, epsilon=epsilon)
   2266 
   2267 
   2268 @keras_export('keras.backend.batch_normalization')
   2269 def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
   2270   """Applies batch normalization on x given mean, var, beta and gamma.
   2271 
   2272   I.e. returns:
   2273   `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
   2274 
   2275   Arguments:
   2276       x: Input tensor or variable.
   2277       mean: Mean of batch.
   2278       var: Variance of batch.
   2279       beta: Tensor with which to center the input.
   2280       gamma: Tensor by which to scale the input.
   2281       axis: Integer, the axis that should be normalized.
   2282           (typically the features axis).
   2283       epsilon: Fuzz factor.
   2284 
   2285   Returns:
   2286       A tensor.
   2287   """
   2288   if ndim(x) == 4:
   2289     # The CPU implementation of `fused_batch_norm` only supports NHWC
   2290     if axis == 1 or axis == -3:
   2291       tf_data_format = 'NCHW'
   2292     elif axis == 3 or axis == -1:
   2293       tf_data_format = 'NHWC'
   2294     else:
   2295       tf_data_format = None
   2296 
   2297     if (tf_data_format == 'NHWC' or
   2298         tf_data_format == 'NCHW' and _has_nchw_support()):
   2299       # The mean / var / beta / gamma tensors may be broadcasted
   2300       # so they may have extra axes of size 1, which should be squeezed.
   2301       if ndim(mean) > 1:
   2302         mean = array_ops.reshape(mean, [-1])
   2303       if ndim(var) > 1:
   2304         var = array_ops.reshape(var, [-1])
   2305       if beta is None:
   2306         beta = zeros_like(mean)
   2307       elif ndim(beta) > 1:
   2308         beta = array_ops.reshape(beta, [-1])
   2309       if gamma is None:
   2310         gamma = ones_like(mean)
   2311       elif ndim(gamma) > 1:
   2312         gamma = array_ops.reshape(gamma, [-1])
   2313     y, _, _ = nn.fused_batch_norm(
   2314         x,
   2315         gamma,
   2316         beta,
   2317         epsilon=epsilon,
   2318         mean=mean,
   2319         variance=var,
   2320         data_format=tf_data_format,
   2321         is_training=False
   2322     )
   2323     return y
   2324   return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
   2325 
   2326 
   2327 # SHAPE OPERATIONS
   2328 
   2329 
   2330 @keras_export('keras.backend.concatenate')
   2331 def concatenate(tensors, axis=-1):
   2332   """Concatenates a list of tensors alongside the specified axis.
   2333 
   2334   Arguments:
   2335       tensors: list of tensors to concatenate.
   2336       axis: concatenation axis.
   2337 
   2338   Returns:
   2339       A tensor.
   2340   """
   2341   if axis < 0:
   2342     rank = ndim(tensors[0])
   2343     if rank:
   2344       axis %= rank
   2345     else:
   2346       axis = 0
   2347 
   2348   if py_all(is_sparse(x) for x in tensors):
   2349     return sparse_ops.sparse_concat(axis, tensors)
   2350   else:
   2351     return array_ops.concat([to_dense(x) for x in tensors], axis)
   2352 
   2353 
   2354 @keras_export('keras.backend.reshape')
   2355 def reshape(x, shape):
   2356   """Reshapes a tensor to the specified shape.
   2357 
   2358   Arguments:
   2359       x: Tensor or variable.
   2360       shape: Target shape tuple.
   2361 
   2362   Returns:
   2363       A tensor.
   2364   """
   2365   return array_ops.reshape(x, shape)
   2366 
   2367 
   2368 @keras_export('keras.backend.permute_dimensions')
   2369 def permute_dimensions(x, pattern):
   2370   """Permutes axes in a tensor.
   2371 
   2372   Arguments:
   2373       x: Tensor or variable.
   2374       pattern: A tuple of
   2375           dimension indices, e.g. `(0, 2, 1)`.
   2376 
   2377   Returns:
   2378       A tensor.
   2379   """
   2380   return array_ops.transpose(x, perm=pattern)
   2381 
   2382 
   2383 @keras_export('keras.backend.resize_images')
   2384 def resize_images(x, height_factor, width_factor, data_format,
   2385                   interpolation='nearest'):
   2386   """Resizes the images contained in a 4D tensor.
   2387 
   2388   Arguments:
   2389       x: Tensor or variable to resize.
   2390       height_factor: Positive integer.
   2391       width_factor: Positive integer.
   2392       data_format: One of `"channels_first"`, `"channels_last"`.
   2393       interpolation: A string, one of `nearest` or `bilinear`.
   2394 
   2395   Returns:
   2396       A tensor.
   2397 
   2398   Raises:
   2399       ValueError: in case of incorrect value for
   2400         `data_format` or `interpolation`.
   2401   """
   2402   if data_format == 'channels_first':
   2403     rows, cols = 2, 3
   2404   elif data_format == 'channels_last':
   2405     rows, cols = 1, 2
   2406   else:
   2407     raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
   2408 
   2409   original_shape = int_shape(x)
   2410   new_shape = array_ops.shape(x)[rows:cols + 1]
   2411   new_shape *= constant_op.constant(
   2412       np.array([height_factor, width_factor], dtype='int32'))
   2413 
   2414   if data_format == 'channels_first':
   2415     x = permute_dimensions(x, [0, 2, 3, 1])
   2416   if interpolation == 'nearest':
   2417     x = image_ops.resize_nearest_neighbor(x, new_shape)
   2418   elif interpolation == 'bilinear':
   2419     x = image_ops.resize_bilinear(x, new_shape)
   2420   else:
   2421     raise ValueError('interpolation should be one '
   2422                      'of "nearest" or "bilinear".')
   2423   if data_format == 'channels_first':
   2424     x = permute_dimensions(x, [0, 3, 1, 2])
   2425 
   2426   if original_shape[rows] is None:
   2427     new_height = None
   2428   else:
   2429     new_height = original_shape[rows] * height_factor
   2430 
   2431   if original_shape[cols] is None:
   2432     new_width = None
   2433   else:
   2434     new_width = original_shape[cols] * width_factor
   2435 
   2436   if data_format == 'channels_first':
   2437     output_shape = (None, None, new_height, new_width)
   2438   else:
   2439     output_shape = (None, new_height, new_width, None)
   2440   x.set_shape(output_shape)
   2441   return x
   2442 
   2443 
   2444 @keras_export('keras.backend.resize_volumes')
   2445 def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
   2446   """Resizes the volume contained in a 5D tensor.
   2447 
   2448   Arguments:
   2449       x: Tensor or variable to resize.
   2450       depth_factor: Positive integer.
   2451       height_factor: Positive integer.
   2452       width_factor: Positive integer.
   2453       data_format: One of `"channels_first"`, `"channels_last"`.
   2454 
   2455   Returns:
   2456       A tensor.
   2457 
   2458   Raises:
   2459       ValueError: if `data_format` is neither
   2460           `channels_last` or `channels_first`.
   2461   """
   2462   if data_format == 'channels_first':
   2463     output = repeat_elements(x, depth_factor, axis=2)
   2464     output = repeat_elements(output, height_factor, axis=3)
   2465     output = repeat_elements(output, width_factor, axis=4)
   2466     return output
   2467   elif data_format == 'channels_last':
   2468     output = repeat_elements(x, depth_factor, axis=1)
   2469     output = repeat_elements(output, height_factor, axis=2)
   2470     output = repeat_elements(output, width_factor, axis=3)
   2471     return output
   2472   else:
   2473     raise ValueError('Invalid data_format: ' + str(data_format))
   2474 
   2475 
   2476 @keras_export('keras.backend.repeat_elements')
   2477 def repeat_elements(x, rep, axis):
   2478   """Repeats the elements of a tensor along an axis, like `np.repeat`.
   2479 
   2480   If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
   2481   will have shape `(s1, s2 * rep, s3)`.
   2482 
   2483   Arguments:
   2484       x: Tensor or variable.
   2485       rep: Python integer, number of times to repeat.
   2486       axis: Axis along which to repeat.
   2487 
   2488   Returns:
   2489       A tensor.
   2490   """
   2491   x_shape = x.shape.as_list()
   2492   # For static axis
   2493   if x_shape[axis] is not None:
   2494     # slices along the repeat axis
   2495     splits = array_ops.split(value=x,
   2496                              num_or_size_splits=x_shape[axis],
   2497                              axis=axis)
   2498     # repeat each slice the given number of reps
   2499     x_rep = [s for s in splits for _ in range(rep)]
   2500     return concatenate(x_rep, axis)
   2501 
   2502   # Here we use tf.tile to mimic behavior of np.repeat so that
   2503   # we can handle dynamic shapes (that include None).
   2504   # To do that, we need an auxiliary axis to repeat elements along
   2505   # it and then merge them along the desired axis.
   2506 
   2507   # Repeating
   2508   auxiliary_axis = axis + 1
   2509   x_shape = array_ops.shape(x)
   2510   x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
   2511   reps = np.ones(len(x.shape) + 1)
   2512   reps[auxiliary_axis] = rep
   2513   x_rep = array_ops.tile(x_rep, reps)
   2514 
   2515   # Merging
   2516   reps = np.delete(reps, auxiliary_axis)
   2517   reps[axis] = rep
   2518   reps = array_ops.constant(reps, dtype='int32')
   2519   x_shape *= reps
   2520   x_rep = array_ops.reshape(x_rep, x_shape)
   2521 
   2522   # Fix shape representation
   2523   x_shape = x.shape.as_list()
   2524   x_rep.set_shape(x_shape)
   2525   x_rep._keras_shape = tuple(x_shape)
   2526   return x_rep
   2527 
   2528 
   2529 @keras_export('keras.backend.repeat')
   2530 def repeat(x, n):
   2531   """Repeats a 2D tensor.
   2532 
   2533   if `x` has shape (samples, dim) and `n` is `2`,
   2534   the output will have shape `(samples, 2, dim)`.
   2535 
   2536   Arguments:
   2537       x: Tensor or variable.
   2538       n: Python integer, number of times to repeat.
   2539 
   2540   Returns:
   2541       A tensor.
   2542   """
   2543   assert ndim(x) == 2
   2544   x = array_ops.expand_dims(x, 1)
   2545   pattern = array_ops.stack([1, n, 1])
   2546   return array_ops.tile(x, pattern)
   2547 
   2548 
   2549 @keras_export('keras.backend.arange')
   2550 def arange(start, stop=None, step=1, dtype='int32'):
   2551   """Creates a 1D tensor containing a sequence of integers.
   2552 
   2553   The function arguments use the same convention as
   2554   Theano's arange: if only one argument is provided,
   2555   it is in fact the "stop" argument and "start" is 0.
   2556 
   2557   The default type of the returned tensor is `'int32'` to
   2558   match TensorFlow's default.
   2559 
   2560   Arguments:
   2561       start: Start value.
   2562       stop: Stop value.
   2563       step: Difference between two successive values.
   2564       dtype: Integer dtype to use.
   2565 
   2566   Returns:
   2567       An integer tensor.
   2568 
   2569   """
   2570   # Match the behavior of numpy and Theano by returning an empty sequence.
   2571   if stop is None and start < 0:
   2572     start = 0
   2573   result = math_ops.range(start, limit=stop, delta=step, name='arange')
   2574   if dtype != 'int32':
   2575     result = cast(result, dtype)
   2576   return result
   2577 
   2578 
   2579 @keras_export('keras.backend.tile')
   2580 def tile(x, n):
   2581   """Creates a tensor by tiling `x` by `n`.
   2582 
   2583   Arguments:
   2584       x: A tensor or variable
   2585       n: A list of integer. The length must be the same as the number of
   2586           dimensions in `x`.
   2587 
   2588   Returns:
   2589       A tiled tensor.
   2590   """
   2591   if isinstance(n, int):
   2592     n = [n]
   2593   return array_ops.tile(x, n)
   2594 
   2595 
   2596 @keras_export('keras.backend.flatten')
   2597 def flatten(x):
   2598   """Flatten a tensor.
   2599 
   2600   Arguments:
   2601       x: A tensor or variable.
   2602 
   2603   Returns:
   2604       A tensor, reshaped into 1-D
   2605   """
   2606   return array_ops.reshape(x, [-1])
   2607 
   2608 
   2609 @keras_export('keras.backend.batch_flatten')
   2610 def batch_flatten(x):
   2611   """Turn a nD tensor into a 2D tensor with same 0th dimension.
   2612 
   2613   In other words, it flattens each data samples of a batch.
   2614 
   2615   Arguments:
   2616       x: A tensor or variable.
   2617 
   2618   Returns:
   2619       A tensor.
   2620   """
   2621   x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
   2622   return x
   2623 
   2624 
   2625 @keras_export('keras.backend.expand_dims')
   2626 def expand_dims(x, axis=-1):
   2627   """Adds a 1-sized dimension at index "axis".
   2628 
   2629   Arguments:
   2630       x: A tensor or variable.
   2631       axis: Position where to add a new axis.
   2632 
   2633   Returns:
   2634       A tensor with expanded dimensions.
   2635   """
   2636   return array_ops.expand_dims(x, axis)
   2637 
   2638 
   2639 @keras_export('keras.backend.squeeze')
   2640 def squeeze(x, axis):
   2641   """Removes a 1-dimension from the tensor at index "axis".
   2642 
   2643   Arguments:
   2644       x: A tensor or variable.
   2645       axis: Axis to drop.
   2646 
   2647   Returns:
   2648       A tensor with the same data as `x` but reduced dimensions.
   2649   """
   2650   return array_ops.squeeze(x, [axis])
   2651 
   2652 
   2653 @keras_export('keras.backend.temporal_padding')
   2654 def temporal_padding(x, padding=(1, 1)):
   2655   """Pads the middle dimension of a 3D tensor.
   2656 
   2657   Arguments:
   2658       x: Tensor or variable.
   2659       padding: Tuple of 2 integers, how many zeros to
   2660           add at the start and end of dim 1.
   2661 
   2662   Returns:
   2663       A padded 3D tensor.
   2664   """
   2665   assert len(padding) == 2
   2666   pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
   2667   return array_ops.pad(x, pattern)
   2668 
   2669 
   2670 @keras_export('keras.backend.spatial_2d_padding')
   2671 def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
   2672   """Pads the 2nd and 3rd dimensions of a 4D tensor.
   2673 
   2674   Arguments:
   2675       x: Tensor or variable.
   2676       padding: Tuple of 2 tuples, padding pattern.
   2677       data_format: One of `channels_last` or `channels_first`.
   2678 
   2679   Returns:
   2680       A padded 4D tensor.
   2681 
   2682   Raises:
   2683       ValueError: if `data_format` is neither
   2684           `channels_last` or `channels_first`.
   2685   """
   2686   assert len(padding) == 2
   2687   assert len(padding[0]) == 2
   2688   assert len(padding[1]) == 2
   2689   if data_format is None:
   2690     data_format = image_data_format()
   2691   if data_format not in {'channels_first', 'channels_last'}:
   2692     raise ValueError('Unknown data_format: ' + str(data_format))
   2693 
   2694   if data_format == 'channels_first':
   2695     pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
   2696   else:
   2697     pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
   2698   return array_ops.pad(x, pattern)
   2699 
   2700 
   2701 @keras_export('keras.backend.spatial_3d_padding')
   2702 def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
   2703   """Pads 5D tensor with zeros along the depth, height, width dimensions.
   2704 
   2705   Pads these dimensions with respectively
   2706   "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
   2707 
   2708   For 'channels_last' data_format,
   2709   the 2nd, 3rd and 4th dimension will be padded.
   2710   For 'channels_first' data_format,
   2711   the 3rd, 4th and 5th dimension will be padded.
   2712 
   2713   Arguments:
   2714       x: Tensor or variable.
   2715       padding: Tuple of 3 tuples, padding pattern.
   2716       data_format: One of `channels_last` or `channels_first`.
   2717 
   2718   Returns:
   2719       A padded 5D tensor.
   2720 
   2721   Raises:
   2722       ValueError: if `data_format` is neither
   2723           `channels_last` or `channels_first`.
   2724 
   2725   """
   2726   assert len(padding) == 3
   2727   assert len(padding[0]) == 2
   2728   assert len(padding[1]) == 2
   2729   assert len(padding[2]) == 2
   2730   if data_format is None:
   2731     data_format = image_data_format()
   2732   if data_format not in {'channels_first', 'channels_last'}:
   2733     raise ValueError('Unknown data_format: ' + str(data_format))
   2734 
   2735   if data_format == 'channels_first':
   2736     pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
   2737                [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
   2738   else:
   2739     pattern = [[0, 0], [padding[0][0], padding[0][1]],
   2740                [padding[1][0], padding[1][1]], [padding[2][0],
   2741                                                 padding[2][1]], [0, 0]]
   2742   return array_ops.pad(x, pattern)
   2743 
   2744 
   2745 @keras_export('keras.backend.stack')
   2746 def stack(x, axis=0):
   2747   """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
   2748 
   2749   Arguments:
   2750       x: List of tensors.
   2751       axis: Axis along which to perform stacking.
   2752 
   2753   Returns:
   2754       A tensor.
   2755   """
   2756   return array_ops.stack(x, axis=axis)
   2757 
   2758 
   2759 @keras_export('keras.backend.one_hot')
   2760 def one_hot(indices, num_classes):
   2761   """Computes the one-hot representation of an integer tensor.
   2762 
   2763   Arguments:
   2764       indices: nD integer tensor of shape
   2765           `(batch_size, dim1, dim2, ... dim(n-1))`
   2766       num_classes: Integer, number of classes to consider.
   2767 
   2768   Returns:
   2769       (n + 1)D one hot representation of the input
   2770       with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
   2771 
   2772   Returns:
   2773       The one-hot tensor.
   2774   """
   2775   return array_ops.one_hot(indices, depth=num_classes, axis=-1)
   2776 
   2777 
   2778 @keras_export('keras.backend.reverse')
   2779 def reverse(x, axes):
   2780   """Reverse a tensor along the specified axes.
   2781 
   2782   Arguments:
   2783       x: Tensor to reverse.
   2784       axes: Integer or iterable of integers.
   2785           Axes to reverse.
   2786 
   2787   Returns:
   2788       A tensor.
   2789   """
   2790   if isinstance(axes, int):
   2791     axes = [axes]
   2792   return array_ops.reverse(x, axes)
   2793 
   2794 
   2795 # VALUE MANIPULATION
   2796 
   2797 
   2798 @keras_export('keras.backend.get_value')
   2799 def get_value(x):
   2800   """Returns the value of a variable.
   2801 
   2802   Arguments:
   2803       x: input variable.
   2804 
   2805   Returns:
   2806       A Numpy array.
   2807 
   2808   Raises:
   2809       RuntimeError: If this method is called inside defun.
   2810   """
   2811   if context.executing_eagerly():
   2812     return x.numpy()
   2813   elif not getattr(x, '_in_graph_mode', True):
   2814     # This is a variable which was created in an eager context, but is being
   2815     # evaluated from a Graph.
   2816     with context.eager_mode():
   2817       return x.numpy()
   2818   elif ops.inside_function():
   2819     raise RuntimeError('Cannot get value inside Tensorflow graph function.')
   2820   return x.eval(session=get_session((x,)))
   2821 
   2822 
   2823 @keras_export('keras.backend.batch_get_value')
   2824 def batch_get_value(tensors):
   2825   """Returns the value of more than one tensor variable.
   2826 
   2827   Arguments:
   2828       tensors: list of ops to run.
   2829 
   2830   Returns:
   2831       A list of Numpy arrays.
   2832 
   2833   Raises:
   2834       RuntimeError: If this method is called inside defun.
   2835   """
   2836   if context.executing_eagerly():
   2837     return [x.numpy() for x in tensors]
   2838   elif ops.inside_function():  # pylint: disable=protected-access
   2839     raise RuntimeError('Cannot get value inside Tensorflow graph function.')
   2840   if tensors:
   2841     return get_session(tensors).run(tensors)
   2842   else:
   2843     return []
   2844 
   2845 
   2846 @keras_export('keras.backend.set_value')
   2847 def set_value(x, value):
   2848   """Sets the value of a variable, from a Numpy array.
   2849 
   2850   Arguments:
   2851       x: Tensor to set to a new value.
   2852       value: Value to set the tensor to, as a Numpy array
   2853           (of the same shape).
   2854   """
   2855   value = np.asarray(value, dtype=dtype(x))
   2856   if ops.executing_eagerly_outside_functions():
   2857     x.assign(value)
   2858   else:
   2859     with get_graph().as_default():
   2860       tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
   2861       if hasattr(x, '_assign_placeholder'):
   2862         assign_placeholder = x._assign_placeholder
   2863         assign_op = x._assign_op
   2864       else:
   2865         assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape)
   2866         assign_op = x.assign(assign_placeholder)
   2867         x._assign_placeholder = assign_placeholder
   2868         x._assign_op = assign_op
   2869       get_session().run(assign_op, feed_dict={assign_placeholder: value})
   2870 
   2871 
   2872 @keras_export('keras.backend.batch_set_value')
   2873 def batch_set_value(tuples):
   2874   """Sets the values of many tensor variables at once.
   2875 
   2876   Arguments:
   2877       tuples: a list of tuples `(tensor, value)`.
   2878           `value` should be a Numpy array.
   2879   """
   2880   if ops.executing_eagerly_outside_functions():
   2881     for x, value in tuples:
   2882       x.assign(np.asarray(value, dtype=dtype(x)))
   2883   else:
   2884     with get_graph().as_default():
   2885       if tuples:
   2886         assign_ops = []
   2887         feed_dict = {}
   2888         for x, value in tuples:
   2889           value = np.asarray(value, dtype=dtype(x))
   2890           tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
   2891           if hasattr(x, '_assign_placeholder'):
   2892             assign_placeholder = x._assign_placeholder
   2893             assign_op = x._assign_op
   2894           else:
   2895             assign_placeholder = array_ops.placeholder(tf_dtype,
   2896                                                        shape=value.shape)
   2897             assign_op = x.assign(assign_placeholder)
   2898             x._assign_placeholder = assign_placeholder
   2899             x._assign_op = assign_op
   2900           assign_ops.append(assign_op)
   2901           feed_dict[assign_placeholder] = value
   2902         get_session().run(assign_ops, feed_dict=feed_dict)
   2903 
   2904 
   2905 @keras_export('keras.backend.print_tensor')
   2906 def print_tensor(x, message=''):
   2907   """Prints `message` and the tensor value when evaluated.
   2908 
   2909   Note that `print_tensor` returns a new tensor identical to `x`
   2910   which should be used in the following code. Otherwise the
   2911   print operation is not taken into account during evaluation.
   2912 
   2913   Example:
   2914 
   2915   ```python
   2916      >>> x = K.print_tensor(x, message="x is: ")
   2917   ```
   2918 
   2919   Arguments:
   2920       x: Tensor to print.
   2921       message: Message to print jointly with the tensor.
   2922 
   2923   Returns:
   2924       The same tensor `x`, unchanged.
   2925   """
   2926   return logging_ops.Print(x, [x], message)
   2927 
   2928 
   2929 # GRAPH MANIPULATION
   2930 
   2931 
   2932 class GraphExecutionFunction(object):
   2933   """Runs a computation graph.
   2934 
   2935   It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
   2936   In particular additional operations via `fetches` argument and additional
   2937   tensor substitutions via `feed_dict` arguments. Note that given
   2938   substitutions are merged with substitutions from `inputs`. Even though
   2939   `feed_dict` is passed once in the constructor (called in `model.compile()`)
   2940   we can modify the values in the dictionary. Through this feed_dict we can
   2941   provide additional substitutions besides Keras inputs.
   2942 
   2943   Arguments:
   2944       inputs: Feed placeholders to the computation graph.
   2945       outputs: Output tensors to fetch.
   2946       updates: Additional update ops to be run at function call.
   2947       name: A name to help users identify what this function does.
   2948       session_kwargs: Arguments to `tf.Session.run()`:
   2949                       `fetches`, `feed_dict`, `options`, `run_metadata`.
   2950   """
   2951 
   2952   def __init__(self, inputs, outputs, updates=None, name=None,
   2953                **session_kwargs):
   2954     updates = updates or []
   2955     if not isinstance(updates, (list, tuple)):
   2956       raise TypeError('`updates` in a Keras backend function '
   2957                       'should be a list or tuple.')
   2958     self.inputs = nest.flatten(inputs)
   2959     self._outputs_structure = outputs
   2960     self.outputs = cast_variables_to_tensor(nest.flatten(outputs))
   2961     # TODO(b/127668432): Consider using autograph to generate these
   2962     # dependencies in call.
   2963     # Index 0 = total loss or model output for `predict`.
   2964     with ops.control_dependencies([self.outputs[0]]):
   2965       updates_ops = []
   2966       for update in updates:
   2967         if isinstance(update, tuple):
   2968           p, new_p = update
   2969           updates_ops.append(state_ops.assign(p, new_p))
   2970         else:
   2971           # assumed already an op
   2972           updates_ops.append(update)
   2973       self.updates_op = control_flow_ops.group(*updates_ops)
   2974     self.name = name
   2975     # additional tensor substitutions
   2976     self.feed_dict = session_kwargs.pop('feed_dict', None)
   2977     # additional operations
   2978     self.fetches = session_kwargs.pop('fetches', [])
   2979     if not isinstance(self.fetches, list):
   2980       self.fetches = [self.fetches]
   2981     self.run_options = session_kwargs.pop('options', None)
   2982     self.run_metadata = session_kwargs.pop('run_metadata', None)
   2983     # The main use case of `fetches` being passed to a model is the ability
   2984     # to run custom updates
   2985     # This requires us to wrap fetches in `identity` ops.
   2986     self.fetches = [array_ops.identity(x) for x in self.fetches]
   2987     self.session_kwargs = session_kwargs
   2988     # This mapping keeps track of the function that should receive the
   2989     # output from a fetch in `fetches`: { fetch: function(fetch_output) }
   2990     # A Callback can use this to register a function with access to the
   2991     # output values for a fetch it added.
   2992     self.fetch_callbacks = dict()
   2993 
   2994     if session_kwargs:
   2995       raise ValueError('Some keys in session_kwargs are not supported at this '
   2996                        'time: %s' % (session_kwargs.keys(),))
   2997 
   2998     self._callable_fn = None
   2999     self._feed_arrays = None
   3000     self._feed_symbols = None
   3001     self._symbol_vals = None
   3002     self._fetches = None
   3003     self._session = None
   3004 
   3005   def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
   3006     """Generates a callable that runs the graph.
   3007 
   3008     Arguments:
   3009       feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
   3010       feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
   3011       symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
   3012       session: Session to use to generate the callable.
   3013 
   3014     Returns:
   3015       Function that runs the graph according to the above options.
   3016     """
   3017     # Prepare callable options.
   3018     callable_opts = config_pb2.CallableOptions()
   3019     # Handle external-data feed.
   3020     for x in feed_arrays:
   3021       callable_opts.feed.append(x.name)
   3022     if self.feed_dict:
   3023       for key in sorted(self.feed_dict.keys()):
   3024         callable_opts.feed.append(key.name)
   3025     # Handle symbolic feed.
   3026     for x, y in zip(feed_symbols, symbol_vals):
   3027       connection = callable_opts.tensor_connection.add()
   3028       if x.dtype != y.dtype:
   3029         y = math_ops.cast(y, dtype=x.dtype)
   3030       from_tensor = ops._as_graph_element(y)
   3031       if from_tensor is None:
   3032         from_tensor = y
   3033       connection.from_tensor = from_tensor.name  # Data tensor
   3034       connection.to_tensor = x.name  # Placeholder
   3035     # Handle fetches.
   3036     for x in self.outputs + self.fetches:
   3037       callable_opts.fetch.append(x.name)
   3038     # Handle updates.
   3039     callable_opts.target.append(self.updates_op.name)
   3040     # Handle run_options.
   3041     if self.run_options:
   3042       callable_opts.run_options.CopyFrom(self.run_options)
   3043     # Create callable.
   3044     callable_fn = session._make_callable_from_options(callable_opts)
   3045     # Cache parameters corresponding to the generated callable, so that
   3046     # we can detect future mismatches and refresh the callable.
   3047     self._callable_fn = callable_fn
   3048     self._feed_arrays = feed_arrays
   3049     self._feed_symbols = feed_symbols
   3050     self._symbol_vals = symbol_vals
   3051     self._fetches = list(self.fetches)
   3052     self._session = session
   3053 
   3054   def _call_fetch_callbacks(self, fetches_output):
   3055     for fetch, output in zip(self._fetches, fetches_output):
   3056       if fetch in self.fetch_callbacks:
   3057         self.fetch_callbacks[fetch](output)
   3058 
   3059   def __call__(self, inputs):
   3060     inputs = nest.flatten(inputs)
   3061 
   3062     session = get_session(inputs)
   3063     feed_arrays = []
   3064     array_vals = []
   3065     feed_symbols = []
   3066     symbol_vals = []
   3067     for tensor, value in zip(self.inputs, inputs):
   3068       if value is None:
   3069         continue
   3070       if is_sparse(tensor):
   3071         sparse_coo = value.tocoo()
   3072         indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
   3073                                   np.expand_dims(sparse_coo.col, 1)), 1)
   3074         value = (indices, sparse_coo.data, sparse_coo.shape)
   3075       if tensor_util.is_tensor(value):
   3076         # Case: feeding symbolic tensor.
   3077         feed_symbols.append(tensor)
   3078         symbol_vals.append(value)
   3079       else:
   3080         # Case: feeding Numpy array.
   3081         feed_arrays.append(tensor)
   3082         # We need to do array conversion and type casting at this level, since
   3083         # `callable_fn` only supports exact matches.
   3084         tensor_type = dtypes_module.as_dtype(tensor.dtype)
   3085         array_vals.append(np.asarray(value,
   3086                                      dtype=tensor_type.as_numpy_dtype))
   3087 
   3088     if self.feed_dict:
   3089       for key in sorted(self.feed_dict.keys()):
   3090         array_vals.append(
   3091             np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
   3092 
   3093     # Refresh callable if anything has changed.
   3094     if (self._callable_fn is None or feed_arrays != self._feed_arrays or
   3095         symbol_vals != self._symbol_vals or
   3096         feed_symbols != self._feed_symbols or self.fetches != self._fetches or
   3097         session != self._session):
   3098       self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
   3099 
   3100     fetched = self._callable_fn(*array_vals,
   3101                                 run_metadata=self.run_metadata)
   3102     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3103     return nest.pack_sequence_as(self._outputs_structure,
   3104                                  fetched[:len(self.outputs)])
   3105 
   3106 
   3107 class EagerExecutionFunction(object):
   3108   """Helper class for constructing a TF graph function from the Keras graph.
   3109 
   3110   Arguments:
   3111     inputs: Feed placeholders to the computation graph.
   3112     outputs: Output tensors to fetch.
   3113     updates: Additional update ops to be run at function call.
   3114     name: A name to help users identify what this function does.
   3115     session_kwargs: Unsupported.
   3116   """
   3117 
   3118   def __init__(self, inputs, outputs, updates=None, name=None):
   3119     self.name = name
   3120     self._outputs_structure = outputs
   3121     inputs = nest.flatten(inputs)
   3122     outputs = nest.flatten(outputs)
   3123 
   3124     updates = updates or []
   3125     if not isinstance(updates, (list, tuple)):
   3126       raise TypeError('`updates` in a Keras backend function '
   3127                       'should be a list or tuple.')
   3128 
   3129     if updates and not outputs:
   3130       # Edge case; never happens in practice
   3131       raise ValueError('Cannot create a Keras backend function with updates'
   3132                        ' but no outputs during eager execution.')
   3133 
   3134     graphs = {i.graph for i in nest.flatten([inputs, outputs, updates])
   3135               if hasattr(i, 'graph')}
   3136     if len(graphs) > 1:
   3137       raise ValueError('Cannot create an execution function which is comprised '
   3138                        'of elements from multiple graphs.')
   3139 
   3140     source_graph = graphs.pop()
   3141     global_graph = get_graph()
   3142 
   3143     updates_ops = []
   3144     legacy_update_ops = []
   3145     for update in updates:
   3146       # For legacy reasons it is allowed to pass an update as a tuple
   3147       # `(variable, new_value)` (this maps to an assign op). Otherwise it
   3148       # is assumed to already be an op -- we cannot control its execution
   3149       # order.
   3150       if isinstance(update, tuple):
   3151         legacy_update_ops.append(update)
   3152       else:
   3153         if hasattr(update, 'op'):
   3154           update = update.op
   3155         updates_ops.append(update)
   3156 
   3157     with _scratch_graph() as exec_graph:
   3158       global_graph = get_graph()
   3159       if source_graph not in (exec_graph, global_graph):
   3160         raise ValueError('Unknown graph. Aborting.')
   3161 
   3162       if source_graph is global_graph and exec_graph is not global_graph:
   3163         init_tensors = (
   3164             outputs + updates_ops + [p for [p, _] in legacy_update_ops] +
   3165             [p_new for [_, p_new] in legacy_update_ops
   3166              if isinstance(p_new, ops.Tensor)])
   3167         lifted_map = lift_to_graph.lift_to_graph(
   3168             init_tensors=init_tensors, graph=exec_graph, sources=inputs,
   3169             add_sources=True, handle_captures=True, base_graph=source_graph)
   3170 
   3171         inputs = [lifted_map[i] for i in inputs]
   3172         outputs = [lifted_map[i] for i in outputs]
   3173         updates_ops = [lifted_map[i] for i in updates_ops]
   3174         legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
   3175                              for p, p_new in legacy_update_ops]
   3176 
   3177     # Consolidate updates
   3178     with exec_graph.as_default():
   3179       outputs = cast_variables_to_tensor(outputs)
   3180       with ops.control_dependencies(outputs):
   3181         for p, p_new in legacy_update_ops:
   3182           updates_ops.append(state_ops.assign(p, p_new))
   3183 
   3184       self.inputs, self.outputs = inputs, outputs
   3185       with ops.control_dependencies(updates_ops):
   3186         self.outputs[0] = array_ops.identity(self.outputs[0])
   3187 
   3188       exec_graph.inputs = self.inputs + list(exec_graph.captures.values())
   3189       exec_graph.outputs = self.outputs
   3190       graph_fn = eager_function.ConcreteFunction(exec_graph)
   3191 
   3192     graph_fn._num_positional_args = len(self.inputs)
   3193     graph_fn._arg_keywords = []
   3194     self._graph_fn = graph_fn
   3195 
   3196     # Handle placeholders with default
   3197     # (treated as required placeholder by graph functions)
   3198     self._placeholder_default_values = {}
   3199     with exec_graph.as_default():
   3200       for x in self.inputs:
   3201         if x.op.type == 'PlaceholderWithDefault':
   3202           self._placeholder_default_values[x] = tensor_util.constant_value(
   3203               x.op.inputs[0])
   3204 
   3205   def __call__(self, inputs):
   3206     inputs = nest.flatten(inputs)
   3207     converted_inputs = []
   3208     for tensor, value in zip(self.inputs, inputs):
   3209       if value is None:
   3210         # Assume `value` is a placeholder with default
   3211         value = self._placeholder_default_values.get(tensor, None)
   3212         if value is None:
   3213           raise ValueError(
   3214               'You must feed a value for placeholder %s' % (tensor,))
   3215       if not isinstance(value, ops.Tensor):
   3216         value = ops.convert_to_tensor(value, dtype=tensor.dtype)
   3217       if value.dtype != tensor.dtype:
   3218         # Temporary workaround due to `convert_to_tensor` not casting floats.
   3219         # See b/119637405
   3220         value = math_ops.cast(value, tensor.dtype)
   3221       converted_inputs.append(value)
   3222     outputs = self._graph_fn(*converted_inputs)
   3223     return nest.pack_sequence_as(self._outputs_structure,
   3224                                  [x.numpy() for x in outputs])
   3225 
   3226 
   3227 @keras_export('keras.backend.function')
   3228 def function(inputs, outputs, updates=None, name=None, **kwargs):
   3229   """Instantiates a Keras function.
   3230 
   3231   Arguments:
   3232       inputs: List of placeholder tensors.
   3233       outputs: List of output tensors.
   3234       updates: List of update ops.
   3235       name: String, name of function.
   3236       **kwargs: Passed to `tf.Session.run`.
   3237 
   3238   Returns:
   3239       Output values as Numpy arrays.
   3240 
   3241   Raises:
   3242       ValueError: if invalid kwargs are passed in or if in eager execution.
   3243   """
   3244   if ops.executing_eagerly_outside_functions():
   3245     if kwargs:
   3246       raise ValueError('Session keyword arguments are not support during '
   3247                        'eager execution. You passed: %s' % (kwargs,))
   3248     return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
   3249 
   3250   if kwargs:
   3251     for key in kwargs:
   3252       if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
   3253           and key not in ['inputs', 'outputs', 'updates', 'name']):
   3254         msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
   3255                'backend') % key
   3256         raise ValueError(msg)
   3257   return GraphExecutionFunction(inputs, outputs, updates=updates, **kwargs)
   3258 
   3259 
   3260 @keras_export('keras.backend.gradients')
   3261 def gradients(loss, variables):
   3262   """Returns the gradients of `loss` w.r.t. `variables`.
   3263 
   3264   Arguments:
   3265       loss: Scalar tensor to minimize.
   3266       variables: List of variables.
   3267 
   3268   Returns:
   3269       A gradients tensor.
   3270   """
   3271   return gradients_module.gradients(
   3272       loss, variables, colocate_gradients_with_ops=True)
   3273 
   3274 
   3275 @keras_export('keras.backend.stop_gradient')
   3276 def stop_gradient(variables):
   3277   """Returns `variables` but with zero gradient w.r.t. every other variable.
   3278 
   3279   Arguments:
   3280       variables: Tensor or list of tensors to consider constant with respect
   3281         to any other variable.
   3282 
   3283 
   3284   Returns:
   3285       A single tensor or a list of tensors (depending on the passed argument)
   3286       that has no gradient with respect to any other variable.
   3287   """
   3288   if isinstance(variables, (list, tuple)):
   3289     return map(array_ops.stop_gradient, variables)
   3290   return array_ops.stop_gradient(variables)
   3291 
   3292 
   3293 # CONTROL FLOW
   3294 
   3295 
   3296 @keras_export('keras.backend.rnn')
   3297 def rnn(step_function,
   3298         inputs,
   3299         initial_states,
   3300         go_backwards=False,
   3301         mask=None,
   3302         constants=None,
   3303         unroll=False,
   3304         input_length=None,
   3305         time_major=False,
   3306         zero_output_for_mask=False):
   3307   """Iterates over the time dimension of a tensor.
   3308 
   3309   Arguments:
   3310       step_function: RNN step function.
   3311           Args;
   3312               input; Tensor with shape `(samples, ...)` (no time dimension),
   3313                   representing input for the batch of samples at a certain
   3314                   time step.
   3315               states; List of tensors.
   3316           Returns;
   3317               output; Tensor with shape `(samples, output_dim)`
   3318                   (no time dimension).
   3319               new_states; List of tensors, same length and shapes
   3320                   as 'states'. The first state in the list must be the
   3321                   output tensor at the previous timestep.
   3322       inputs: Tensor of temporal data of shape `(samples, time, ...)`
   3323           (at least 3D), or nested tensors, and each of which has shape
   3324           `(samples, time, ...)`.
   3325       initial_states: Tensor with shape `(samples, state_size)`
   3326           (no time dimension), containing the initial values for the states used
   3327           in the step function. In the case that state_size is in a nested
   3328           shape, the shape of initial_states will also follow the nested
   3329           structure.
   3330       go_backwards: Boolean. If True, do the iteration over the time
   3331           dimension in reverse order and return the reversed sequence.
   3332       mask: Binary tensor with shape `(samples, time, 1)`,
   3333           with a zero for every element that is masked.
   3334       constants: List of constant values passed at each step.
   3335       unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
   3336       input_length: If specified, assume time dimension is of this length.
   3337       time_major: Boolean. If true, the inputs and outputs will be in shape
   3338           `(timesteps, batch, ...)`, whereas in the False case, it will be
   3339           `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
   3340           efficient because it avoids transposes at the beginning and end of the
   3341           RNN calculation. However, most TensorFlow data is batch-major, so by
   3342           default this function accepts input and emits output in batch-major
   3343           form.
   3344       zero_output_for_mask: Boolean. If True, the output for masked timestep
   3345           will be zeros, whereas in the False case, output from previous
   3346           timestep is returned.
   3347   Returns:
   3348       A tuple, `(last_output, outputs, new_states)`.
   3349           last_output: the latest output of the rnn, of shape `(samples, ...)`
   3350           outputs: tensor with shape `(samples, time, ...)` where each
   3351               entry `outputs[s, t]` is the output of the step function
   3352               at time `t` for sample `s`.
   3353           new_states: list of tensors, latest states returned by
   3354               the step function, of shape `(samples, ...)`.
   3355 
   3356   Raises:
   3357       ValueError: if input dimension is less than 3.
   3358       ValueError: if `unroll` is `True` but input timestep is not a fixed
   3359       number.
   3360       ValueError: if `mask` is provided (not `None`) but states is not provided
   3361           (`len(states)` == 0).
   3362   """
   3363 
   3364   def swap_batch_timestep(input_t):
   3365     # Swap the batch and timestep dim for the incoming tensor.
   3366     axes = list(range(len(input_t.shape)))
   3367     axes[0], axes[1] = 1, 0
   3368     return array_ops.transpose(input_t, axes)
   3369 
   3370   if not time_major:
   3371     inputs = nest.map_structure(swap_batch_timestep, inputs)
   3372 
   3373   flatted_inputs = nest.flatten(inputs)
   3374   time_steps = flatted_inputs[0].shape[0]
   3375   batch = flatted_inputs[0].shape[1]
   3376   time_steps_t = array_ops.shape(flatted_inputs[0])[0]
   3377 
   3378   for input_ in flatted_inputs:
   3379     input_.get_shape().with_rank_at_least(3)
   3380 
   3381   if mask is not None:
   3382     if mask.dtype != dtypes_module.bool:
   3383       mask = math_ops.cast(mask, dtypes_module.bool)
   3384     if len(mask.shape) == 2:
   3385       mask = expand_dims(mask)
   3386     if not time_major:
   3387       mask = swap_batch_timestep(mask)
   3388 
   3389   if constants is None:
   3390     constants = []
   3391 
   3392   # tf.where needs its condition tensor to be the same shape as its two
   3393   # result tensors, but in our case the condition (mask) tensor is
   3394   # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
   3395   # So we need to broadcast the mask to match the shape of inputs.
   3396   # That's what the tile call does, it just repeats the mask along its
   3397   # second dimension n times.
   3398   def _expand_mask(mask_t, input_t, fixed_dim=1):
   3399     assert not nest.is_sequence(mask_t)
   3400     assert not nest.is_sequence(input_t)
   3401     rank_diff = len(input_t.shape) - len(mask_t.shape)
   3402     for _ in range(rank_diff):
   3403       mask_t = array_ops.expand_dims(mask_t, -1)
   3404     multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
   3405     return array_ops.tile(mask_t, multiples)
   3406 
   3407   if unroll:
   3408     if not time_steps:
   3409       raise ValueError('Unrolling requires a fixed number of timesteps.')
   3410     states = tuple(initial_states)
   3411     successive_states = []
   3412     successive_outputs = []
   3413 
   3414     # Process the input tensors. The input tensor need to be split on the
   3415     # time_step dim, and reverse if go_backwards is True. In the case of nested
   3416     # input, the input is flattened and then transformed individually.
   3417     # The result of this will be a tuple of lists, each of the item in tuple is
   3418     # list of the tensor with shape (batch, feature)
   3419     def _process_single_input_t(input_t):
   3420       input_t = array_ops.unstack(input_t)  # unstack for time_step dim
   3421       if go_backwards:
   3422         input_t.reverse()
   3423       return input_t
   3424 
   3425     if nest.is_sequence(inputs):
   3426       processed_input = nest.map_structure(_process_single_input_t, inputs)
   3427     else:
   3428       processed_input = (_process_single_input_t(inputs),)
   3429 
   3430     def _get_input_tensor(time):
   3431       inp = [t_[time] for t_ in processed_input]
   3432       return nest.pack_sequence_as(inputs, inp)
   3433 
   3434     if mask is not None:
   3435       mask_list = array_ops.unstack(mask)
   3436       if go_backwards:
   3437         mask_list.reverse()
   3438 
   3439       for i in range(time_steps):
   3440         inp = _get_input_tensor(i)
   3441         mask_t = mask_list[i]
   3442         output, new_states = step_function(inp,
   3443                                            tuple(states) + tuple(constants))
   3444         tiled_mask_t = _expand_mask(mask_t, output)
   3445 
   3446         if not successive_outputs:
   3447           prev_output = zeros_like(output)
   3448         else:
   3449           prev_output = successive_outputs[-1]
   3450 
   3451         output = array_ops.where(tiled_mask_t, output, prev_output)
   3452 
   3453         return_states = []
   3454         for state, new_state in zip(states, new_states):
   3455           # (see earlier comment for tile explanation)
   3456           tiled_mask_t = _expand_mask(mask_t, new_state)
   3457           return_states.append(array_ops.where(tiled_mask_t, new_state, state))
   3458         states = return_states
   3459         successive_outputs.append(output)
   3460         successive_states.append(states)
   3461       last_output = successive_outputs[-1]
   3462       new_states = successive_states[-1]
   3463       outputs = array_ops.stack(successive_outputs)
   3464 
   3465       if zero_output_for_mask:
   3466         last_output = array_ops.where(
   3467             _expand_mask(mask_list[-1], last_output),
   3468             last_output,
   3469             zeros_like(last_output))
   3470         outputs = array_ops.where(
   3471             _expand_mask(mask, outputs, fixed_dim=2),
   3472             outputs,
   3473             zeros_like(outputs))
   3474 
   3475     else:
   3476       for i in range(time_steps):
   3477         inp = _get_input_tensor(i)
   3478         output, states = step_function(inp, tuple(states) + tuple(constants))
   3479         successive_outputs.append(output)
   3480         successive_states.append(states)
   3481       last_output = successive_outputs[-1]
   3482       new_states = successive_states[-1]
   3483       outputs = array_ops.stack(successive_outputs)
   3484 
   3485   else:
   3486     states = tuple(initial_states)
   3487 
   3488     # Create input tensor array, if the inputs is nested tensors, then it will
   3489     # be flattened first, and tensor array will be created one per flattened
   3490     # tensor.
   3491     input_ta = tuple(
   3492         tensor_array_ops.TensorArray(
   3493             dtype=inp.dtype,
   3494             size=time_steps_t,
   3495             tensor_array_name='input_ta_%s' % i)
   3496         for i, inp in enumerate(flatted_inputs))
   3497     input_ta = tuple(
   3498         ta.unstack(input_) if not go_backwards else ta
   3499         .unstack(reverse(input_, 0))
   3500         for ta, input_ in zip(input_ta, flatted_inputs))
   3501 
   3502     # Get the time(0) input and compute the output for that, the output will be
   3503     # used to determine the dtype of output tensor array. Don't read from
   3504     # input_ta due to TensorArray clear_after_read default to True.
   3505     input_time_zero = nest.pack_sequence_as(inputs,
   3506                                             [inp[0] for inp in flatted_inputs])
   3507     # output_time_zero is used to determine the cell output shape and its dtype.
   3508     # the value is discarded.
   3509     output_time_zero, _ = step_function(input_time_zero,
   3510                                         initial_states + constants)
   3511     output_ta = tuple(
   3512         tensor_array_ops.TensorArray(
   3513             dtype=out.dtype,
   3514             size=time_steps_t,
   3515             tensor_array_name='output_ta_%s' % i)
   3516         for i, out in enumerate(nest.flatten(output_time_zero)))
   3517 
   3518     time = constant_op.constant(0, dtype='int32', name='time')
   3519 
   3520     while_loop_kwargs = {
   3521         'cond': lambda time, *_: time < time_steps_t,
   3522         'maximum_iterations': input_length,
   3523         'parallel_iterations': 32,
   3524         'swap_memory': True,
   3525     }
   3526 
   3527     if mask is not None:
   3528       if not states:
   3529         raise ValueError('No initial states provided! '
   3530                          'When using masking in an RNN, you should '
   3531                          'provide initial states '
   3532                          '(and your step function should return '
   3533                          'as its first state at time `t` '
   3534                          'the output at time `t-1`).')
   3535       if go_backwards:
   3536         mask = reverse(mask, 0)
   3537 
   3538       mask_ta = tensor_array_ops.TensorArray(
   3539           dtype=dtypes_module.bool,
   3540           size=time_steps_t,
   3541           tensor_array_name='mask_ta')
   3542       mask_ta = mask_ta.unstack(mask)
   3543 
   3544       # Mask for the T output will be base on the output of T - 1. In the case
   3545       # T = 0, a zero filled tensor will be used.
   3546       flat_zero_output = tuple(array_ops.zeros_like(o)
   3547                                for o in nest.flatten(output_time_zero))
   3548       def _step(time, output_ta_t, prev_output, *states):
   3549         """RNN step function.
   3550 
   3551         Arguments:
   3552             time: Current timestep value.
   3553             output_ta_t: TensorArray.
   3554             prev_output: tuple of outputs from time - 1.
   3555             *states: List of states.
   3556 
   3557         Returns:
   3558             Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
   3559         """
   3560         current_input = tuple(ta.read(time) for ta in input_ta)
   3561         # maybe set shape.
   3562         current_input = nest.pack_sequence_as(inputs, current_input)
   3563         mask_t = mask_ta.read(time)
   3564         output, new_states = step_function(current_input,
   3565                                            tuple(states) + tuple(constants))
   3566         # mask output
   3567         flat_output = nest.flatten(output)
   3568         flat_mask_output = (flat_zero_output if zero_output_for_mask
   3569                             else nest.flatten(prev_output))
   3570         tiled_mask_t = tuple(_expand_mask(mask_t, o) for o in flat_output)
   3571         flat_new_output = tuple(
   3572             array_ops.where(m, o, zo) for m, o, zo in zip(
   3573                 tiled_mask_t, flat_output, flat_mask_output))
   3574 
   3575         # mask states
   3576         flat_state = nest.flatten(states)
   3577         flat_new_state = nest.flatten(new_states)
   3578         for state, new_state in zip(flat_state, flat_new_state):
   3579           new_state.set_shape(state.shape)
   3580         tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_state)
   3581         flat_final_state = tuple(
   3582             array_ops.where(m, s, ps)
   3583             for m, s, ps in zip(tiled_mask_t, flat_new_state, flat_state))
   3584         new_states = nest.pack_sequence_as(new_states, flat_final_state)
   3585 
   3586         output_ta_t = tuple(
   3587             ta.write(time, out)
   3588             for ta, out in zip(output_ta_t, flat_new_output))
   3589         return (time + 1, output_ta_t,
   3590                 tuple(flat_new_output)) + tuple(new_states)
   3591 
   3592       final_outputs = control_flow_ops.while_loop(
   3593           body=_step,
   3594           loop_vars=(time, output_ta, flat_zero_output) + states,
   3595           **while_loop_kwargs)
   3596       # Skip final_outputs[2] which is the output for final timestep.
   3597       new_states = final_outputs[3:]
   3598     else:
   3599       def _step(time, output_ta_t, *states):
   3600         """RNN step function.
   3601 
   3602         Arguments:
   3603             time: Current timestep value.
   3604             output_ta_t: TensorArray.
   3605             *states: List of states.
   3606 
   3607         Returns:
   3608             Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
   3609         """
   3610         current_input = tuple(ta.read(time) for ta in input_ta)
   3611         current_input = nest.pack_sequence_as(inputs, current_input)
   3612         output, new_states = step_function(current_input,
   3613                                            tuple(states) + tuple(constants))
   3614         flat_state = nest.flatten(states)
   3615         flat_new_state = nest.flatten(new_states)
   3616         for state, new_state in zip(flat_state, flat_new_state):
   3617           new_state.set_shape(state.shape)
   3618 
   3619         flat_output = nest.flatten(output)
   3620         output_ta_t = tuple(
   3621             ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
   3622         new_states = nest.pack_sequence_as(initial_states, flat_new_state)
   3623         return (time + 1, output_ta_t) + tuple(new_states)
   3624 
   3625       final_outputs = control_flow_ops.while_loop(
   3626           body=_step,
   3627           loop_vars=(time, output_ta) + states,
   3628           **while_loop_kwargs)
   3629       new_states = final_outputs[2:]
   3630 
   3631     output_ta = final_outputs[1]
   3632 
   3633     outputs = tuple(o.stack() for o in output_ta)
   3634     last_output = tuple(o[-1] for o in outputs)
   3635 
   3636     outputs = nest.pack_sequence_as(output_time_zero, outputs)
   3637     last_output = nest.pack_sequence_as(output_time_zero, last_output)
   3638 
   3639   # static shape inference
   3640   def set_shape(output_):
   3641     shape = output_.shape.as_list()
   3642     shape[0] = time_steps
   3643     shape[1] = batch
   3644     output_.set_shape(shape)
   3645     return output_
   3646 
   3647   outputs = nest.map_structure(set_shape, outputs)
   3648 
   3649   if not time_major:
   3650     outputs = nest.map_structure(swap_batch_timestep, outputs)
   3651 
   3652   return last_output, outputs, new_states
   3653 
   3654 
   3655 @keras_export('keras.backend.switch')
   3656 def switch(condition, then_expression, else_expression):
   3657   """Switches between two operations depending on a scalar value.
   3658 
   3659   Note that both `then_expression` and `else_expression`
   3660   should be symbolic tensors of the *same shape*.
   3661 
   3662   Arguments:
   3663       condition: tensor (`int` or `bool`).
   3664       then_expression: either a tensor, or a callable that returns a tensor.
   3665       else_expression: either a tensor, or a callable that returns a tensor.
   3666 
   3667   Returns:
   3668       The selected tensor.
   3669 
   3670   Raises:
   3671       ValueError: If rank of `condition` is greater than rank of expressions.
   3672   """
   3673   if condition.dtype != dtypes_module.bool:
   3674     condition = math_ops.cast(condition, 'bool')
   3675   cond_ndim = ndim(condition)
   3676   if not cond_ndim:
   3677     if not callable(then_expression):
   3678 
   3679       def then_expression_fn():
   3680         return then_expression
   3681     else:
   3682       then_expression_fn = then_expression
   3683     if not callable(else_expression):
   3684 
   3685       def else_expression_fn():
   3686         return else_expression
   3687     else:
   3688       else_expression_fn = else_expression
   3689     x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
   3690   else:
   3691     # tf.where needs its condition tensor
   3692     # to be the same shape as its two
   3693     # result tensors
   3694     if callable(then_expression):
   3695       then_expression = then_expression()
   3696     if callable(else_expression):
   3697       else_expression = else_expression()
   3698     expr_ndim = ndim(then_expression)
   3699     if cond_ndim > expr_ndim:
   3700       raise ValueError('Rank of `condition` should be less than or'
   3701                        ' equal to rank of `then_expression` and '
   3702                        '`else_expression`. ndim(condition)=' + str(cond_ndim) +
   3703                        ', ndim(then_expression)'
   3704                        '=' + str(expr_ndim))
   3705     if cond_ndim > 1:
   3706       ndim_diff = expr_ndim - cond_ndim
   3707       cond_shape = array_ops.concat(
   3708           [array_ops.shape(condition), [1] * ndim_diff], axis=0)
   3709       condition = array_ops.reshape(condition, cond_shape)
   3710       expr_shape = array_ops.shape(then_expression)
   3711       shape_diff = expr_shape - cond_shape
   3712       tile_shape = array_ops.where(shape_diff > 0, expr_shape,
   3713                                    array_ops.ones_like(expr_shape))
   3714       condition = array_ops.tile(condition, tile_shape)
   3715     x = array_ops.where(condition, then_expression, else_expression)
   3716   return x
   3717 
   3718 
   3719 @keras_export('keras.backend.in_train_phase')
   3720 def in_train_phase(x, alt, training=None):
   3721   """Selects `x` in train phase, and `alt` otherwise.
   3722 
   3723   Note that `alt` should have the *same shape* as `x`.
   3724 
   3725   Arguments:
   3726       x: What to return in train phase
   3727           (tensor or callable that returns a tensor).
   3728       alt: What to return otherwise
   3729           (tensor or callable that returns a tensor).
   3730       training: Optional scalar tensor
   3731           (or Python boolean, or Python integer)
   3732           specifying the learning phase.
   3733 
   3734   Returns:
   3735       Either `x` or `alt` based on the `training` flag.
   3736       the `training` flag defaults to `K.learning_phase()`.
   3737   """
   3738   if training is None:
   3739     training = learning_phase()
   3740 
   3741   if training == 1 or training is True:
   3742     if callable(x):
   3743       return x()
   3744     else:
   3745       return x
   3746 
   3747   elif training == 0 or training is False:
   3748     if callable(alt):
   3749       return alt()
   3750     else:
   3751       return alt
   3752 
   3753   # else: assume learning phase is a placeholder tensor.
   3754   x = switch(training, x, alt)
   3755   return x
   3756 
   3757 
   3758 @keras_export('keras.backend.in_test_phase')
   3759 def in_test_phase(x, alt, training=None):
   3760   """Selects `x` in test phase, and `alt` otherwise.
   3761 
   3762   Note that `alt` should have the *same shape* as `x`.
   3763 
   3764   Arguments:
   3765       x: What to return in test phase
   3766           (tensor or callable that returns a tensor).
   3767       alt: What to return otherwise
   3768           (tensor or callable that returns a tensor).
   3769       training: Optional scalar tensor
   3770           (or Python boolean, or Python integer)
   3771           specifying the learning phase.
   3772 
   3773   Returns:
   3774       Either `x` or `alt` based on `K.learning_phase`.
   3775   """
   3776   return in_train_phase(alt, x, training=training)
   3777 
   3778 
   3779 # NN OPERATIONS
   3780 
   3781 
   3782 @keras_export('keras.backend.relu')
   3783 def relu(x, alpha=0., max_value=None, threshold=0):
   3784   """Rectified linear unit.
   3785 
   3786   With default values, it returns element-wise `max(x, 0)`.
   3787 
   3788   Otherwise, it follows:
   3789   `f(x) = max_value` for `x >= max_value`,
   3790   `f(x) = x` for `threshold <= x < max_value`,
   3791   `f(x) = alpha * (x - threshold)` otherwise.
   3792 
   3793   Arguments:
   3794       x: A tensor or variable.
   3795       alpha: A scalar, slope of negative section (default=`0.`).
   3796       max_value: float. Saturation threshold.
   3797       threshold: float. Threshold value for thresholded activation.
   3798 
   3799   Returns:
   3800       A tensor.
   3801   """
   3802 
   3803   if alpha != 0.:
   3804     if max_value is None and threshold == 0:
   3805       return nn.leaky_relu(x, alpha=alpha)
   3806 
   3807     if threshold != 0:
   3808       negative_part = nn.relu(-x + threshold)
   3809     else:
   3810       negative_part = nn.relu(-x)
   3811 
   3812   clip_max = max_value is not None
   3813 
   3814   if threshold != 0:
   3815     # computes x for x > threshold else 0
   3816     x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
   3817   elif max_value == 6:
   3818     # if no threshold, then can use nn.relu6 native TF op for performance
   3819     x = nn.relu6(x)
   3820     clip_max = False
   3821   else:
   3822     x = nn.relu(x)
   3823 
   3824   if clip_max:
   3825     max_value = _to_tensor(max_value, x.dtype.base_dtype)
   3826     zero = _to_tensor(0., x.dtype.base_dtype)
   3827     x = clip_ops.clip_by_value(x, zero, max_value)
   3828 
   3829   if alpha != 0.:
   3830     alpha = _to_tensor(alpha, x.dtype.base_dtype)
   3831     x -= alpha * negative_part
   3832   return x
   3833 
   3834 
   3835 @keras_export('keras.backend.elu')
   3836 def elu(x, alpha=1.):
   3837   """Exponential linear unit.
   3838 
   3839   Arguments:
   3840       x: A tensor or variable to compute the activation function for.
   3841       alpha: A scalar, slope of negative section.
   3842 
   3843   Returns:
   3844       A tensor.
   3845   """
   3846   res = nn.elu(x)
   3847   if alpha == 1:
   3848     return res
   3849   else:
   3850     return array_ops.where(x > 0, res, alpha * res)
   3851 
   3852 
   3853 @keras_export('keras.backend.softmax')
   3854 def softmax(x, axis=-1):
   3855   """Softmax of a tensor.
   3856 
   3857   Arguments:
   3858       x: A tensor or variable.
   3859       axis: The dimension softmax would be performed on.
   3860           The default is -1 which indicates the last dimension.
   3861 
   3862   Returns:
   3863       A tensor.
   3864   """
   3865   return nn.softmax(x, axis=axis)
   3866 
   3867 
   3868 @keras_export('keras.backend.softplus')
   3869 def softplus(x):
   3870   """Softplus of a tensor.
   3871 
   3872   Arguments:
   3873       x: A tensor or variable.
   3874 
   3875   Returns:
   3876       A tensor.
   3877   """
   3878   return nn.softplus(x)
   3879 
   3880 
   3881 @keras_export('keras.backend.softsign')
   3882 def softsign(x):
   3883   """Softsign of a tensor.
   3884 
   3885   Arguments:
   3886       x: A tensor or variable.
   3887 
   3888   Returns:
   3889       A tensor.
   3890   """
   3891   return nn.softsign(x)
   3892 
   3893 
   3894 @keras_export('keras.backend.categorical_crossentropy')
   3895 def categorical_crossentropy(target, output, from_logits=False, axis=-1):
   3896   """Categorical crossentropy between an output tensor and a target tensor.
   3897 
   3898   Arguments:
   3899       target: A tensor of the same shape as `output`.
   3900       output: A tensor resulting from a softmax
   3901           (unless `from_logits` is True, in which
   3902           case `output` is expected to be the logits).
   3903       from_logits: Boolean, whether `output` is the
   3904           result of a softmax, or is a tensor of logits.
   3905       axis: Int specifying the channels axis. `axis=-1` corresponds to data
   3906           format `channels_last', and `axis=1` corresponds to data format
   3907           `channels_first`.
   3908 
   3909   Returns:
   3910       Output tensor.
   3911 
   3912   Raises:
   3913       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
   3914   """
   3915   if not from_logits:
   3916     if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
   3917         output.op.type != 'Softmax'):
   3918       axis = axis % len(output.shape)
   3919       # scale preds so that the class probas of each sample sum to 1
   3920       output = output / math_ops.reduce_sum(output, axis, True)
   3921       # Compute cross entropy from probabilities.
   3922       epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
   3923       output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
   3924       return -math_ops.reduce_sum(target * math_ops.log(output), axis)
   3925     else:
   3926       # When softmax activation function is used for output operation, we
   3927       # use logits from the softmax function directly to compute loss in order
   3928       # to prevent collapsing zero when training.
   3929       # See b/117284466
   3930       assert len(output.op.inputs) == 1
   3931       output = output.op.inputs[0]
   3932   return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output)
   3933 
   3934 
   3935 @keras_export('keras.backend.sparse_categorical_crossentropy')
   3936 def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
   3937   """Categorical crossentropy with integer targets.
   3938 
   3939   Arguments:
   3940       target: An integer tensor.
   3941       output: A tensor resulting from a softmax
   3942           (unless `from_logits` is True, in which
   3943           case `output` is expected to be the logits).
   3944       from_logits: Boolean, whether `output` is the
   3945           result of a softmax, or is a tensor of logits.
   3946       axis: Int specifying the channels axis. `axis=-1` corresponds to data
   3947           format `channels_last', and `axis=1` corresponds to data format
   3948           `channels_first`.
   3949 
   3950   Returns:
   3951       Output tensor.
   3952 
   3953   Raises:
   3954       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
   3955   """
   3956   if not from_logits:
   3957     if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
   3958         output.op.type != 'Softmax'):
   3959       epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
   3960       output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
   3961       output = math_ops.log(output)
   3962     else:
   3963       # When softmax activation function is used for output operation, we
   3964       # use logits from the softmax function directly to compute loss in order
   3965       # to prevent collapsing zero when training.
   3966       # See b/117284466
   3967       assert len(output.op.inputs) == 1
   3968       output = output.op.inputs[0]
   3969 
   3970   rank = len(output.shape)
   3971   axis = axis % rank
   3972   if axis != rank - 1:
   3973     permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis]
   3974     output = array_ops.transpose(output, perm=permutation)
   3975 
   3976   output_shape = output.shape
   3977   targets = cast(flatten(target), 'int64')
   3978   logits = array_ops.reshape(output, [-1, int(output_shape[-1])])
   3979   res = nn.sparse_softmax_cross_entropy_with_logits(
   3980       labels=targets, logits=logits)
   3981   if len(output_shape) >= 3:
   3982     # If our output includes timesteps or spatial dimensions we need to reshape
   3983     return array_ops.reshape(res, array_ops.shape(output)[:-1])
   3984   else:
   3985     return res
   3986 
   3987 
   3988 @keras_export('keras.backend.binary_crossentropy')
   3989 def binary_crossentropy(target, output, from_logits=False):
   3990   """Binary crossentropy between an output tensor and a target tensor.
   3991 
   3992   Arguments:
   3993       target: A tensor with the same shape as `output`.
   3994       output: A tensor.
   3995       from_logits: Whether `output` is expected to be a logits tensor.
   3996           By default, we consider that `output`
   3997           encodes a probability distribution.
   3998 
   3999   Returns:
   4000       A tensor.
   4001   """
   4002   if not from_logits:
   4003     if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
   4004         output.op.type != 'Sigmoid'):
   4005       epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
   4006       output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
   4007 
   4008       # Compute cross entropy from probabilities.
   4009       bce = target * math_ops.log(output + epsilon())
   4010       bce += (1 - target) * math_ops.log(1 - output + epsilon())
   4011       return -bce
   4012     else:
   4013       # When sigmoid activation function is used for output operation, we
   4014       # use logits from the sigmoid function directly to compute loss in order
   4015       # to prevent collapsing zero when training.
   4016       assert len(output.op.inputs) == 1
   4017       output = output.op.inputs[0]
   4018   return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
   4019 
   4020 
   4021 @keras_export('keras.backend.sigmoid')
   4022 def sigmoid(x):
   4023   """Element-wise sigmoid.
   4024 
   4025   Arguments:
   4026       x: A tensor or variable.
   4027 
   4028   Returns:
   4029       A tensor.
   4030   """
   4031   return nn.sigmoid(x)
   4032 
   4033 
   4034 @keras_export('keras.backend.hard_sigmoid')
   4035 def hard_sigmoid(x):
   4036   """Segment-wise linear approximation of sigmoid.
   4037 
   4038   Faster than sigmoid.
   4039   Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
   4040   In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
   4041 
   4042   Arguments:
   4043       x: A tensor or variable.
   4044 
   4045   Returns:
   4046       A tensor.
   4047   """
   4048   x = (0.2 * x) + 0.5
   4049   zero = _to_tensor(0., x.dtype.base_dtype)
   4050   one = _to_tensor(1., x.dtype.base_dtype)
   4051   x = clip_ops.clip_by_value(x, zero, one)
   4052   return x
   4053 
   4054 
   4055 @keras_export('keras.backend.tanh')
   4056 def tanh(x):
   4057   """Element-wise tanh.
   4058 
   4059   Arguments:
   4060       x: A tensor or variable.
   4061 
   4062   Returns:
   4063       A tensor.
   4064   """
   4065   return nn.tanh(x)
   4066 
   4067 
   4068 @keras_export('keras.backend.dropout')
   4069 def dropout(x, level, noise_shape=None, seed=None):
   4070   """Sets entries in `x` to zero at random, while scaling the entire tensor.
   4071 
   4072   Arguments:
   4073       x: tensor
   4074       level: fraction of the entries in the tensor
   4075           that will be set to 0.
   4076       noise_shape: shape for randomly generated keep/drop flags,
   4077           must be broadcastable to the shape of `x`
   4078       seed: random seed to ensure determinism.
   4079 
   4080   Returns:
   4081       A tensor.
   4082   """
   4083   if seed is None:
   4084     seed = np.random.randint(10e6)
   4085   return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
   4086 
   4087 
   4088 @keras_export('keras.backend.l2_normalize')
   4089 def l2_normalize(x, axis=None):
   4090   """Normalizes a tensor wrt the L2 norm alongside the specified axis.
   4091 
   4092   Arguments:
   4093       x: Tensor or variable.
   4094       axis: axis along which to perform normalization.
   4095 
   4096   Returns:
   4097       A tensor.
   4098   """
   4099   return nn.l2_normalize(x, axis=axis)
   4100 
   4101 
   4102 @keras_export('keras.backend.in_top_k')
   4103 def in_top_k(predictions, targets, k):
   4104   """Returns whether the `targets` are in the top `k` `predictions`.
   4105 
   4106   Arguments:
   4107       predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
   4108       targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
   4109       k: An `int`, number of top elements to consider.
   4110 
   4111   Returns:
   4112       A 1D tensor of length `batch_size` and type `bool`.
   4113       `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
   4114       values of `predictions[i]`.
   4115   """
   4116   return nn.in_top_k(predictions, targets, k)
   4117 
   4118 
   4119 # CONVOLUTIONS
   4120 
   4121 
   4122 def _preprocess_conv1d_input(x, data_format):
   4123   """Transpose and cast the input before the conv1d.
   4124 
   4125   Arguments:
   4126       x: input tensor.
   4127       data_format: string, `"channels_last"` or `"channels_first"`.
   4128 
   4129   Returns:
   4130       A tensor.
   4131   """
   4132   tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
   4133   if data_format == 'channels_first':
   4134     if not _has_nchw_support():
   4135       x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
   4136     else:
   4137       tf_data_format = 'NCW'
   4138   return x, tf_data_format
   4139 
   4140 
   4141 def _preprocess_conv2d_input(x, data_format, force_transpose=False):
   4142   """Transpose and cast the input before the conv2d.
   4143 
   4144   Arguments:
   4145       x: input tensor.
   4146       data_format: string, `"channels_last"` or `"channels_first"`.
   4147       force_transpose: Boolean. If True, the input will always be transposed
   4148           from NCHW to NHWC if `data_format` is `"channels_first"`.
   4149           If False, the transposition only occurs on CPU (GPU ops are
   4150           assumed to support NCHW).
   4151 
   4152   Returns:
   4153       A tensor.
   4154   """
   4155   tf_data_format = 'NHWC'
   4156   if data_format == 'channels_first':
   4157     if not _has_nchw_support() or force_transpose:
   4158       x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
   4159     else:
   4160       tf_data_format = 'NCHW'
   4161   return x, tf_data_format
   4162 
   4163 
   4164 def _preprocess_conv3d_input(x, data_format):
   4165   """Transpose and cast the input before the conv3d.
   4166 
   4167   Arguments:
   4168       x: input tensor.
   4169       data_format: string, `"channels_last"` or `"channels_first"`.
   4170 
   4171   Returns:
   4172       A tensor.
   4173   """
   4174   tf_data_format = 'NDHWC'
   4175   if data_format == 'channels_first':
   4176     if not _has_nchw_support():
   4177       x = array_ops.transpose(x, (0, 2, 3, 4, 1))
   4178     else:
   4179       tf_data_format = 'NCDHW'
   4180   return x, tf_data_format
   4181 
   4182 
   4183 def _preprocess_padding(padding):
   4184   """Convert keras' padding to TensorFlow's padding.
   4185 
   4186   Arguments:
   4187       padding: string, one of 'same' , 'valid'
   4188 
   4189   Returns:
   4190       a string, one of 'SAME', 'VALID'.
   4191 
   4192   Raises:
   4193       ValueError: if invalid `padding'`
   4194   """
   4195   if padding == 'same':
   4196     padding = 'SAME'
   4197   elif padding == 'valid':
   4198     padding = 'VALID'
   4199   else:
   4200     raise ValueError('Invalid padding: ' + str(padding))
   4201   return padding
   4202 
   4203 
   4204 @keras_export('keras.backend.conv1d')
   4205 def conv1d(x,
   4206            kernel,
   4207            strides=1,
   4208            padding='valid',
   4209            data_format=None,
   4210            dilation_rate=1):
   4211   """1D convolution.
   4212 
   4213   Arguments:
   4214       x: Tensor or variable.
   4215       kernel: kernel tensor.
   4216       strides: stride integer.
   4217       padding: string, `"same"`, `"causal"` or `"valid"`.
   4218       data_format: string, one of "channels_last", "channels_first".
   4219       dilation_rate: integer dilate rate.
   4220 
   4221   Returns:
   4222       A tensor, result of 1D convolution.
   4223 
   4224   Raises:
   4225       ValueError: if `data_format` is neither `channels_last` or
   4226       `channels_first`.
   4227   """
   4228   if data_format is None:
   4229     data_format = image_data_format()
   4230   if data_format not in {'channels_first', 'channels_last'}:
   4231     raise ValueError('Unknown data_format: ' + str(data_format))
   4232 
   4233   kernel_shape = kernel.shape.as_list()
   4234   if padding == 'causal':
   4235     # causal (dilated) convolution:
   4236     left_pad = dilation_rate * (kernel_shape[0] - 1)
   4237     x = temporal_padding(x, (left_pad, 0))
   4238     padding = 'valid'
   4239   padding = _preprocess_padding(padding)
   4240 
   4241   x, tf_data_format = _preprocess_conv1d_input(x, data_format)
   4242   x = nn.convolution(
   4243       input=x,
   4244       filter=kernel,
   4245       dilation_rate=dilation_rate,
   4246       strides=strides,
   4247       padding=padding,
   4248       data_format=tf_data_format)
   4249   if data_format == 'channels_first' and tf_data_format == 'NWC':
   4250     x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
   4251   return x
   4252 
   4253 
   4254 @keras_export('keras.backend.conv2d')
   4255 def conv2d(x,
   4256            kernel,
   4257            strides=(1, 1),
   4258            padding='valid',
   4259            data_format=None,
   4260            dilation_rate=(1, 1)):
   4261   """2D convolution.
   4262 
   4263   Arguments:
   4264       x: Tensor or variable.
   4265       kernel: kernel tensor.
   4266       strides: strides tuple.
   4267       padding: string, `"same"` or `"valid"`.
   4268       data_format: `"channels_last"` or `"channels_first"`.
   4269           Whether to use Theano or TensorFlow data format
   4270           for inputs/kernels/outputs.
   4271       dilation_rate: tuple of 2 integers.
   4272 
   4273   Returns:
   4274       A tensor, result of 2D convolution.
   4275 
   4276   Raises:
   4277       ValueError: if `data_format` is neither `channels_last` or
   4278       `channels_first`.
   4279   """
   4280   if data_format is None:
   4281     data_format = image_data_format()
   4282   if data_format not in {'channels_first', 'channels_last'}:
   4283     raise ValueError('Unknown data_format: ' + str(data_format))
   4284 
   4285   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
   4286   padding = _preprocess_padding(padding)
   4287   x = nn.convolution(
   4288       input=x,
   4289       filter=kernel,
   4290       dilation_rate=dilation_rate,
   4291       strides=strides,
   4292       padding=padding,
   4293       data_format=tf_data_format)
   4294   if data_format == 'channels_first' and tf_data_format == 'NHWC':
   4295     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
   4296   return x
   4297 
   4298 
   4299 @keras_export('keras.backend.conv2d_transpose')
   4300 def conv2d_transpose(x,
   4301                      kernel,
   4302                      output_shape,
   4303                      strides=(1, 1),
   4304                      padding='valid',
   4305                      data_format=None,
   4306                      dilation_rate=(1, 1)):
   4307   """2D deconvolution (i.e.
   4308 
   4309   transposed convolution).
   4310 
   4311   Arguments:
   4312       x: Tensor or variable.
   4313       kernel: kernel tensor.
   4314       output_shape: 1D int tensor for the output shape.
   4315       strides: strides tuple.
   4316       padding: string, `"same"` or `"valid"`.
   4317       data_format: string, `"channels_last"` or `"channels_first"`.
   4318           Whether to use Theano or TensorFlow/CNTK data format
   4319           for inputs/kernels/outputs.
   4320       dilation_rate: Tuple of 2 integers.
   4321 
   4322   Returns:
   4323       A tensor, result of transposed 2D convolution.
   4324 
   4325   Raises:
   4326       ValueError: if `data_format` is neither `channels_last` or
   4327       `channels_first`.
   4328   """
   4329   if data_format is None:
   4330     data_format = image_data_format()
   4331   if data_format not in {'channels_first', 'channels_last'}:
   4332     raise ValueError('Unknown data_format: ' + str(data_format))
   4333   if isinstance(output_shape, (tuple, list)):
   4334     output_shape = array_ops.stack(output_shape)
   4335 
   4336   # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
   4337   if data_format == 'channels_first' and dilation_rate != (1, 1):
   4338     force_transpose = True
   4339   else:
   4340     force_transpose = False
   4341 
   4342   x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
   4343 
   4344   if data_format == 'channels_first' and tf_data_format == 'NHWC':
   4345     output_shape = (output_shape[0], output_shape[2], output_shape[3],
   4346                     output_shape[1])
   4347   if output_shape[0] is None:
   4348     output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
   4349     output_shape = array_ops.stack(list(output_shape))
   4350 
   4351   padding = _preprocess_padding(padding)
   4352   if tf_data_format == 'NHWC':
   4353     strides = (1,) + strides + (1,)
   4354   else:
   4355     strides = (1, 1) + strides
   4356 
   4357   if dilation_rate == (1, 1):
   4358     x = nn.conv2d_transpose(x, kernel, output_shape, strides,
   4359                             padding=padding,
   4360                             data_format=tf_data_format)
   4361   else:
   4362     assert dilation_rate[0] == dilation_rate[1]
   4363     x = nn.atrous_conv2d_transpose(
   4364         x,
   4365         kernel,
   4366         output_shape,
   4367         rate=dilation_rate[0],
   4368         padding=padding)
   4369   if data_format == 'channels_first' and tf_data_format == 'NHWC':
   4370     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
   4371   return x
   4372 
   4373 
   4374 def separable_conv1d(x,
   4375                      depthwise_kernel,
   4376                      pointwise_kernel,
   4377                      strides=1,
   4378                      padding='valid',
   4379                      data_format=None,
   4380                      dilation_rate=1):
   4381   """1D convolution with separable filters.
   4382 
   4383   Arguments:
   4384       x: input tensor
   4385       depthwise_kernel: convolution kernel for the depthwise convolution.
   4386       pointwise_kernel: kernel for the 1x1 convolution.
   4387       strides: stride integer.
   4388       padding: string, `"same"` or `"valid"`.
   4389       data_format: string, `"channels_last"` or `"channels_first"`.
   4390       dilation_rate: integer dilation rate.
   4391 
   4392   Returns:
   4393       Output tensor.
   4394 
   4395   Raises:
   4396       ValueError: if `data_format` is neither `channels_last` or
   4397       `channels_first`.
   4398   """
   4399   if data_format is None:
   4400     data_format = image_data_format()
   4401   if data_format not in {'channels_first', 'channels_last'}:
   4402     raise ValueError('Unknown data_format: ' + str(data_format))
   4403 
   4404   if isinstance(strides, int):
   4405     strides = (strides,)
   4406   if isinstance(dilation_rate, int):
   4407     dilation_rate = (dilation_rate,)
   4408 
   4409   x, tf_data_format = _preprocess_conv1d_input(x, data_format)
   4410   padding = _preprocess_padding(padding)
   4411   if not isinstance(strides, tuple):
   4412     strides = tuple(strides)
   4413   if tf_data_format == 'NWC':
   4414     spatial_start_dim = 1
   4415     strides = (1,) + strides * 2 + (1,)
   4416   else:
   4417     spatial_start_dim = 2
   4418     strides = (1, 1) + strides * 2
   4419   x = array_ops.expand_dims(x, spatial_start_dim)
   4420   depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
   4421   pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
   4422   dilation_rate = (1,) + dilation_rate
   4423 
   4424   x = nn.separable_conv2d(
   4425       x,
   4426       depthwise_kernel,
   4427       pointwise_kernel,
   4428       strides=strides,
   4429       padding=padding,
   4430       rate=dilation_rate,
   4431       data_format=tf_data_format)
   4432 
   4433   x = array_ops.squeeze(x, [spatial_start_dim])
   4434 
   4435   if data_format == 'channels_first' and tf_data_format == 'NWC':
   4436     x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
   4437 
   4438   return x
   4439 
   4440 
   4441 @keras_export('keras.backend.separable_conv2d')
   4442 def separable_conv2d(x,
   4443                      depthwise_kernel,
   4444                      pointwise_kernel,
   4445                      strides=(1, 1),
   4446                      padding='valid',
   4447                      data_format=None,
   4448                      dilation_rate=(1, 1)):
   4449   """2D convolution with separable filters.
   4450 
   4451   Arguments:
   4452       x: input tensor
   4453       depthwise_kernel: convolution kernel for the depthwise convolution.
   4454       pointwise_kernel: kernel for the 1x1 convolution.
   4455       strides: strides tuple (length 2).
   4456       padding: string, `"same"` or `"valid"`.
   4457       data_format: string, `"channels_last"` or `"channels_first"`.
   4458       dilation_rate: tuple of integers,
   4459           dilation rates for the separable convolution.
   4460 
   4461   Returns:
   4462       Output tensor.
   4463 
   4464   Raises:
   4465       ValueError: if `data_format` is neither `channels_last` or
   4466       `channels_first`.
   4467       ValueError: if `strides` is not a tuple of 2 integers.
   4468   """
   4469   if data_format is None:
   4470     data_format = image_data_format()
   4471   if data_format not in {'channels_first', 'channels_last'}:
   4472     raise ValueError('Unknown data_format: ' + str(data_format))
   4473   if len(strides) != 2:
   4474     raise ValueError('`strides` must be a tuple of 2 integers.')
   4475 
   4476   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
   4477   padding = _preprocess_padding(padding)
   4478   if not isinstance(strides, tuple):
   4479     strides = tuple(strides)
   4480   if tf_data_format == 'NHWC':
   4481     strides = (1,) + strides + (1,)
   4482   else:
   4483     strides = (1, 1) + strides
   4484 
   4485   x = nn.separable_conv2d(
   4486       x,
   4487       depthwise_kernel,
   4488       pointwise_kernel,
   4489       strides=strides,
   4490       padding=padding,
   4491       rate=dilation_rate,
   4492       data_format=tf_data_format)
   4493   if data_format == 'channels_first' and tf_data_format == 'NHWC':
   4494     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
   4495   return x
   4496 
   4497 
   4498 def depthwise_conv2d(x,
   4499                      depthwise_kernel,
   4500                      strides=(1, 1),
   4501                      padding='valid',
   4502                      data_format=None,
   4503                      dilation_rate=(1, 1)):
   4504   """2D convolution with separable filters.
   4505 
   4506   Arguments:
   4507       x: input tensor
   4508       depthwise_kernel: convolution kernel for the depthwise convolution.
   4509       strides: strides tuple (length 2).
   4510       padding: string, `"same"` or `"valid"`.
   4511       data_format: string, `"channels_last"` or `"channels_first"`.
   4512       dilation_rate: tuple of integers,
   4513           dilation rates for the separable convolution.
   4514 
   4515   Returns:
   4516       Output tensor.
   4517 
   4518   Raises:
   4519       ValueError: if `data_format` is neither `channels_last` or
   4520       `channels_first`.
   4521   """
   4522   if data_format is None:
   4523     data_format = image_data_format()
   4524   if data_format not in {'channels_first', 'channels_last'}:
   4525     raise ValueError('Unknown data_format: ' + str(data_format))
   4526 
   4527   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
   4528   padding = _preprocess_padding(padding)
   4529   if tf_data_format == 'NHWC':
   4530     strides = (1,) + strides + (1,)
   4531   else:
   4532     strides = (1, 1) + strides
   4533 
   4534   x = nn.depthwise_conv2d(
   4535       x,
   4536       depthwise_kernel,
   4537       strides=strides,
   4538       padding=padding,
   4539       rate=dilation_rate,
   4540       data_format=tf_data_format)
   4541   if data_format == 'channels_first' and tf_data_format == 'NHWC':
   4542     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
   4543   return x
   4544 
   4545 
   4546 @keras_export('keras.backend.conv3d')
   4547 def conv3d(x,
   4548            kernel,
   4549            strides=(1, 1, 1),
   4550            padding='valid',
   4551            data_format=None,
   4552            dilation_rate=(1, 1, 1)):
   4553   """3D convolution.
   4554 
   4555   Arguments:
   4556       x: Tensor or variable.
   4557       kernel: kernel tensor.
   4558       strides: strides tuple.
   4559       padding: string, `"same"` or `"valid"`.
   4560       data_format: string, `"channels_last"` or `"channels_first"`.
   4561           Whether to use Theano or TensorFlow/CNTK data format
   4562           for inputs/kernels/outputs.
   4563       dilation_rate: tuple of 3 integers.
   4564 
   4565   Returns:
   4566       A tensor, result of 3D convolution.
   4567 
   4568   Raises:
   4569       ValueError: if `data_format` is neither `channels_last` or
   4570       `channels_first`.
   4571   """
   4572   if data_format is None:
   4573     data_format = image_data_format()
   4574   if data_format not in {'channels_first', 'channels_last'}:
   4575     raise ValueError('Unknown data_format: ' + str(data_format))
   4576 
   4577   x, tf_data_format = _preprocess_conv3d_input(x, data_format)
   4578   padding = _preprocess_padding(padding)
   4579   x = nn.convolution(
   4580       input=x,
   4581       filter=kernel,
   4582       dilation_rate=dilation_rate,
   4583       strides=strides,
   4584       padding=padding,
   4585       data_format=tf_data_format)
   4586   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
   4587     x = array_ops.transpose(x, (0, 4, 1, 2, 3))
   4588   return x
   4589 
   4590 
   4591 def conv3d_transpose(x,
   4592                      kernel,
   4593                      output_shape,
   4594                      strides=(1, 1, 1),
   4595                      padding='valid',
   4596                      data_format=None):
   4597   """3D deconvolution (i.e.
   4598 
   4599   transposed convolution).
   4600 
   4601   Arguments:
   4602       x: input tensor.
   4603       kernel: kernel tensor.
   4604       output_shape: 1D int tensor for the output shape.
   4605       strides: strides tuple.
   4606       padding: string, "same" or "valid".
   4607       data_format: string, `"channels_last"` or `"channels_first"`.
   4608           Whether to use Theano or TensorFlow/CNTK data format
   4609           for inputs/kernels/outputs.
   4610 
   4611   Returns:
   4612       A tensor, result of transposed 3D convolution.
   4613 
   4614   Raises:
   4615       ValueError: if `data_format` is neither `channels_last` or
   4616       `channels_first`.
   4617   """
   4618   if data_format is None:
   4619     data_format = image_data_format()
   4620   if data_format not in {'channels_first', 'channels_last'}:
   4621     raise ValueError('Unknown data_format: ' + str(data_format))
   4622   if isinstance(output_shape, (tuple, list)):
   4623     output_shape = array_ops.stack(output_shape)
   4624 
   4625   x, tf_data_format = _preprocess_conv3d_input(x, data_format)
   4626 
   4627   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
   4628     output_shape = (output_shape[0], output_shape[2], output_shape[3],
   4629                     output_shape[4], output_shape[1])
   4630   if output_shape[0] is None:
   4631     output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
   4632     output_shape = array_ops.stack(list(output_shape))
   4633 
   4634   padding = _preprocess_padding(padding)
   4635   if tf_data_format == 'NDHWC':
   4636     strides = (1,) + strides + (1,)
   4637   else:
   4638     strides = (1, 1) + strides
   4639 
   4640   x = nn.conv3d_transpose(
   4641       x,
   4642       kernel,
   4643       output_shape,
   4644       strides,
   4645       padding=padding,
   4646       data_format=tf_data_format)
   4647   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
   4648     x = array_ops.transpose(x, (0, 4, 1, 2, 3))
   4649   return x
   4650 
   4651 
   4652 @keras_export('keras.backend.pool2d')
   4653 def pool2d(x,
   4654            pool_size,
   4655            strides=(1, 1),
   4656            padding='valid',
   4657            data_format=None,
   4658            pool_mode='max'):
   4659   """2D Pooling.
   4660 
   4661   Arguments:
   4662       x: Tensor or variable.
   4663       pool_size: tuple of 2 integers.
   4664       strides: tuple of 2 integers.
   4665       padding: string, `"same"` or `"valid"`.
   4666       data_format: string, `"channels_last"` or `"channels_first"`.
   4667       pool_mode: string, `"max"` or `"avg"`.
   4668 
   4669   Returns:
   4670       A tensor, result of 2D pooling.
   4671 
   4672   Raises:
   4673       ValueError: if `data_format` is neither `"channels_last"` or
   4674       `"channels_first"`.
   4675       ValueError: if `pool_size` is not a tuple of 2 integers.
   4676       ValueError: if `strides` is not a tuple of 2 integers.
   4677       ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
   4678   """
   4679   if data_format is None:
   4680     data_format = image_data_format()
   4681   if data_format not in {'channels_first', 'channels_last'}:
   4682     raise ValueError('Unknown data_format: ' + str(data_format))
   4683   if len(pool_size) != 2:
   4684     raise ValueError('`pool_size` must be a tuple of 2 integers.')
   4685   if len(strides) != 2:
   4686     raise ValueError('`strides` must be a tuple of 2 integers.')
   4687 
   4688   x, tf_data_format = _preprocess_conv2d_input(x, data_format)
   4689   padding = _preprocess_padding(padding)
   4690   if tf_data_format == 'NHWC':
   4691     strides = (1,) + strides + (1,)
   4692     pool_size = (1,) + pool_size + (1,)
   4693   else:
   4694     strides = (1, 1) + strides
   4695     pool_size = (1, 1) + pool_size
   4696 
   4697   if pool_mode == 'max':
   4698     x = nn.max_pool(
   4699         x, pool_size, strides, padding=padding, data_format=tf_data_format)
   4700   elif pool_mode == 'avg':
   4701     x = nn.avg_pool(
   4702         x, pool_size, strides, padding=padding, data_format=tf_data_format)
   4703   else:
   4704     raise ValueError('Invalid pooling mode: ' + str(pool_mode))
   4705 
   4706   if data_format == 'channels_first' and tf_data_format == 'NHWC':
   4707     x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
   4708   return x
   4709 
   4710 
   4711 @keras_export('keras.backend.pool3d')
   4712 def pool3d(x,
   4713            pool_size,
   4714            strides=(1, 1, 1),
   4715            padding='valid',
   4716            data_format=None,
   4717            pool_mode='max'):
   4718   """3D Pooling.
   4719 
   4720   Arguments:
   4721       x: Tensor or variable.
   4722       pool_size: tuple of 3 integers.
   4723       strides: tuple of 3 integers.
   4724       padding: string, `"same"` or `"valid"`.
   4725       data_format: string, `"channels_last"` or `"channels_first"`.
   4726       pool_mode: string, `"max"` or `"avg"`.
   4727 
   4728   Returns:
   4729       A tensor, result of 3D pooling.
   4730 
   4731   Raises:
   4732       ValueError: if `data_format` is neither `"channels_last"` or
   4733       `"channels_first"`.
   4734       ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
   4735   """
   4736   if data_format is None:
   4737     data_format = image_data_format()
   4738   if data_format not in {'channels_first', 'channels_last'}:
   4739     raise ValueError('Unknown data_format: ' + str(data_format))
   4740 
   4741   x, tf_data_format = _preprocess_conv3d_input(x, data_format)
   4742   padding = _preprocess_padding(padding)
   4743   if tf_data_format == 'NDHWC':
   4744     strides = (1,) + strides + (1,)
   4745     pool_size = (1,) + pool_size + (1,)
   4746   else:
   4747     strides = (1, 1) + strides
   4748     pool_size = (1, 1) + pool_size
   4749 
   4750   if pool_mode == 'max':
   4751     x = nn.max_pool3d(
   4752         x, pool_size, strides, padding=padding, data_format=tf_data_format)
   4753   elif pool_mode == 'avg':
   4754     x = nn.avg_pool3d(
   4755         x, pool_size, strides, padding=padding, data_format=tf_data_format)
   4756   else:
   4757     raise ValueError('Invalid pooling mode: ' + str(pool_mode))
   4758 
   4759   if data_format == 'channels_first' and tf_data_format == 'NDHWC':
   4760     x = array_ops.transpose(x, (0, 4, 1, 2, 3))
   4761   return x
   4762 
   4763 
   4764 def local_conv(inputs,
   4765                kernel,
   4766                kernel_size,
   4767                strides,
   4768                output_shape,
   4769                data_format=None):
   4770   """Apply N-D convolution with un-shared weights.
   4771 
   4772   Arguments:
   4773       inputs: (N+2)-D tensor with shape
   4774           (batch_size, channels_in, d_in1, ..., d_inN)
   4775           if data_format='channels_first', or
   4776           (batch_size, d_in1, ..., d_inN, channels_in)
   4777           if data_format='channels_last'.
   4778       kernel: the unshared weight for N-D convolution,
   4779           with shape (output_items, feature_dim, channels_out), where
   4780           feature_dim = np.prod(kernel_size) * channels_in,
   4781           output_items = np.prod(output_shape).
   4782       kernel_size: a tuple of N integers, specifying the
   4783           spatial dimensions of the N-D convolution window.
   4784       strides: a tuple of N integers, specifying the strides
   4785           of the convolution along the spatial dimensions.
   4786       output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
   4787           dimensionality of the output.
   4788       data_format: string, "channels_first" or "channels_last".
   4789 
   4790   Returns:
   4791       An (N+2)-D tensor with shape:
   4792       (batch_size, channels_out) + output_shape
   4793       if data_format='channels_first', or:
   4794       (batch_size,) + output_shape + (channels_out,)
   4795       if data_format='channels_last'.
   4796 
   4797   Raises:
   4798       ValueError: if `data_format` is neither
   4799       `channels_last` nor `channels_first`.
   4800   """
   4801   if data_format is None:
   4802     data_format = image_data_format()
   4803   if data_format not in {'channels_first', 'channels_last'}:
   4804     raise ValueError('Unknown data_format: ' + str(data_format))
   4805 
   4806   kernel_shape = int_shape(kernel)
   4807   feature_dim = kernel_shape[1]
   4808   channels_out = kernel_shape[-1]
   4809   ndims = len(output_shape)
   4810   spatial_dimensions = list(range(ndims))
   4811 
   4812   xs = []
   4813   output_axes_ticks = [range(axis_max) for axis_max in output_shape]
   4814   for position in itertools.product(*output_axes_ticks):
   4815     slices = [slice(None)]
   4816 
   4817     if data_format == 'channels_first':
   4818       slices.append(slice(None))
   4819 
   4820     slices.extend([slice(position[d] * strides[d],
   4821                          position[d] * strides[d] + kernel_size[d])
   4822                    for d in spatial_dimensions])
   4823 
   4824     if data_format == 'channels_last':
   4825       slices.append(slice(None))
   4826 
   4827     xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
   4828 
   4829   x_aggregate = concatenate(xs, axis=0)
   4830   output = batch_dot(x_aggregate, kernel)
   4831   output = reshape(output, output_shape + (-1, channels_out))
   4832 
   4833   if data_format == 'channels_first':
   4834     permutation = [ndims, ndims + 1] + spatial_dimensions
   4835   else:
   4836     permutation = [ndims] + spatial_dimensions + [ndims + 1]
   4837 
   4838   return permute_dimensions(output, permutation)
   4839 
   4840 
   4841 @keras_export('keras.backend.local_conv1d')
   4842 def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
   4843   """Apply 1D conv with un-shared weights.
   4844 
   4845   Arguments:
   4846       inputs: 3D tensor with shape:
   4847           (batch_size, steps, input_dim)
   4848           if data_format is "channels_last" or
   4849           (batch_size, input_dim, steps)
   4850           if data_format is "channels_first".
   4851       kernel: the unshared weight for convolution,
   4852           with shape (output_length, feature_dim, filters).
   4853       kernel_size: a tuple of a single integer,
   4854           specifying the length of the 1D convolution window.
   4855       strides: a tuple of a single integer,
   4856           specifying the stride length of the convolution.
   4857       data_format: the data format, channels_first or channels_last.
   4858 
   4859   Returns:
   4860       A 3d tensor with shape:
   4861       (batch_size, output_length, filters)
   4862       if data_format='channels_first'
   4863       or 3D tensor with shape:
   4864       (batch_size, filters, output_length)
   4865       if data_format='channels_last'.
   4866   """
   4867   output_shape = (kernel.shape[0],)
   4868   return local_conv(inputs,
   4869                     kernel,
   4870                     kernel_size,
   4871                     strides,
   4872                     output_shape,
   4873                     data_format)
   4874 
   4875 
   4876 @keras_export('keras.backend.local_conv2d')
   4877 def local_conv2d(inputs,
   4878                  kernel,
   4879                  kernel_size,
   4880                  strides,
   4881                  output_shape,
   4882                  data_format=None):
   4883   """Apply 2D conv with un-shared weights.
   4884 
   4885   Arguments:
   4886       inputs: 4D tensor with shape:
   4887           (batch_size, filters, new_rows, new_cols)
   4888           if data_format='channels_first'
   4889           or 4D tensor with shape:
   4890           (batch_size, new_rows, new_cols, filters)
   4891           if data_format='channels_last'.
   4892       kernel: the unshared weight for convolution,
   4893           with shape (output_items, feature_dim, filters).
   4894       kernel_size: a tuple of 2 integers, specifying the
   4895           width and height of the 2D convolution window.
   4896       strides: a tuple of 2 integers, specifying the strides
   4897           of the convolution along the width and height.
   4898       output_shape: a tuple with (output_row, output_col).
   4899       data_format: the data format, channels_first or channels_last.
   4900 
   4901   Returns:
   4902       A 4D tensor with shape:
   4903       (batch_size, filters, new_rows, new_cols)
   4904       if data_format='channels_first'
   4905       or 4D tensor with shape:
   4906       (batch_size, new_rows, new_cols, filters)
   4907       if data_format='channels_last'.
   4908   """
   4909   return local_conv(inputs,
   4910                     kernel,
   4911                     kernel_size,
   4912                     strides,
   4913                     output_shape,
   4914                     data_format)
   4915 
   4916 
   4917 @keras_export('keras.backend.bias_add')
   4918 def bias_add(x, bias, data_format=None):
   4919   """Adds a bias vector to a tensor.
   4920 
   4921   Arguments:
   4922       x: Tensor or variable.
   4923       bias: Bias tensor to add.
   4924       data_format: string, `"channels_last"` or `"channels_first"`.
   4925 
   4926   Returns:
   4927       Output tensor.
   4928 
   4929   Raises:
   4930       ValueError: In one of the two cases below:
   4931                   1. invalid `data_format` argument.
   4932                   2. invalid bias shape.
   4933                      the bias should be either a vector or
   4934                      a tensor with ndim(x) - 1 dimension
   4935   """
   4936   if data_format is None:
   4937     data_format = image_data_format()
   4938   if data_format not in {'channels_first', 'channels_last'}:
   4939     raise ValueError('Unknown data_format: ' + str(data_format))
   4940   bias_shape = int_shape(bias)
   4941   if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
   4942     raise ValueError(
   4943         'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
   4944         (len(bias_shape), ndim(x)))
   4945   # pylint: disable=g-no-augmented-assignment
   4946   if ndim(x) == 5:
   4947     if data_format == 'channels_first':
   4948       if len(bias_shape) == 1:
   4949         x = x + reshape(bias, (1, bias_shape[0], 1, 1, 1))
   4950       else:
   4951         x = x + reshape(bias, (1, bias_shape[3]) + bias_shape[:3])
   4952     elif data_format == 'channels_last':
   4953       if len(bias_shape) == 1:
   4954         x = x + reshape(bias, (1, 1, 1, bias_shape[0]))
   4955       else:
   4956         x = x + reshape(bias, (1,) + bias_shape)
   4957   elif ndim(x) == 4:
   4958     if data_format == 'channels_first':
   4959       if len(bias_shape) == 1:
   4960         if _has_nchw_support():
   4961           x = nn.bias_add(x, bias, data_format='NCHW')
   4962         else:
   4963           x = x + reshape(bias, (1, bias_shape[0], 1, 1))
   4964       else:
   4965         x = x + reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
   4966     elif data_format == 'channels_last':
   4967       if len(bias_shape) == 1:
   4968         x = nn.bias_add(x, bias, data_format='NHWC')
   4969       else:
   4970         x = x + reshape(bias, (1,) + bias_shape)
   4971   elif ndim(x) == 3:
   4972     if data_format == 'channels_first':
   4973       if len(bias_shape) == 1:
   4974         x = x + reshape(bias, (1, bias_shape[0], 1))
   4975       else:
   4976         x = x + reshape(bias, (1, bias_shape[1], bias_shape[0]))
   4977     elif data_format == 'channels_last':
   4978       if len(bias_shape) == 1:
   4979         x = x + reshape(bias, (1, 1, bias_shape[0]))
   4980       else:
   4981         x = x + reshape(bias, (1,) + bias_shape)
   4982   else:
   4983     x = nn.bias_add(x, bias)
   4984   # pylint: enable=g-no-augmented-assignment
   4985   return x
   4986 
   4987 
   4988 # RANDOMNESS
   4989 
   4990 
   4991 @keras_export('keras.backend.random_normal')
   4992 def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
   4993   """Returns a tensor with normal distribution of values.
   4994 
   4995   Arguments:
   4996       shape: A tuple of integers, the shape of tensor to create.
   4997       mean: A float, mean of the normal distribution to draw samples.
   4998       stddev: A float, standard deviation of the normal distribution
   4999           to draw samples.
   5000       dtype: String, dtype of returned tensor.
   5001       seed: Integer, random seed.
   5002 
   5003   Returns:
   5004       A tensor.
   5005   """
   5006   if dtype is None:
   5007     dtype = floatx()
   5008   if seed is None:
   5009     seed = np.random.randint(10e6)
   5010   return random_ops.random_normal(
   5011       shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
   5012 
   5013 
   5014 @keras_export('keras.backend.random_uniform')
   5015 def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
   5016   """Returns a tensor with uniform distribution of values.
   5017 
   5018   Arguments:
   5019       shape: A tuple of integers, the shape of tensor to create.
   5020       minval: A float, lower boundary of the uniform distribution
   5021           to draw samples.
   5022       maxval: A float, upper boundary of the uniform distribution
   5023           to draw samples.
   5024       dtype: String, dtype of returned tensor.
   5025       seed: Integer, random seed.
   5026 
   5027   Returns:
   5028       A tensor.
   5029   """
   5030   if dtype is None:
   5031     dtype = floatx()
   5032   if seed is None:
   5033     seed = np.random.randint(10e6)
   5034   return random_ops.random_uniform(
   5035       shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
   5036 
   5037 
   5038 @keras_export('keras.backend.random_binomial')
   5039 def random_binomial(shape, p=0.0, dtype=None, seed=None):
   5040   """Returns a tensor with random binomial distribution of values.
   5041 
   5042   Arguments:
   5043       shape: A tuple of integers, the shape of tensor to create.
   5044       p: A float, `0. <= p <= 1`, probability of binomial distribution.
   5045       dtype: String, dtype of returned tensor.
   5046       seed: Integer, random seed.
   5047 
   5048   Returns:
   5049       A tensor.
   5050   """
   5051   if dtype is None:
   5052     dtype = floatx()
   5053   if seed is None:
   5054     seed = np.random.randint(10e6)
   5055   return array_ops.where(
   5056       random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
   5057       array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
   5058 
   5059 
   5060 @keras_export('keras.backend.truncated_normal')
   5061 def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
   5062   """Returns a tensor with truncated random normal distribution of values.
   5063 
   5064   The generated values follow a normal distribution
   5065   with specified mean and standard deviation,
   5066   except that values whose magnitude is more than
   5067   two standard deviations from the mean are dropped and re-picked.
   5068 
   5069   Arguments:
   5070       shape: A tuple of integers, the shape of tensor to create.
   5071       mean: Mean of the values.
   5072       stddev: Standard deviation of the values.
   5073       dtype: String, dtype of returned tensor.
   5074       seed: Integer, random seed.
   5075 
   5076   Returns:
   5077       A tensor.
   5078   """
   5079   if dtype is None:
   5080     dtype = floatx()
   5081   if seed is None:
   5082     seed = np.random.randint(10e6)
   5083   return random_ops.truncated_normal(
   5084       shape, mean, stddev, dtype=dtype, seed=seed)
   5085 
   5086 
   5087 # CTC
   5088 # TensorFlow has a native implementation, but it uses sparse tensors
   5089 # and therefore requires a wrapper for Keras. The functions below convert
   5090 # dense to sparse tensors and also wraps up the beam search code that is
   5091 # in TensorFlow's CTC implementation
   5092 
   5093 
   5094 @keras_export('keras.backend.ctc_label_dense_to_sparse')
   5095 def ctc_label_dense_to_sparse(labels, label_lengths):
   5096   """Converts CTC labels from dense to sparse.
   5097 
   5098   Arguments:
   5099       labels: dense CTC labels.
   5100       label_lengths: length of the labels.
   5101 
   5102   Returns:
   5103       A sparse tensor representation of the labels.
   5104   """
   5105   label_shape = array_ops.shape(labels)
   5106   num_batches_tns = array_ops.stack([label_shape[0]])
   5107   max_num_labels_tns = array_ops.stack([label_shape[1]])
   5108 
   5109   def range_less_than(_, current_input):
   5110     return array_ops.expand_dims(
   5111         math_ops.range(label_shape[1]), 0) < array_ops.fill(
   5112             max_num_labels_tns, current_input)
   5113 
   5114   init = math_ops.cast(
   5115       array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
   5116   dense_mask = functional_ops.scan(
   5117       range_less_than, label_lengths, initializer=init, parallel_iterations=1)
   5118   dense_mask = dense_mask[:, 0, :]
   5119 
   5120   label_array = array_ops.reshape(
   5121       array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
   5122       label_shape)
   5123   label_ind = array_ops.boolean_mask(label_array, dense_mask)
   5124 
   5125   batch_array = array_ops.transpose(
   5126       array_ops.reshape(
   5127           array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
   5128           reverse(label_shape, 0)))
   5129   batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
   5130   indices = array_ops.transpose(
   5131       array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
   5132 
   5133   vals_sparse = array_ops.gather_nd(labels, indices)
   5134 
   5135   return sparse_tensor.SparseTensor(
   5136       math_ops.cast(indices, dtypes_module.int64), vals_sparse,
   5137       math_ops.cast(label_shape, dtypes_module.int64))
   5138 
   5139 
   5140 @keras_export('keras.backend.ctc_batch_cost')
   5141 def ctc_batch_cost(y_true, y_pred, input_length, label_length):
   5142   """Runs CTC loss algorithm on each batch element.
   5143 
   5144   Arguments:
   5145       y_true: tensor `(samples, max_string_length)`
   5146           containing the truth labels.
   5147       y_pred: tensor `(samples, time_steps, num_categories)`
   5148           containing the prediction, or output of the softmax.
   5149       input_length: tensor `(samples, 1)` containing the sequence length for
   5150           each batch item in `y_pred`.
   5151       label_length: tensor `(samples, 1)` containing the sequence length for
   5152           each batch item in `y_true`.
   5153 
   5154   Returns:
   5155       Tensor with shape (samples,1) containing the
   5156           CTC loss of each element.
   5157   """
   5158   label_length = math_ops.cast(
   5159       array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
   5160   input_length = math_ops.cast(
   5161       array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
   5162   sparse_labels = math_ops.cast(
   5163       ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
   5164 
   5165   y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
   5166 
   5167   return array_ops.expand_dims(
   5168       ctc.ctc_loss(
   5169           inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
   5170 
   5171 
   5172 @keras_export('keras.backend.ctc_decode')
   5173 def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
   5174   """Decodes the output of a softmax.
   5175 
   5176   Can use either greedy search (also known as best path)
   5177   or a constrained dictionary search.
   5178 
   5179   Arguments:
   5180       y_pred: tensor `(samples, time_steps, num_categories)`
   5181           containing the prediction, or output of the softmax.
   5182       input_length: tensor `(samples, )` containing the sequence length for
   5183           each batch item in `y_pred`.
   5184       greedy: perform much faster best-path search if `true`.
   5185           This does not use a dictionary.
   5186       beam_width: if `greedy` is `false`: a beam search decoder will be used
   5187           with a beam of this width.
   5188       top_paths: if `greedy` is `false`,
   5189           how many of the most probable paths will be returned.
   5190 
   5191   Returns:
   5192       Tuple:
   5193           List: if `greedy` is `true`, returns a list of one element that
   5194               contains the decoded sequence.
   5195               If `false`, returns the `top_paths` most probable
   5196               decoded sequences.
   5197               Important: blank labels are returned as `-1`.
   5198           Tensor `(top_paths, )` that contains
   5199               the log probability of each decoded sequence.
   5200   """
   5201   y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
   5202   input_length = math_ops.cast(input_length, dtypes_module.int32)
   5203 
   5204   if greedy:
   5205     (decoded, log_prob) = ctc.ctc_greedy_decoder(
   5206         inputs=y_pred, sequence_length=input_length)
   5207   else:
   5208     (decoded, log_prob) = ctc.ctc_beam_search_decoder(
   5209         inputs=y_pred,
   5210         sequence_length=input_length,
   5211         beam_width=beam_width,
   5212         top_paths=top_paths)
   5213   decoded_dense = [
   5214       sparse_ops.sparse_to_dense(
   5215           st.indices, st.dense_shape, st.values, default_value=-1)
   5216       for st in decoded
   5217   ]
   5218   return (decoded_dense, log_prob)
   5219 
   5220 
   5221 # HIGH ORDER FUNCTIONS
   5222 
   5223 
   5224 @keras_export('keras.backend.map_fn')
   5225 def map_fn(fn, elems, name=None, dtype=None):
   5226   """Map the function fn over the elements elems and return the outputs.
   5227 
   5228   Arguments:
   5229       fn: Callable that will be called upon each element in elems
   5230       elems: tensor
   5231       name: A string name for the map node in the graph
   5232       dtype: Output data type.
   5233 
   5234   Returns:
   5235       Tensor with dtype `dtype`.
   5236   """
   5237   return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
   5238 
   5239 
   5240 @keras_export('keras.backend.foldl')
   5241 def foldl(fn, elems, initializer=None, name=None):
   5242   """Reduce elems using fn to combine them from left to right.
   5243 
   5244   Arguments:
   5245       fn: Callable that will be called upon each element in elems and an
   5246           accumulator, for instance `lambda acc, x: acc + x`
   5247       elems: tensor
   5248       initializer: The first value used (`elems[0]` in case of None)
   5249       name: A string name for the foldl node in the graph
   5250 
   5251   Returns:
   5252       Tensor with same type and shape as `initializer`.
   5253   """
   5254   return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
   5255 
   5256 
   5257 @keras_export('keras.backend.foldr')
   5258 def foldr(fn, elems, initializer=None, name=None):
   5259   """Reduce elems using fn to combine them from right to left.
   5260 
   5261   Arguments:
   5262       fn: Callable that will be called upon each element in elems and an
   5263           accumulator, for instance `lambda acc, x: acc + x`
   5264       elems: tensor
   5265       initializer: The first value used (`elems[-1]` in case of None)
   5266       name: A string name for the foldr node in the graph
   5267 
   5268   Returns:
   5269       Same type and shape as initializer
   5270   """
   5271   return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
   5272 
   5273 # Load Keras default configuration from config file if present.
   5274 # Set Keras base dir path given KERAS_HOME env variable, if applicable.
   5275 # Otherwise either ~/.keras or /tmp.
   5276 if 'KERAS_HOME' in os.environ:
   5277   _keras_dir = os.environ.get('KERAS_HOME')
   5278 else:
   5279   _keras_base_dir = os.path.expanduser('~')
   5280   _keras_dir = os.path.join(_keras_base_dir, '.keras')
   5281 _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
   5282 if os.path.exists(_config_path):
   5283   try:
   5284     _config = json.load(open(_config_path))
   5285   except ValueError:
   5286     _config = {}
   5287   _floatx = _config.get('floatx', floatx())
   5288   assert _floatx in {'float16', 'float32', 'float64'}
   5289   _epsilon = _config.get('epsilon', epsilon())
   5290   assert isinstance(_epsilon, float)
   5291   _image_data_format = _config.get('image_data_format', image_data_format())
   5292   assert _image_data_format in {'channels_last', 'channels_first'}
   5293   set_floatx(_floatx)
   5294   set_epsilon(_epsilon)
   5295   set_image_data_format(_image_data_format)
   5296 
   5297 # Save config file.
   5298 if not os.path.exists(_keras_dir):
   5299   try:
   5300     os.makedirs(_keras_dir)
   5301   except OSError:
   5302     # Except permission denied and potential race conditions
   5303     # in multi-threaded environments.
   5304     pass
   5305 
   5306 if not os.path.exists(_config_path):
   5307   _config = {
   5308       'floatx': floatx(),
   5309       'epsilon': epsilon(),
   5310       'backend': 'tensorflow',
   5311       'image_data_format': image_data_format()
   5312   }
   5313   try:
   5314     with open(_config_path, 'w') as f:
   5315       f.write(json.dumps(_config, indent=4))
   5316   except IOError:
   5317     # Except permission denied.
   5318     pass
   5319 
   5320 
   5321 def in_multi_worker_mode():
   5322   """Whether we are operating in a Multi-Worker setting."""
   5323   tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
   5324   cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
   5325   return tf_config and 'master' not in cluster_spec.jobs
   5326 
   5327 
   5328 def configure_and_create_distributed_session(distribution_strategy):
   5329   """Configure session config and create a session with it."""
   5330 
   5331   def _create_session(distribution_strategy):
   5332     """Create the Distributed Strategy session."""
   5333     session_config = get_default_session_config()
   5334 
   5335     # If a session already exists, merge in its config; in the case there is a
   5336     # conflict, take values of the existing config.
   5337     global _SESSION
   5338     if getattr(_SESSION, 'session', None) and _SESSION.session._config:
   5339       session_config.MergeFrom(_SESSION.session._config)
   5340 
   5341     if is_tpu_strategy(distribution_strategy):
   5342       # TODO(priyag, yuefengz): Remove this workaround when Distribute
   5343       # Coordinator is integrated with keras and we can create a session from
   5344       # there.
   5345       distribution_strategy.configure(session_config)
   5346       master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
   5347       session = session_module.Session(config=session_config, target=master)
   5348     else:
   5349       worker_context = dc_context.get_current_worker_context()
   5350       if worker_context:
   5351         dc_session_config = worker_context.session_config
   5352         # Merge the default session config to the one from distribute
   5353         # coordinator, which is fine for now since they don't have
   5354         # conflicting configurations.
   5355         dc_session_config.MergeFrom(session_config)
   5356         session = session_module.Session(
   5357             config=dc_session_config, target=worker_context.master_target)
   5358       else:
   5359         distribution_strategy.configure(session_config)
   5360         session = session_module.Session(config=session_config)
   5361 
   5362     set_session(session)
   5363 
   5364   if in_multi_worker_mode():
   5365     dc.run_distribute_coordinator(
   5366         _create_session,
   5367         distribution_strategy,
   5368         mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
   5369   else:
   5370     _create_session(distribution_strategy)
   5371 
   5372 
   5373 def is_tpu_strategy(strategy):
   5374   """We're executing TPU Strategy."""
   5375   return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
   5376 
   5377 
   5378 def cast_variables_to_tensor(tensors):
   5379 
   5380   def _cast_variables_to_tensor(tensor):
   5381     if isinstance(tensor, variables_module.Variable):
   5382       return array_ops.identity(tensor)
   5383     return tensor
   5384 
   5385   return nest.map_structure(_cast_variables_to_tensor, tensors)
   5386