Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Gradients for operators defined in linalg_ops.py.
     16 
     17 Useful reference for derivative formulas is
     18 An extended collection of matrix derivative results for forward and reverse
     19 mode algorithmic differentiation by Mike Giles:
     20 http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
     21 
     22 A detailed derivation of formulas for backpropagating through spectral layers
     23 (SVD and Eig) by Ionescu, Vantzos & Sminchisescu:
     24 https://arxiv.org/pdf/1509.07838v4.pdf
     25 """
     26 from __future__ import absolute_import
     27 from __future__ import division
     28 from __future__ import print_function
     29 
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import linalg_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops.linalg import linalg_impl as _linalg
     36 
     37 
     38 @ops.RegisterGradient("MatrixInverse")
     39 def _MatrixInverseGrad(op, grad):
     40   """Gradient for MatrixInverse."""
     41   ainv = op.outputs[0]
     42   return -math_ops.matmul(
     43       ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True)
     44 
     45 
     46 @ops.RegisterGradient("MatrixDeterminant")
     47 def _MatrixDeterminantGrad(op, grad):
     48   """Gradient for MatrixDeterminant."""
     49   a = op.inputs[0]
     50   c = op.outputs[0]
     51   a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
     52   multipliers = array_ops.reshape(grad * c,
     53                                   array_ops.concat([array_ops.shape(c), [1, 1]],
     54                                                    0))
     55   return multipliers * a_adj_inv
     56 
     57 
     58 @ops.RegisterGradient("Cholesky")
     59 def _CholeskyGrad(op, grad):
     60   """Gradient for Cholesky."""
     61 
     62   # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
     63   l = op.outputs[0]
     64   num_rows = array_ops.shape(l)[-1]
     65   batch_shape = array_ops.shape(l)[:-2]
     66   l_inverse = linalg_ops.matrix_triangular_solve(l,
     67                                                  linalg_ops.eye(
     68                                                      num_rows,
     69                                                      batch_shape=batch_shape,
     70                                                      dtype=l.dtype))
     71 
     72   middle = math_ops.matmul(l, grad, adjoint_a=True)
     73   middle = array_ops.matrix_set_diag(middle,
     74                                      0.5 * array_ops.matrix_diag_part(middle))
     75   middle = array_ops.matrix_band_part(middle, -1, 0)
     76 
     77   grad_a = math_ops.matmul(
     78       math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
     79 
     80   grad_a += _linalg.adjoint(grad_a)
     81   return grad_a * 0.5
     82 
     83 
     84 @ops.RegisterGradient("Qr")
     85 def _QrGrad(op, dq, dr):
     86   """Gradient for Qr."""
     87   q, r = op.outputs
     88   if q.dtype.is_complex:
     89     raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
     90   if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
     91       r.shape.as_list()[-1] is None):
     92     raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
     93   if r.shape[-2].value != r.shape[-1].value:
     94     raise NotImplementedError("QrGrad not implemented when ncols > nrows "
     95                               "or full_matrices is true and ncols != nrows.")
     96 
     97   qdq = math_ops.matmul(q, dq, adjoint_a=True)
     98   qdq_ = qdq - _linalg.adjoint(qdq)
     99   rdr = math_ops.matmul(r, dr, adjoint_b=True)
    100   rdr_ = rdr - _linalg.adjoint(rdr)
    101   tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
    102 
    103   def _TriangularSolve(x, r):
    104     """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
    105     return _linalg.adjoint(
    106         linalg_ops.matrix_triangular_solve(
    107             r, _linalg.adjoint(x), lower=False, adjoint=False))
    108 
    109   grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
    110   grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
    111   return grad_a + grad_b
    112 
    113 
    114 @ops.RegisterGradient("MatrixSolve")
    115 def _MatrixSolveGrad(op, grad):
    116   """Gradient for MatrixSolve."""
    117   a = op.inputs[0]
    118   adjoint_a = op.get_attr("adjoint")
    119   c = op.outputs[0]
    120   grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
    121   if adjoint_a:
    122     grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
    123   else:
    124     grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
    125   return (grad_a, grad_b)
    126 
    127 
    128 @ops.RegisterGradient("MatrixSolveLs")
    129 def _MatrixSolveLsGrad(op, grad):
    130   """Gradients for MatrixSolveLs."""
    131 
    132   # TODO(rmlarsen): The implementation could be more efficient:
    133   #   a) Output the Cholesky factorization from forward op instead of
    134   #      recomputing it here.
    135   #   b) Implement a symmetric rank-k update op instead of computing
    136   #      x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
    137 
    138   def _Overdetermined(op, grad):
    139     """Gradients for the overdetermined case of MatrixSolveLs.
    140 
    141     This is the backprop for the solution to the normal equations of the first
    142     kind:
    143        X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
    144     which solve the least squares problem
    145        min ||A * X - B||_F^2 + lambda ||X||_F^2.
    146     """
    147     a = op.inputs[0]
    148     b = op.inputs[1]
    149     x = op.outputs[0]
    150     l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
    151     # pylint: disable=protected-access
    152     chol = linalg_ops._RegularizedGramianCholesky(
    153         a, l2_regularizer=l2_regularizer, first_kind=True)
    154     # pylint: enable=protected-access
    155     # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
    156     z = linalg_ops.cholesky_solve(chol, grad)
    157     xzt = math_ops.matmul(x, z, adjoint_b=True)
    158     zx_sym = xzt + array_ops.matrix_transpose(xzt)
    159     grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)
    160     grad_b = math_ops.matmul(a, z)
    161     return (grad_a, grad_b, None)
    162 
    163   def _Underdetermined(op, grad):
    164     """Gradients for the underdetermined case of MatrixSolveLs.
    165 
    166     This is the backprop for the solution to the normal equations of the second
    167     kind:
    168       X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
    169     that (for lambda=0) solve the least squares problem
    170       min ||X||_F subject to A*X = B.
    171     """
    172     a = op.inputs[0]
    173     b = op.inputs[1]
    174     l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
    175     # pylint: disable=protected-access
    176     chol = linalg_ops._RegularizedGramianCholesky(
    177         a, l2_regularizer=l2_regularizer, first_kind=False)
    178     # pylint: enable=protected-access
    179     grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
    180     # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
    181     tmp = linalg_ops.cholesky_solve(chol, b)
    182     a1 = math_ops.matmul(tmp, a, adjoint_a=True)
    183     a1 = -math_ops.matmul(grad_b, a1)
    184     a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
    185     a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
    186     grad_a = a1 + a2
    187     return (grad_a, grad_b, None)
    188 
    189   fast = op.get_attr("fast")
    190   if fast is False:
    191     raise ValueError("Gradient not defined for fast=False")
    192   matrix_shape = op.inputs[0].get_shape()[-2:]
    193   if matrix_shape.is_fully_defined():
    194     if matrix_shape[-2] >= matrix_shape[-1]:
    195       return _Overdetermined(op, grad)
    196     else:
    197       return _Underdetermined(op, grad)
    198   else:
    199     # We have to defer determining the shape to runtime and use
    200     # conditional execution of the appropriate graph.
    201     matrix_shape = array_ops.shape(op.inputs[0])[-2:]
    202     return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
    203                                  lambda: _Overdetermined(op, grad),
    204                                  lambda: _Underdetermined(op, grad))
    205 
    206 
    207 @ops.RegisterGradient("MatrixTriangularSolve")
    208 def _MatrixTriangularSolveGrad(op, grad):
    209   """Gradient for MatrixTriangularSolve."""
    210   a = op.inputs[0]
    211   adjoint_a = op.get_attr("adjoint")
    212   lower_a = op.get_attr("lower")
    213   c = op.outputs[0]
    214   grad_b = linalg_ops.matrix_triangular_solve(
    215       a, grad, lower=lower_a, adjoint=not adjoint_a)
    216   if adjoint_a:
    217     grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
    218   else:
    219     grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
    220   if lower_a:
    221     grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
    222   else:
    223     grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
    224   return (grad_a, grad_b)
    225 
    226 
    227 @ops.RegisterGradient("SelfAdjointEigV2")
    228 def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
    229   """Gradient for SelfAdjointEigV2."""
    230   e = op.outputs[0]
    231   compute_v = op.get_attr("compute_v")
    232   # a = op.inputs[0], which satisfies
    233   # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
    234   with ops.control_dependencies([grad_e, grad_v]):
    235     if compute_v:
    236       v = op.outputs[1]
    237       # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
    238       # Notice that because of the term involving f, the gradient becomes
    239       # infinite (or NaN in practice) when eigenvalues are not unique.
    240       # Mathematically this should not be surprising, since for (k-fold)
    241       # degenerate eigenvalues, the corresponding eigenvectors are only defined
    242       # up to arbitrary rotation in a (k-dimensional) subspace.
    243       f = array_ops.matrix_set_diag(
    244           math_ops.reciprocal(
    245               array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
    246           array_ops.zeros_like(e))
    247       grad_a = math_ops.matmul(
    248           v,
    249           math_ops.matmul(
    250               array_ops.matrix_diag(grad_e) +
    251               f * math_ops.matmul(v, grad_v, adjoint_a=True),
    252               v,
    253               adjoint_b=True))
    254     else:
    255       _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
    256       grad_a = math_ops.matmul(v,
    257                                math_ops.matmul(
    258                                    array_ops.matrix_diag(grad_e),
    259                                    v,
    260                                    adjoint_b=True))
    261     # The forward op only depends on the lower triangular part of a, so here we
    262     # symmetrize and take the lower triangle
    263     grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
    264     grad_a = array_ops.matrix_set_diag(grad_a,
    265                                        0.5 * array_ops.matrix_diag_part(grad_a))
    266     return grad_a
    267 
    268 
    269 @ops.RegisterGradient("Svd")
    270 def _SvdGrad(op, grad_s, grad_u, grad_v):
    271   """Gradient for the singular value decomposition."""
    272 
    273   # The derivation for the compute_uv=False case, and most of
    274   # the derivation for the full_matrices=True case, are in
    275   # Giles' paper (see reference at top of file).  A derivation for
    276   # the full_matrices=False case is available at
    277   # https://j-towns.github.io/papers/svd-derivative.pdf
    278   a = op.inputs[0]
    279   a_shape = a.get_shape().with_rank_at_least(2)
    280   grad_s_mat = array_ops.matrix_diag(grad_s)
    281 
    282   if not op.get_attr("compute_uv"):
    283     s, u, v = linalg_ops.svd(a, compute_uv=True)
    284     grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
    285     grad_a.set_shape(a_shape)
    286     return grad_a
    287 
    288   full_matrices = op.get_attr("full_matrices")
    289 
    290   # TODO(rmlarsen): Make this work with complex types.
    291   if a.dtype.is_complex:
    292     raise NotImplementedError(
    293         "SVD gradient is not implemented for complex types and "
    294         "compute_uv=True.")
    295   grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
    296   grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
    297   m = a_shape[-2].merge_with(grad_u_shape[-2])
    298   n = a_shape[-1].merge_with(grad_v_shape[-2])
    299   batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
    300       grad_v_shape[:-2])
    301   a_shape = batch_shape.concatenate([m, n])
    302 
    303   m = a_shape[-2].value
    304   n = a_shape[-1].value
    305   # TODO(rmlarsen): Make this work with placeholders.
    306   if m is None or n is None:
    307     raise NotImplementedError(
    308         "SVD gradient has not been implemented for input with unknown "
    309         "inner matrix shape.")
    310 
    311   s = op.outputs[0]
    312   u = op.outputs[1]
    313   v = op.outputs[2]
    314 
    315   use_adjoint = False
    316   if m > n:
    317     # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
    318     # Hermitian transpose of the gradient at the end.
    319     use_adjoint = True
    320     m, n = n, m
    321     u, v = v, u
    322     grad_u, grad_v = grad_v, grad_u
    323 
    324   with ops.control_dependencies([grad_s, grad_u, grad_v]):
    325     if full_matrices and abs(m - n) > 1:
    326       raise NotImplementedError(
    327           "svd gradient is not implemented for abs(m - n) > 1 "
    328           "when full_matrices is True")
    329     s_mat = array_ops.matrix_diag(s)
    330     s2 = math_ops.square(s)
    331 
    332     # NOTICE: Because of the term involving f, the gradient becomes
    333     # infinite (or NaN in practice) when singular values are not unique.
    334     # Mathematically this should not be surprising, since for (k-fold)
    335     # degenerate singular values, the corresponding singular vectors are
    336     # only defined up a (k-dimensional) subspace. In practice, this can
    337     # lead to numerical instability when singular values are close but not
    338     # exactly equal.
    339     f = array_ops.matrix_set_diag(
    340         math_ops.reciprocal(
    341             array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
    342         array_ops.zeros_like(s))
    343     s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
    344 
    345     v1 = v[..., :, :m]
    346     grad_v1 = grad_v[..., :, :m]
    347 
    348     u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
    349     v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
    350 
    351     f_u = f * u_gu
    352     f_v = f * v_gv
    353 
    354     term1_nouv = (
    355         grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
    356         math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
    357 
    358     term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
    359 
    360     if m == n:
    361       grad_a_before_transpose = term1
    362     else:
    363       gv1t = array_ops.matrix_transpose(grad_v1)
    364       gv1t_v1 = math_ops.matmul(gv1t, v1)
    365       term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
    366 
    367       if full_matrices:
    368         v2 = v[..., :, m:n]
    369         grad_v2 = grad_v[..., :, m:n]
    370 
    371         v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
    372         term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
    373 
    374       u_s_inv = math_ops.matmul(u, s_inv_mat)
    375       term2 = math_ops.matmul(u_s_inv, term2_nous)
    376 
    377       grad_a_before_transpose = term1 + term2
    378 
    379     if use_adjoint:
    380       grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
    381     else:
    382       grad_a = grad_a_before_transpose
    383 
    384     grad_a.set_shape(a_shape)
    385     return grad_a
    386