Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Solvers for linear least-squares."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 
     23 from tensorflow.contrib.solvers.python.ops import util
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 
     31 
     32 def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"):
     33   r"""Conjugate gradient least squares solver.
     34 
     35   Solves a linear least squares problem \\(||A x - rhs||_2\\) for a single
     36   righ-hand side, using an iterative, matrix-free algorithm where the action of
     37   the matrix A is represented by `operator`. The CGLS algorithm implicitly
     38   applies the symmetric conjugate gradient algorithm to the normal equations
     39   \\(A^* A x = A^* rhs\\). The iteration terminates when either
     40   the number of iterations exceeds `max_iter` or when the norm of the conjugate
     41   residual (residual of the normal equations) have been reduced to `tol` times
     42   its initial initial value, i.e.
     43   \\(||A^* (rhs - A x_k)|| <= tol ||A^* rhs||\\).
     44 
     45   Args:
     46     operator: An object representing a linear operator with attributes:
     47       - shape: Either a list of integers or a 1-D `Tensor` of type `int32` of
     48         length 2. `shape[0]` is the dimension on the domain of the operator,
     49         `shape[1]` is the dimension of the co-domain of the operator. On other
     50         words, if operator represents an M x N matrix A, `shape` must contain
     51         `[M, N]`.
     52       - dtype: The datatype of input to and output from `apply` and
     53         `apply_adjoint`.
     54       - apply: Callable object taking a vector `x` as input and returning a
     55         vector with the result of applying the operator to `x`, i.e. if
     56        `operator` represents matrix `A`, `apply` should return `A * x`.
     57       - apply_adjoint: Callable object taking a vector `x` as input and
     58         returning a vector with the result of applying the adjoint operator
     59         to `x`, i.e. if `operator` represents matrix `A`, `apply_adjoint` should
     60         return `conj(transpose(A)) * x`.
     61 
     62     rhs: A rank-1 `Tensor` of shape `[M]` containing the right-hand size vector.
     63     tol: A float scalar convergence tolerance.
     64     max_iter: An integer giving the maximum number of iterations.
     65     name: A name scope for the operation.
     66 
     67 
     68   Returns:
     69     output: A namedtuple representing the final state with fields:
     70       - i: A scalar `int32` `Tensor`. Number of iterations executed.
     71       - x: A rank-1 `Tensor` of shape `[N]` containing the computed solution.
     72       - r: A rank-1 `Tensor` of shape `[M]` containing the residual vector.
     73       - p: A rank-1 `Tensor` of shape `[N]`. The next descent direction.
     74       - gamma: \\(||A^* r||_2^2\\)
     75   """
     76   # ephemeral class holding CGLS state.
     77   cgls_state = collections.namedtuple("CGLSState",
     78                                       ["i", "x", "r", "p", "gamma"])
     79 
     80   def stopping_criterion(i, state):
     81     return math_ops.logical_and(i < max_iter, state.gamma > tol)
     82 
     83   # TODO(rmlarsen): add preconditioning
     84   def cgls_step(i, state):
     85     q = operator.apply(state.p)
     86     alpha = state.gamma / util.l2norm_squared(q)
     87     x = state.x + alpha * state.p
     88     r = state.r - alpha * q
     89     s = operator.apply_adjoint(r)
     90     gamma = util.l2norm_squared(s)
     91     beta = gamma / state.gamma
     92     p = s + beta * state.p
     93     return i + 1, cgls_state(i + 1, x, r, p, gamma)
     94 
     95   with ops.name_scope(name):
     96     n = operator.shape[1:]
     97     rhs = array_ops.expand_dims(rhs, -1)
     98     s0 = operator.apply_adjoint(rhs)
     99     gamma0 = util.l2norm_squared(s0)
    100     tol = tol * tol * gamma0
    101     x = array_ops.expand_dims(
    102         array_ops.zeros(
    103             n, dtype=rhs.dtype.base_dtype), -1)
    104     i = constant_op.constant(0, dtype=dtypes.int32)
    105     state = cgls_state(i=i, x=x, r=rhs, p=s0, gamma=gamma0)
    106     _, state = control_flow_ops.while_loop(stopping_criterion, cgls_step,
    107                                            [i, state])
    108     return cgls_state(
    109         state.i,
    110         x=array_ops.squeeze(state.x),
    111         r=array_ops.squeeze(state.r),
    112         p=array_ops.squeeze(state.p),
    113         gamma=state.gamma)
    114