Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Solvers for linear equations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 
     23 from tensorflow.contrib.solvers.python.ops import util
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import linalg_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import linalg_ops
     32 
     33 
     34 def conjugate_gradient(operator,
     35                        rhs,
     36                        preconditioner=None,
     37                        x=None,
     38                        tol=1e-4,
     39                        max_iter=20,
     40                        name="conjugate_gradient"):
     41   r"""Conjugate gradient solver.
     42 
     43   Solves a linear system of equations `A*x = rhs` for selfadjoint, positive
     44   definite matrix `A` and righ-hand side vector `rhs`, using an iterative,
     45   matrix-free algorithm where the action of the matrix A is represented by
     46   `operator`. The iteration terminates when either the number of iterations
     47   exceeds `max_iter` or when the residual norm has been reduced to `tol`
     48   times its initial value, i.e. \\(||rhs - A x_k|| <= tol ||rhs||\\).
     49 
     50   Args:
     51     operator: An object representing a linear operator with attributes:
     52       - shape: Either a list of integers or a 1-D `Tensor` of type `int32` of
     53         length 2. `shape[0]` is the dimension on the domain of the operator,
     54         `shape[1]` is the dimension of the co-domain of the operator. On other
     55         words, if operator represents an N x N matrix A, `shape` must contain
     56         `[N, N]`.
     57       - dtype: The datatype of input to and output from `apply`.
     58       - apply: Callable object taking a vector `x` as input and returning a
     59         vector with the result of applying the operator to `x`, i.e. if
     60        `operator` represents matrix `A`, `apply` should return `A * x`.
     61     rhs: A rank-1 `Tensor` of shape `[N]` containing the right-hand size vector.
     62     preconditioner: An object representing a linear operator, see `operator`
     63       for detail. The preconditioner should approximate the inverse of `A`.
     64       An efficient preconditioner could dramatically improve the rate of
     65       convergence. If `preconditioner` represents matrix `M`(`M` approximates
     66       `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate
     67       `A^{-1}x`. For this to be useful, the cost of applying `M` should be
     68       much lower than computing `A^{-1}` directly.
     69     x: A rank-1 `Tensor` of shape `[N]` containing the initial guess for the
     70       solution.
     71     tol: A float scalar convergence tolerance.
     72     max_iter: An integer giving the maximum number of iterations.
     73     name: A name scope for the operation.
     74 
     75   Returns:
     76     output: A namedtuple representing the final state with fields:
     77       - i: A scalar `int32` `Tensor`. Number of iterations executed.
     78       - x: A rank-1 `Tensor` of shape `[N]` containing the computed solution.
     79       - r: A rank-1 `Tensor` of shape `[M]` containing the residual vector.
     80       - p: A rank-1 `Tensor` of shape `[N]`. `A`-conjugate basis vector.
     81       - gamma: \\(r \dot M \dot r\\), equivalent to  \\(||r||_2^2\\) when
     82         `preconditioner=None`.
     83   """
     84   # ephemeral class holding CG state.
     85   cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"])
     86 
     87   def stopping_criterion(i, state):
     88     return math_ops.logical_and(i < max_iter, linalg_ops.norm(state.r) > tol)
     89 
     90   def cg_step(i, state):  # pylint: disable=missing-docstring
     91     z = operator.apply(state.p)
     92     alpha = state.gamma / util.dot(state.p, z)
     93     x = state.x + alpha * state.p
     94     r = state.r - alpha * z
     95     if preconditioner is None:
     96       gamma = util.dot(r, r)
     97       beta = gamma / state.gamma
     98       p = r + beta * state.p
     99     else:
    100       q = preconditioner.apply(r)
    101       gamma = util.dot(r, q)
    102       beta = gamma / state.gamma
    103       p = q + beta * state.p
    104     return i + 1, cg_state(i + 1, x, r, p, gamma)
    105 
    106   with ops.name_scope(name):
    107     n = operator.shape[1:]
    108     rhs = array_ops.expand_dims(rhs, -1)
    109     if x is None:
    110       x = array_ops.expand_dims(
    111           array_ops.zeros(n, dtype=rhs.dtype.base_dtype), -1)
    112       r0 = rhs
    113     else:
    114       x = array_ops.expand_dims(x, -1)
    115       r0 = rhs - operator.apply(x)
    116     if preconditioner is None:
    117       p0 = r0
    118     else:
    119       p0 = preconditioner.apply(r0)
    120     gamma0 = util.dot(r0, p0)
    121     tol *= linalg_ops.norm(r0)
    122     i = constant_op.constant(0, dtype=dtypes.int32)
    123     state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0)
    124     _, state = control_flow_ops.while_loop(stopping_criterion, cg_step,
    125                                            [i, state])
    126     return cg_state(
    127         state.i,
    128         x=array_ops.squeeze(state.x),
    129         r=array_ops.squeeze(state.r),
    130         p=array_ops.squeeze(state.p),
    131         gamma=state.gamma)
    132