Home | History | Annotate | Download | only in linalg
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """`LinearOperator` acting like a diagonal matrix."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import check_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops.linalg import linalg_impl as linalg
     27 from tensorflow.python.ops.linalg import linear_operator
     28 from tensorflow.python.ops.linalg import linear_operator_util
     29 from tensorflow.python.util.tf_export import tf_export
     30 
     31 __all__ = ["LinearOperatorDiag",]
     32 
     33 
     34 @tf_export("linalg.LinearOperatorDiag")
     35 class LinearOperatorDiag(linear_operator.LinearOperator):
     36   """`LinearOperator` acting like a [batch] square diagonal matrix.
     37 
     38   This operator acts like a [batch] diagonal matrix `A` with shape
     39   `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
     40   batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
     41   an `N x N` matrix.  This matrix `A` is not materialized, but for
     42   purposes of broadcasting this shape will be relevant.
     43 
     44   `LinearOperatorDiag` is initialized with a (batch) vector.
     45 
     46   ```python
     47   # Create a 2 x 2 diagonal linear operator.
     48   diag = [1., -1.]
     49   operator = LinearOperatorDiag(diag)
     50 
     51   operator.to_dense()
     52   ==> [[1.,  0.]
     53        [0., -1.]]
     54 
     55   operator.shape
     56   ==> [2, 2]
     57 
     58   operator.log_abs_determinant()
     59   ==> scalar Tensor
     60 
     61   x = ... Shape [2, 4] Tensor
     62   operator.matmul(x)
     63   ==> Shape [2, 4] Tensor
     64 
     65   # Create a [2, 3] batch of 4 x 4 linear operators.
     66   diag = tf.random_normal(shape=[2, 3, 4])
     67   operator = LinearOperatorDiag(diag)
     68 
     69   # Create a shape [2, 1, 4, 2] vector.  Note that this shape is compatible
     70   # since the batch dimensions, [2, 1], are brodcast to
     71   # operator.batch_shape = [2, 3].
     72   y = tf.random_normal(shape=[2, 1, 4, 2])
     73   x = operator.solve(y)
     74   ==> operator.matmul(x) = y
     75   ```
     76 
     77   #### Shape compatibility
     78 
     79   This operator acts on [batch] matrix with compatible shape.
     80   `x` is a batch matrix with compatible shape for `matmul` and `solve` if
     81 
     82   ```
     83   operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
     84   x.shape =   [C1,...,Cc] + [N, R],
     85   and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
     86   ```
     87 
     88   #### Performance
     89 
     90   Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`,
     91   and `x.shape = [N, R]`.  Then
     92 
     93   * `operator.matmul(x)` involves `N * R` multiplications.
     94   * `operator.solve(x)` involves `N` divisions and `N * R` multiplications.
     95   * `operator.determinant()` involves a size `N` `reduce_prod`.
     96 
     97   If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
     98   `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
     99 
    100   #### Matrix property hints
    101 
    102   This `LinearOperator` is initialized with boolean flags of the form `is_X`,
    103   for `X = non_singular, self_adjoint, positive_definite, square`.
    104   These have the following meaning:
    105 
    106   * If `is_X == True`, callers should expect the operator to have the
    107     property `X`.  This is a promise that should be fulfilled, but is *not* a
    108     runtime assert.  For example, finite floating point precision may result
    109     in these promises being violated.
    110   * If `is_X == False`, callers should expect the operator to not have `X`.
    111   * If `is_X == None` (the default), callers should have no expectation either
    112     way.
    113   """
    114 
    115   def __init__(self,
    116                diag,
    117                is_non_singular=None,
    118                is_self_adjoint=None,
    119                is_positive_definite=None,
    120                is_square=None,
    121                name="LinearOperatorDiag"):
    122     r"""Initialize a `LinearOperatorDiag`.
    123 
    124     Args:
    125       diag:  Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
    126         The diagonal of the operator.  Allowed dtypes: `float16`, `float32`,
    127           `float64`, `complex64`, `complex128`.
    128       is_non_singular:  Expect that this operator is non-singular.
    129       is_self_adjoint:  Expect that this operator is equal to its hermitian
    130         transpose.  If `diag.dtype` is real, this is auto-set to `True`.
    131       is_positive_definite:  Expect that this operator is positive definite,
    132         meaning the quadratic form `x^H A x` has positive real part for all
    133         nonzero `x`.  Note that we do not require the operator to be
    134         self-adjoint to be positive-definite.  See:
    135         https://en.wikipedia.org/wiki/Positive-definite_matrix\
    136             #Extension_for_non_symmetric_matrices
    137       is_square:  Expect that this operator acts like square [batch] matrices.
    138       name: A name for this `LinearOperator`.
    139 
    140     Raises:
    141       TypeError:  If `diag.dtype` is not an allowed type.
    142       ValueError:  If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
    143     """
    144 
    145     with ops.name_scope(name, values=[diag]):
    146       self._diag = ops.convert_to_tensor(diag, name="diag")
    147       self._check_diag(self._diag)
    148 
    149       # Check and auto-set hints.
    150       if not self._diag.dtype.is_complex:
    151         if is_self_adjoint is False:
    152           raise ValueError("A real diagonal operator is always self adjoint.")
    153         else:
    154           is_self_adjoint = True
    155 
    156       if is_square is False:
    157         raise ValueError("Only square diagonal operators currently supported.")
    158       is_square = True
    159 
    160       super(LinearOperatorDiag, self).__init__(
    161           dtype=self._diag.dtype,
    162           graph_parents=[self._diag],
    163           is_non_singular=is_non_singular,
    164           is_self_adjoint=is_self_adjoint,
    165           is_positive_definite=is_positive_definite,
    166           is_square=is_square,
    167           name=name)
    168 
    169   def _check_diag(self, diag):
    170     """Static check of diag."""
    171     allowed_dtypes = [
    172         dtypes.float16,
    173         dtypes.float32,
    174         dtypes.float64,
    175         dtypes.complex64,
    176         dtypes.complex128,
    177     ]
    178 
    179     dtype = diag.dtype
    180     if dtype not in allowed_dtypes:
    181       raise TypeError(
    182           "Argument diag must have dtype in %s.  Found: %s"
    183           % (allowed_dtypes, dtype))
    184 
    185     if diag.get_shape().ndims is not None and diag.get_shape().ndims < 1:
    186       raise ValueError("Argument diag must have at least 1 dimension.  "
    187                        "Found: %s" % diag)
    188 
    189   def _shape(self):
    190     # If d_shape = [5, 3], we return [5, 3, 3].
    191     d_shape = self._diag.get_shape()
    192     return d_shape.concatenate(d_shape[-1:])
    193 
    194   def _shape_tensor(self):
    195     d_shape = array_ops.shape(self._diag)
    196     k = d_shape[-1]
    197     return array_ops.concat((d_shape, [k]), 0)
    198 
    199   def _assert_non_singular(self):
    200     return linear_operator_util.assert_no_entries_with_modulus_zero(
    201         self._diag,
    202         message="Singular operator:  Diagonal contained zero values.")
    203 
    204   def _assert_positive_definite(self):
    205     if self.dtype.is_complex:
    206       message = (
    207           "Diagonal operator had diagonal entries with non-positive real part, "
    208           "thus was not positive definite.")
    209     else:
    210       message = (
    211           "Real diagonal operator had non-positive diagonal entries, "
    212           "thus was not positive definite.")
    213 
    214     return check_ops.assert_positive(
    215         math_ops.real(self._diag),
    216         message=message)
    217 
    218   def _assert_self_adjoint(self):
    219     return linear_operator_util.assert_zero_imag_part(
    220         self._diag,
    221         message=(
    222             "This diagonal operator contained non-zero imaginary values.  "
    223             " Thus it was not self-adjoint."))
    224 
    225   def _matmul(self, x, adjoint=False, adjoint_arg=False):
    226     diag_term = math_ops.conj(self._diag) if adjoint else self._diag
    227     x = linalg.adjoint(x) if adjoint_arg else x
    228     diag_mat = array_ops.expand_dims(diag_term, -1)
    229     return diag_mat * x
    230 
    231   def _determinant(self):
    232     return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
    233 
    234   def _log_abs_determinant(self):
    235     return math_ops.reduce_sum(
    236         math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
    237 
    238   def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    239     diag_term = math_ops.conj(self._diag) if adjoint else self._diag
    240     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
    241     inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
    242     return rhs * inv_diag_mat
    243 
    244   def _to_dense(self):
    245     return array_ops.matrix_diag(self._diag)
    246 
    247   def _diag_part(self):
    248     return self.diag
    249 
    250   def _add_to_tensor(self, x):
    251     x_diag = array_ops.matrix_diag_part(x)
    252     new_diag = self._diag + x_diag
    253     return array_ops.matrix_set_diag(x, new_diag)
    254 
    255   @property
    256   def diag(self):
    257     return self._diag
    258