1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """`LinearOperator` acting like a diagonal matrix.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import check_ops 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops.linalg import linalg_impl as linalg 27 from tensorflow.python.ops.linalg import linear_operator 28 from tensorflow.python.ops.linalg import linear_operator_util 29 from tensorflow.python.util.tf_export import tf_export 30 31 __all__ = ["LinearOperatorDiag",] 32 33 34 @tf_export("linalg.LinearOperatorDiag") 35 class LinearOperatorDiag(linear_operator.LinearOperator): 36 """`LinearOperator` acting like a [batch] square diagonal matrix. 37 38 This operator acts like a [batch] diagonal matrix `A` with shape 39 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 40 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 41 an `N x N` matrix. This matrix `A` is not materialized, but for 42 purposes of broadcasting this shape will be relevant. 43 44 `LinearOperatorDiag` is initialized with a (batch) vector. 45 46 ```python 47 # Create a 2 x 2 diagonal linear operator. 48 diag = [1., -1.] 49 operator = LinearOperatorDiag(diag) 50 51 operator.to_dense() 52 ==> [[1., 0.] 53 [0., -1.]] 54 55 operator.shape 56 ==> [2, 2] 57 58 operator.log_abs_determinant() 59 ==> scalar Tensor 60 61 x = ... Shape [2, 4] Tensor 62 operator.matmul(x) 63 ==> Shape [2, 4] Tensor 64 65 # Create a [2, 3] batch of 4 x 4 linear operators. 66 diag = tf.random_normal(shape=[2, 3, 4]) 67 operator = LinearOperatorDiag(diag) 68 69 # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible 70 # since the batch dimensions, [2, 1], are brodcast to 71 # operator.batch_shape = [2, 3]. 72 y = tf.random_normal(shape=[2, 1, 4, 2]) 73 x = operator.solve(y) 74 ==> operator.matmul(x) = y 75 ``` 76 77 #### Shape compatibility 78 79 This operator acts on [batch] matrix with compatible shape. 80 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 81 82 ``` 83 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 84 x.shape = [C1,...,Cc] + [N, R], 85 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 86 ``` 87 88 #### Performance 89 90 Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, 91 and `x.shape = [N, R]`. Then 92 93 * `operator.matmul(x)` involves `N * R` multiplications. 94 * `operator.solve(x)` involves `N` divisions and `N * R` multiplications. 95 * `operator.determinant()` involves a size `N` `reduce_prod`. 96 97 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 98 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 99 100 #### Matrix property hints 101 102 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 103 for `X = non_singular, self_adjoint, positive_definite, square`. 104 These have the following meaning: 105 106 * If `is_X == True`, callers should expect the operator to have the 107 property `X`. This is a promise that should be fulfilled, but is *not* a 108 runtime assert. For example, finite floating point precision may result 109 in these promises being violated. 110 * If `is_X == False`, callers should expect the operator to not have `X`. 111 * If `is_X == None` (the default), callers should have no expectation either 112 way. 113 """ 114 115 def __init__(self, 116 diag, 117 is_non_singular=None, 118 is_self_adjoint=None, 119 is_positive_definite=None, 120 is_square=None, 121 name="LinearOperatorDiag"): 122 r"""Initialize a `LinearOperatorDiag`. 123 124 Args: 125 diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 126 The diagonal of the operator. Allowed dtypes: `float16`, `float32`, 127 `float64`, `complex64`, `complex128`. 128 is_non_singular: Expect that this operator is non-singular. 129 is_self_adjoint: Expect that this operator is equal to its hermitian 130 transpose. If `diag.dtype` is real, this is auto-set to `True`. 131 is_positive_definite: Expect that this operator is positive definite, 132 meaning the quadratic form `x^H A x` has positive real part for all 133 nonzero `x`. Note that we do not require the operator to be 134 self-adjoint to be positive-definite. See: 135 https://en.wikipedia.org/wiki/Positive-definite_matrix\ 136 #Extension_for_non_symmetric_matrices 137 is_square: Expect that this operator acts like square [batch] matrices. 138 name: A name for this `LinearOperator`. 139 140 Raises: 141 TypeError: If `diag.dtype` is not an allowed type. 142 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`. 143 """ 144 145 with ops.name_scope(name, values=[diag]): 146 self._diag = ops.convert_to_tensor(diag, name="diag") 147 self._check_diag(self._diag) 148 149 # Check and auto-set hints. 150 if not self._diag.dtype.is_complex: 151 if is_self_adjoint is False: 152 raise ValueError("A real diagonal operator is always self adjoint.") 153 else: 154 is_self_adjoint = True 155 156 if is_square is False: 157 raise ValueError("Only square diagonal operators currently supported.") 158 is_square = True 159 160 super(LinearOperatorDiag, self).__init__( 161 dtype=self._diag.dtype, 162 graph_parents=[self._diag], 163 is_non_singular=is_non_singular, 164 is_self_adjoint=is_self_adjoint, 165 is_positive_definite=is_positive_definite, 166 is_square=is_square, 167 name=name) 168 169 def _check_diag(self, diag): 170 """Static check of diag.""" 171 allowed_dtypes = [ 172 dtypes.float16, 173 dtypes.float32, 174 dtypes.float64, 175 dtypes.complex64, 176 dtypes.complex128, 177 ] 178 179 dtype = diag.dtype 180 if dtype not in allowed_dtypes: 181 raise TypeError( 182 "Argument diag must have dtype in %s. Found: %s" 183 % (allowed_dtypes, dtype)) 184 185 if diag.get_shape().ndims is not None and diag.get_shape().ndims < 1: 186 raise ValueError("Argument diag must have at least 1 dimension. " 187 "Found: %s" % diag) 188 189 def _shape(self): 190 # If d_shape = [5, 3], we return [5, 3, 3]. 191 d_shape = self._diag.get_shape() 192 return d_shape.concatenate(d_shape[-1:]) 193 194 def _shape_tensor(self): 195 d_shape = array_ops.shape(self._diag) 196 k = d_shape[-1] 197 return array_ops.concat((d_shape, [k]), 0) 198 199 def _assert_non_singular(self): 200 return linear_operator_util.assert_no_entries_with_modulus_zero( 201 self._diag, 202 message="Singular operator: Diagonal contained zero values.") 203 204 def _assert_positive_definite(self): 205 if self.dtype.is_complex: 206 message = ( 207 "Diagonal operator had diagonal entries with non-positive real part, " 208 "thus was not positive definite.") 209 else: 210 message = ( 211 "Real diagonal operator had non-positive diagonal entries, " 212 "thus was not positive definite.") 213 214 return check_ops.assert_positive( 215 math_ops.real(self._diag), 216 message=message) 217 218 def _assert_self_adjoint(self): 219 return linear_operator_util.assert_zero_imag_part( 220 self._diag, 221 message=( 222 "This diagonal operator contained non-zero imaginary values. " 223 " Thus it was not self-adjoint.")) 224 225 def _matmul(self, x, adjoint=False, adjoint_arg=False): 226 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 227 x = linalg.adjoint(x) if adjoint_arg else x 228 diag_mat = array_ops.expand_dims(diag_term, -1) 229 return diag_mat * x 230 231 def _determinant(self): 232 return math_ops.reduce_prod(self._diag, reduction_indices=[-1]) 233 234 def _log_abs_determinant(self): 235 return math_ops.reduce_sum( 236 math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) 237 238 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 239 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 240 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 241 inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) 242 return rhs * inv_diag_mat 243 244 def _to_dense(self): 245 return array_ops.matrix_diag(self._diag) 246 247 def _diag_part(self): 248 return self.diag 249 250 def _add_to_tensor(self, x): 251 x_diag = array_ops.matrix_diag_part(x) 252 new_diag = self._diag + x_diag 253 return array_ops.matrix_set_diag(x, new_diag) 254 255 @property 256 def diag(self): 257 return self._diag 258