1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Gradients for operators defined in linalg_ops.py. 16 17 Useful reference for derivative formulas is 18 An extended collection of matrix derivative results for forward and reverse 19 mode algorithmic differentiation by Mike Giles: 20 http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf 21 22 A detailed derivation of formulas for backpropagating through spectral layers 23 (SVD and Eig) by Ionescu, Vantzos & Sminchisescu: 24 https://arxiv.org/pdf/1509.07838v4.pdf 25 """ 26 from __future__ import absolute_import 27 from __future__ import division 28 from __future__ import print_function 29 30 from tensorflow.python.framework import ops 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import linalg_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.ops.linalg import linalg_impl as _linalg 36 37 38 @ops.RegisterGradient("MatrixInverse") 39 def _MatrixInverseGrad(op, grad): 40 """Gradient for MatrixInverse.""" 41 ainv = op.outputs[0] 42 return -math_ops.matmul( 43 ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True) 44 45 46 @ops.RegisterGradient("MatrixDeterminant") 47 def _MatrixDeterminantGrad(op, grad): 48 """Gradient for MatrixDeterminant.""" 49 a = op.inputs[0] 50 c = op.outputs[0] 51 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 52 multipliers = array_ops.reshape(grad * c, 53 array_ops.concat([array_ops.shape(c), [1, 1]], 54 0)) 55 return multipliers * a_adj_inv 56 57 58 @ops.RegisterGradient("Cholesky") 59 def _CholeskyGrad(op, grad): 60 """Gradient for Cholesky.""" 61 62 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} 63 l = op.outputs[0] 64 num_rows = array_ops.shape(l)[-1] 65 batch_shape = array_ops.shape(l)[:-2] 66 l_inverse = linalg_ops.matrix_triangular_solve(l, 67 linalg_ops.eye( 68 num_rows, 69 batch_shape=batch_shape, 70 dtype=l.dtype)) 71 72 middle = math_ops.matmul(l, grad, adjoint_a=True) 73 middle = array_ops.matrix_set_diag(middle, 74 0.5 * array_ops.matrix_diag_part(middle)) 75 middle = array_ops.matrix_band_part(middle, -1, 0) 76 77 grad_a = math_ops.matmul( 78 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) 79 80 grad_a += _linalg.adjoint(grad_a) 81 return grad_a * 0.5 82 83 84 @ops.RegisterGradient("Qr") 85 def _QrGrad(op, dq, dr): 86 """Gradient for Qr.""" 87 q, r = op.outputs 88 if q.dtype.is_complex: 89 raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) 90 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or 91 r.shape.as_list()[-1] is None): 92 raise NotImplementedError("QrGrad not implemented with dynamic shapes.") 93 if r.shape[-2].value != r.shape[-1].value: 94 raise NotImplementedError("QrGrad not implemented when ncols > nrows " 95 "or full_matrices is true and ncols != nrows.") 96 97 qdq = math_ops.matmul(q, dq, adjoint_a=True) 98 qdq_ = qdq - _linalg.adjoint(qdq) 99 rdr = math_ops.matmul(r, dr, adjoint_b=True) 100 rdr_ = rdr - _linalg.adjoint(rdr) 101 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) 102 103 def _TriangularSolve(x, r): 104 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" 105 return _linalg.adjoint( 106 linalg_ops.matrix_triangular_solve( 107 r, _linalg.adjoint(x), lower=False, adjoint=False)) 108 109 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) 110 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) 111 return grad_a + grad_b 112 113 114 @ops.RegisterGradient("MatrixSolve") 115 def _MatrixSolveGrad(op, grad): 116 """Gradient for MatrixSolve.""" 117 a = op.inputs[0] 118 adjoint_a = op.get_attr("adjoint") 119 c = op.outputs[0] 120 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) 121 if adjoint_a: 122 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 123 else: 124 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 125 return (grad_a, grad_b) 126 127 128 @ops.RegisterGradient("MatrixSolveLs") 129 def _MatrixSolveLsGrad(op, grad): 130 """Gradients for MatrixSolveLs.""" 131 132 # TODO(rmlarsen): The implementation could be more efficient: 133 # a) Output the Cholesky factorization from forward op instead of 134 # recomputing it here. 135 # b) Implement a symmetric rank-k update op instead of computing 136 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. 137 138 def _Overdetermined(op, grad): 139 """Gradients for the overdetermined case of MatrixSolveLs. 140 141 This is the backprop for the solution to the normal equations of the first 142 kind: 143 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B 144 which solve the least squares problem 145 min ||A * X - B||_F^2 + lambda ||X||_F^2. 146 """ 147 a = op.inputs[0] 148 b = op.inputs[1] 149 x = op.outputs[0] 150 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 151 # pylint: disable=protected-access 152 chol = linalg_ops._RegularizedGramianCholesky( 153 a, l2_regularizer=l2_regularizer, first_kind=True) 154 # pylint: enable=protected-access 155 # Temporary z = (A^T * A + lambda * I)^{-1} * grad. 156 z = linalg_ops.cholesky_solve(chol, grad) 157 xzt = math_ops.matmul(x, z, adjoint_b=True) 158 zx_sym = xzt + array_ops.matrix_transpose(xzt) 159 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) 160 grad_b = math_ops.matmul(a, z) 161 return (grad_a, grad_b, None) 162 163 def _Underdetermined(op, grad): 164 """Gradients for the underdetermined case of MatrixSolveLs. 165 166 This is the backprop for the solution to the normal equations of the second 167 kind: 168 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B 169 that (for lambda=0) solve the least squares problem 170 min ||X||_F subject to A*X = B. 171 """ 172 a = op.inputs[0] 173 b = op.inputs[1] 174 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 175 # pylint: disable=protected-access 176 chol = linalg_ops._RegularizedGramianCholesky( 177 a, l2_regularizer=l2_regularizer, first_kind=False) 178 # pylint: enable=protected-access 179 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) 180 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. 181 tmp = linalg_ops.cholesky_solve(chol, b) 182 a1 = math_ops.matmul(tmp, a, adjoint_a=True) 183 a1 = -math_ops.matmul(grad_b, a1) 184 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) 185 a2 = math_ops.matmul(tmp, a2, adjoint_b=True) 186 grad_a = a1 + a2 187 return (grad_a, grad_b, None) 188 189 fast = op.get_attr("fast") 190 if fast is False: 191 raise ValueError("Gradient not defined for fast=False") 192 matrix_shape = op.inputs[0].get_shape()[-2:] 193 if matrix_shape.is_fully_defined(): 194 if matrix_shape[-2] >= matrix_shape[-1]: 195 return _Overdetermined(op, grad) 196 else: 197 return _Underdetermined(op, grad) 198 else: 199 # We have to defer determining the shape to runtime and use 200 # conditional execution of the appropriate graph. 201 matrix_shape = array_ops.shape(op.inputs[0])[-2:] 202 return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], 203 lambda: _Overdetermined(op, grad), 204 lambda: _Underdetermined(op, grad)) 205 206 207 @ops.RegisterGradient("MatrixTriangularSolve") 208 def _MatrixTriangularSolveGrad(op, grad): 209 """Gradient for MatrixTriangularSolve.""" 210 a = op.inputs[0] 211 adjoint_a = op.get_attr("adjoint") 212 lower_a = op.get_attr("lower") 213 c = op.outputs[0] 214 grad_b = linalg_ops.matrix_triangular_solve( 215 a, grad, lower=lower_a, adjoint=not adjoint_a) 216 if adjoint_a: 217 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 218 else: 219 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 220 if lower_a: 221 grad_a = array_ops.matrix_band_part(grad_a, -1, 0) 222 else: 223 grad_a = array_ops.matrix_band_part(grad_a, 0, -1) 224 return (grad_a, grad_b) 225 226 227 @ops.RegisterGradient("SelfAdjointEigV2") 228 def _SelfAdjointEigV2Grad(op, grad_e, grad_v): 229 """Gradient for SelfAdjointEigV2.""" 230 e = op.outputs[0] 231 compute_v = op.get_attr("compute_v") 232 # a = op.inputs[0], which satisfies 233 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 234 with ops.control_dependencies([grad_e, grad_v]): 235 if compute_v: 236 v = op.outputs[1] 237 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 238 # Notice that because of the term involving f, the gradient becomes 239 # infinite (or NaN in practice) when eigenvalues are not unique. 240 # Mathematically this should not be surprising, since for (k-fold) 241 # degenerate eigenvalues, the corresponding eigenvectors are only defined 242 # up to arbitrary rotation in a (k-dimensional) subspace. 243 f = array_ops.matrix_set_diag( 244 math_ops.reciprocal( 245 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 246 array_ops.zeros_like(e)) 247 grad_a = math_ops.matmul( 248 v, 249 math_ops.matmul( 250 array_ops.matrix_diag(grad_e) + 251 f * math_ops.matmul(v, grad_v, adjoint_a=True), 252 v, 253 adjoint_b=True)) 254 else: 255 _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) 256 grad_a = math_ops.matmul(v, 257 math_ops.matmul( 258 array_ops.matrix_diag(grad_e), 259 v, 260 adjoint_b=True)) 261 # The forward op only depends on the lower triangular part of a, so here we 262 # symmetrize and take the lower triangle 263 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) 264 grad_a = array_ops.matrix_set_diag(grad_a, 265 0.5 * array_ops.matrix_diag_part(grad_a)) 266 return grad_a 267 268 269 @ops.RegisterGradient("Svd") 270 def _SvdGrad(op, grad_s, grad_u, grad_v): 271 """Gradient for the singular value decomposition.""" 272 273 # The derivation for the compute_uv=False case, and most of 274 # the derivation for the full_matrices=True case, are in 275 # Giles' paper (see reference at top of file). A derivation for 276 # the full_matrices=False case is available at 277 # https://j-towns.github.io/papers/svd-derivative.pdf 278 a = op.inputs[0] 279 a_shape = a.get_shape().with_rank_at_least(2) 280 grad_s_mat = array_ops.matrix_diag(grad_s) 281 282 if not op.get_attr("compute_uv"): 283 s, u, v = linalg_ops.svd(a, compute_uv=True) 284 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) 285 grad_a.set_shape(a_shape) 286 return grad_a 287 288 full_matrices = op.get_attr("full_matrices") 289 290 # TODO(rmlarsen): Make this work with complex types. 291 if a.dtype.is_complex: 292 raise NotImplementedError( 293 "SVD gradient is not implemented for complex types and " 294 "compute_uv=True.") 295 grad_u_shape = grad_u.get_shape().with_rank_at_least(2) 296 grad_v_shape = grad_v.get_shape().with_rank_at_least(2) 297 m = a_shape[-2].merge_with(grad_u_shape[-2]) 298 n = a_shape[-1].merge_with(grad_v_shape[-2]) 299 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( 300 grad_v_shape[:-2]) 301 a_shape = batch_shape.concatenate([m, n]) 302 303 m = a_shape[-2].value 304 n = a_shape[-1].value 305 # TODO(rmlarsen): Make this work with placeholders. 306 if m is None or n is None: 307 raise NotImplementedError( 308 "SVD gradient has not been implemented for input with unknown " 309 "inner matrix shape.") 310 311 s = op.outputs[0] 312 u = op.outputs[1] 313 v = op.outputs[2] 314 315 use_adjoint = False 316 if m > n: 317 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the 318 # Hermitian transpose of the gradient at the end. 319 use_adjoint = True 320 m, n = n, m 321 u, v = v, u 322 grad_u, grad_v = grad_v, grad_u 323 324 with ops.control_dependencies([grad_s, grad_u, grad_v]): 325 if full_matrices and abs(m - n) > 1: 326 raise NotImplementedError( 327 "svd gradient is not implemented for abs(m - n) > 1 " 328 "when full_matrices is True") 329 s_mat = array_ops.matrix_diag(s) 330 s2 = math_ops.square(s) 331 332 # NOTICE: Because of the term involving f, the gradient becomes 333 # infinite (or NaN in practice) when singular values are not unique. 334 # Mathematically this should not be surprising, since for (k-fold) 335 # degenerate singular values, the corresponding singular vectors are 336 # only defined up a (k-dimensional) subspace. In practice, this can 337 # lead to numerical instability when singular values are close but not 338 # exactly equal. 339 f = array_ops.matrix_set_diag( 340 math_ops.reciprocal( 341 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), 342 array_ops.zeros_like(s)) 343 s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) 344 345 v1 = v[..., :, :m] 346 grad_v1 = grad_v[..., :, :m] 347 348 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) 349 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) 350 351 f_u = f * u_gu 352 f_v = f * v_gv 353 354 term1_nouv = ( 355 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + 356 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) 357 358 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) 359 360 if m == n: 361 grad_a_before_transpose = term1 362 else: 363 gv1t = array_ops.matrix_transpose(grad_v1) 364 gv1t_v1 = math_ops.matmul(gv1t, v1) 365 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) 366 367 if full_matrices: 368 v2 = v[..., :, m:n] 369 grad_v2 = grad_v[..., :, m:n] 370 371 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) 372 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) 373 374 u_s_inv = math_ops.matmul(u, s_inv_mat) 375 term2 = math_ops.matmul(u_s_inv, term2_nous) 376 377 grad_a_before_transpose = term1 + term2 378 379 if use_adjoint: 380 grad_a = array_ops.matrix_transpose(grad_a_before_transpose) 381 else: 382 grad_a = grad_a_before_transpose 383 384 grad_a.set_shape(a_shape) 385 return grad_a 386