1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Solvers for linear equations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 23 from tensorflow.contrib.solvers.python.ops import util 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import linalg_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import linalg_ops 32 33 34 def conjugate_gradient(operator, 35 rhs, 36 preconditioner=None, 37 x=None, 38 tol=1e-4, 39 max_iter=20, 40 name="conjugate_gradient"): 41 r"""Conjugate gradient solver. 42 43 Solves a linear system of equations `A*x = rhs` for selfadjoint, positive 44 definite matrix `A` and righ-hand side vector `rhs`, using an iterative, 45 matrix-free algorithm where the action of the matrix A is represented by 46 `operator`. The iteration terminates when either the number of iterations 47 exceeds `max_iter` or when the residual norm has been reduced to `tol` 48 times its initial value, i.e. \\(||rhs - A x_k|| <= tol ||rhs||\\). 49 50 Args: 51 operator: An object representing a linear operator with attributes: 52 - shape: Either a list of integers or a 1-D `Tensor` of type `int32` of 53 length 2. `shape[0]` is the dimension on the domain of the operator, 54 `shape[1]` is the dimension of the co-domain of the operator. On other 55 words, if operator represents an N x N matrix A, `shape` must contain 56 `[N, N]`. 57 - dtype: The datatype of input to and output from `apply`. 58 - apply: Callable object taking a vector `x` as input and returning a 59 vector with the result of applying the operator to `x`, i.e. if 60 `operator` represents matrix `A`, `apply` should return `A * x`. 61 rhs: A rank-1 `Tensor` of shape `[N]` containing the right-hand size vector. 62 preconditioner: An object representing a linear operator, see `operator` 63 for detail. The preconditioner should approximate the inverse of `A`. 64 An efficient preconditioner could dramatically improve the rate of 65 convergence. If `preconditioner` represents matrix `M`(`M` approximates 66 `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate 67 `A^{-1}x`. For this to be useful, the cost of applying `M` should be 68 much lower than computing `A^{-1}` directly. 69 x: A rank-1 `Tensor` of shape `[N]` containing the initial guess for the 70 solution. 71 tol: A float scalar convergence tolerance. 72 max_iter: An integer giving the maximum number of iterations. 73 name: A name scope for the operation. 74 75 Returns: 76 output: A namedtuple representing the final state with fields: 77 - i: A scalar `int32` `Tensor`. Number of iterations executed. 78 - x: A rank-1 `Tensor` of shape `[N]` containing the computed solution. 79 - r: A rank-1 `Tensor` of shape `[M]` containing the residual vector. 80 - p: A rank-1 `Tensor` of shape `[N]`. `A`-conjugate basis vector. 81 - gamma: \\(r \dot M \dot r\\), equivalent to \\(||r||_2^2\\) when 82 `preconditioner=None`. 83 """ 84 # ephemeral class holding CG state. 85 cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"]) 86 87 def stopping_criterion(i, state): 88 return math_ops.logical_and(i < max_iter, linalg_ops.norm(state.r) > tol) 89 90 def cg_step(i, state): # pylint: disable=missing-docstring 91 z = operator.apply(state.p) 92 alpha = state.gamma / util.dot(state.p, z) 93 x = state.x + alpha * state.p 94 r = state.r - alpha * z 95 if preconditioner is None: 96 gamma = util.dot(r, r) 97 beta = gamma / state.gamma 98 p = r + beta * state.p 99 else: 100 q = preconditioner.apply(r) 101 gamma = util.dot(r, q) 102 beta = gamma / state.gamma 103 p = q + beta * state.p 104 return i + 1, cg_state(i + 1, x, r, p, gamma) 105 106 with ops.name_scope(name): 107 n = operator.shape[1:] 108 rhs = array_ops.expand_dims(rhs, -1) 109 if x is None: 110 x = array_ops.expand_dims( 111 array_ops.zeros(n, dtype=rhs.dtype.base_dtype), -1) 112 r0 = rhs 113 else: 114 x = array_ops.expand_dims(x, -1) 115 r0 = rhs - operator.apply(x) 116 if preconditioner is None: 117 p0 = r0 118 else: 119 p0 = preconditioner.apply(r0) 120 gamma0 = util.dot(r0, p0) 121 tol *= linalg_ops.norm(r0) 122 i = constant_op.constant(0, dtype=dtypes.int32) 123 state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0) 124 _, state = control_flow_ops.while_loop(stopping_criterion, cg_step, 125 [i, state]) 126 return cg_state( 127 state.i, 128 x=array_ops.squeeze(state.x), 129 r=array_ops.squeeze(state.r), 130 p=array_ops.squeeze(state.p), 131 gamma=state.gamma) 132