Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Add one or more `LinearOperators` efficiently."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 
     23 import six
     24 
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import check_ops
     28 from tensorflow.python.ops.linalg import linear_operator
     29 from tensorflow.python.ops.linalg import linear_operator_diag
     30 from tensorflow.python.ops.linalg import linear_operator_full_matrix
     31 from tensorflow.python.ops.linalg import linear_operator_identity
     32 from tensorflow.python.ops.linalg import linear_operator_lower_triangular
     33 
     34 __all__ = []
     35 
     36 
     37 def add_operators(operators,
     38                   operator_name=None,
     39                   addition_tiers=None,
     40                   name=None):
     41   """Efficiently add one or more linear operators.
     42 
     43   Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
     44   operators `[B1, B2,...]` such that
     45 
     46   ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
     47 
     48   The operators `Bk` result by adding some of the `Ak`, as allowed by
     49   `addition_tiers`.
     50 
     51   Example of efficient adding of diagonal operators.
     52 
     53   ```python
     54   A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
     55   A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
     56 
     57   # Use two tiers, the first contains an Adder that returns Diag.  Since both
     58   # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
     59   # used.
     60   addition_tiers = [
     61       [_AddAndReturnDiag()],
     62       [_AddAndReturnMatrix()]]
     63   B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
     64 
     65   len(B_list)
     66   ==> 1
     67 
     68   B_list[0].__class__.__name__
     69   ==> 'LinearOperatorDiag'
     70 
     71   B_list[0].to_dense()
     72   ==> [[3., 0.],
     73        [0., 3.]]
     74 
     75   B_list[0].name
     76   ==> 'Add/A1__A2/'
     77   ```
     78 
     79   Args:
     80     operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
     81       and range dimensions, and broadcastable batch shapes.
     82     operator_name:  String name for returned `LinearOperator`.  Defaults to
     83       concatenation of "Add/A__B/" that indicates the order of addition steps.
     84     addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
     85       is a list of `Adder` objects.  This function attempts to do all additions
     86       in tier `i` before trying tier `i + 1`.
     87     name:  A name for this `Op`.  Defaults to `add_operators`.
     88 
     89   Returns:
     90     Subclass of `LinearOperator`.  Class and order of addition may change as new
     91       (and better) addition strategies emerge.
     92 
     93   Raises:
     94     ValueError:  If `operators` argument is empty.
     95     ValueError:  If shapes are incompatible.
     96   """
     97   # Default setting
     98   if addition_tiers is None:
     99     addition_tiers = _DEFAULT_ADDITION_TIERS
    100 
    101   # Argument checking.
    102   check_ops.assert_proper_iterable(operators)
    103   operators = list(reversed(operators))
    104   if len(operators) < 1:
    105     raise ValueError(
    106         "Argument 'operators' must contain at least one operator.  "
    107         "Found: %s" % operators)
    108   if not all(
    109       isinstance(op, linear_operator.LinearOperator) for op in operators):
    110     raise TypeError(
    111         "Argument 'operators' must contain only LinearOperator instances.  "
    112         "Found: %s" % operators)
    113   _static_check_for_same_dimensions(operators)
    114   _static_check_for_broadcastable_batch_shape(operators)
    115 
    116   graph_parents = []
    117   for operator in operators:
    118     graph_parents.extend(operator.graph_parents)
    119 
    120   with ops.name_scope(name or "add_operators", values=graph_parents):
    121 
    122     # Additions done in one of the tiers.  Try tier 0, 1,...
    123     ops_to_try_at_next_tier = list(operators)
    124     for tier in addition_tiers:
    125       ops_to_try_at_this_tier = ops_to_try_at_next_tier
    126       ops_to_try_at_next_tier = []
    127       while ops_to_try_at_this_tier:
    128         op1 = ops_to_try_at_this_tier.pop()
    129         op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
    130         if op2 is not None:
    131           # Will try to add the result of this again at this same tier.
    132           new_operator = adder.add(op1, op2, operator_name)
    133           ops_to_try_at_this_tier.append(new_operator)
    134         else:
    135           ops_to_try_at_next_tier.append(op1)
    136 
    137     return ops_to_try_at_next_tier
    138 
    139 
    140 def _pop_a_match_at_tier(op1, operator_list, tier):
    141   # Search from the back of list to the front in order to create nice default
    142   # order of operations.
    143   for i in range(1, len(operator_list) + 1):
    144     op2 = operator_list[-i]
    145     for adder in tier:
    146       if adder.can_add(op1, op2):
    147         return operator_list.pop(-i), adder
    148   return None, None
    149 
    150 
    151 def _infer_hints_allowing_override(op1, op2, hints):
    152   """Infer hints from op1 and op2.  hints argument is an override.
    153 
    154   Args:
    155     op1:  LinearOperator
    156     op2:  LinearOperator
    157     hints:  _Hints object holding "is_X" boolean hints to use for returned
    158       operator.
    159       If some hint is None, try to set using op1 and op2.  If the
    160       hint is provided, ignore op1 and op2 hints.  This allows an override
    161       of previous hints, but does not allow forbidden hints (e.g. you still
    162       cannot say a real diagonal operator is not self-adjoint.
    163 
    164   Returns:
    165     _Hints object.
    166   """
    167   hints = hints or _Hints()
    168   # If A, B are self-adjoint, then so is A + B.
    169   if hints.is_self_adjoint is None:
    170     is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
    171   else:
    172     is_self_adjoint = hints.is_self_adjoint
    173 
    174   # If A, B are positive definite, then so is A + B.
    175   if hints.is_positive_definite is None:
    176     is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
    177   else:
    178     is_positive_definite = hints.is_positive_definite
    179 
    180   # A positive definite operator is always non-singular.
    181   if is_positive_definite and hints.is_positive_definite is None:
    182     is_non_singular = True
    183   else:
    184     is_non_singular = hints.is_non_singular
    185 
    186   return _Hints(
    187       is_non_singular=is_non_singular,
    188       is_self_adjoint=is_self_adjoint,
    189       is_positive_definite=is_positive_definite)
    190 
    191 
    192 def _static_check_for_same_dimensions(operators):
    193   """ValueError if operators determined to have different dimensions."""
    194   if len(operators) < 2:
    195     return
    196 
    197   domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
    198                        if op.domain_dimension.value is not None]
    199   if len(set(value for name, value in domain_dimensions)) > 1:
    200     raise ValueError("Operators must have the same domain dimension. Found: %s"
    201                      % domain_dimensions)
    202 
    203   range_dimensions = [(op.name, op.range_dimension.value) for op in operators
    204                       if op.range_dimension.value is not None]
    205   if len(set(value for name, value in range_dimensions)) > 1:
    206     raise ValueError("Operators must have the same range dimension. Found: %s" %
    207                      range_dimensions)
    208 
    209 
    210 def _static_check_for_broadcastable_batch_shape(operators):
    211   """ValueError if operators determined to have non-broadcastable shapes."""
    212   if len(operators) < 2:
    213     return
    214 
    215   # This will fail if they cannot be broadcast together.
    216   batch_shape = operators[0].batch_shape
    217   for op in operators[1:]:
    218     batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
    219 
    220 
    221 class _Hints(object):
    222   """Holds 'is_X' flags that every LinearOperator is initialized with."""
    223 
    224   def __init__(self,
    225                is_non_singular=None,
    226                is_positive_definite=None,
    227                is_self_adjoint=None):
    228     self.is_non_singular = is_non_singular
    229     self.is_positive_definite = is_positive_definite
    230     self.is_self_adjoint = is_self_adjoint
    231 
    232 
    233 ################################################################################
    234 # Classes to add two linear operators.
    235 ################################################################################
    236 
    237 
    238 @six.add_metaclass(abc.ABCMeta)
    239 class _Adder(object):
    240   """Abstract base class to add two operators.
    241 
    242   Each `Adder` acts independently, adding everything it can, paying no attention
    243   as to whether another `Adder` could have done the addition more efficiently.
    244   """
    245 
    246   @property
    247   def name(self):
    248     return self.__class__.__name__
    249 
    250   @abc.abstractmethod
    251   def can_add(self, op1, op2):
    252     """Returns `True` if this `Adder` can add `op1` and `op2`.  Else `False`."""
    253     pass
    254 
    255   @abc.abstractmethod
    256   def _add(self, op1, op2, operator_name, hints):
    257     # Derived classes can assume op1 and op2 have been validated, e.g. they have
    258     # the same dtype, and their domain/range dimensions match.
    259     pass
    260 
    261   def add(self, op1, op2, operator_name, hints=None):
    262     """Return new `LinearOperator` acting like `op1 + op2`.
    263 
    264     Args:
    265       op1:  `LinearOperator`
    266       op2:  `LinearOperator`, with `shape` and `dtype` such that adding to
    267         `op1` is allowed.
    268       operator_name:  `String` name to give to returned `LinearOperator`
    269       hints:  `_Hints` object.  Returned `LinearOperator` will be created with
    270         these hints.
    271 
    272     Returns:
    273       `LinearOperator`
    274     """
    275     updated_hints = _infer_hints_allowing_override(op1, op2, hints)
    276 
    277     if operator_name is None:
    278       operator_name = "Add/" + op1.name + "__" + op2.name + "/"
    279 
    280     values = op1.graph_parents + op2.graph_parents
    281     scope_name = self.name
    282     if scope_name.startswith("_"):
    283       scope_name = scope_name[1:]
    284     with ops.name_scope(scope_name, values=values):
    285       return self._add(op1, op2, operator_name, updated_hints)
    286 
    287 
    288 class _AddAndReturnScaledIdentity(_Adder):
    289   """Handles additions resulting in an Identity family member.
    290 
    291   The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
    292   is closed under addition.  This `Adder` respects that, and returns an Identity
    293   """
    294 
    295   def can_add(self, op1, op2):
    296     types = {_type(op1), _type(op2)}
    297     return not types.difference(_IDENTITY_FAMILY)
    298 
    299   def _add(self, op1, op2, operator_name, hints):
    300     # Will build a LinearOperatorScaledIdentity.
    301 
    302     if _type(op1) == _SCALED_IDENTITY:
    303       multiplier_1 = op1.multiplier
    304     else:
    305       multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
    306 
    307     if _type(op2) == _SCALED_IDENTITY:
    308       multiplier_2 = op2.multiplier
    309     else:
    310       multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
    311 
    312     return linear_operator_identity.LinearOperatorScaledIdentity(
    313         num_rows=op1.range_dimension_tensor(),
    314         multiplier=multiplier_1 + multiplier_2,
    315         is_non_singular=hints.is_non_singular,
    316         is_self_adjoint=hints.is_self_adjoint,
    317         is_positive_definite=hints.is_positive_definite,
    318         name=operator_name)
    319 
    320 
    321 class _AddAndReturnDiag(_Adder):
    322   """Handles additions resulting in a Diag operator."""
    323 
    324   def can_add(self, op1, op2):
    325     types = {_type(op1), _type(op2)}
    326     return not types.difference(_DIAG_LIKE)
    327 
    328   def _add(self, op1, op2, operator_name, hints):
    329     return linear_operator_diag.LinearOperatorDiag(
    330         diag=op1.diag_part() + op2.diag_part(),
    331         is_non_singular=hints.is_non_singular,
    332         is_self_adjoint=hints.is_self_adjoint,
    333         is_positive_definite=hints.is_positive_definite,
    334         name=operator_name)
    335 
    336 
    337 class _AddAndReturnTriL(_Adder):
    338   """Handles additions resulting in a TriL operator."""
    339 
    340   def can_add(self, op1, op2):
    341     types = {_type(op1), _type(op2)}
    342     return not types.difference(_DIAG_LIKE.union({_TRIL}))
    343 
    344   def _add(self, op1, op2, operator_name, hints):
    345     if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
    346       op_add_to_tensor, op_other = op1, op2
    347     else:
    348       op_add_to_tensor, op_other = op2, op1
    349 
    350     return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
    351         tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
    352         is_non_singular=hints.is_non_singular,
    353         is_self_adjoint=hints.is_self_adjoint,
    354         is_positive_definite=hints.is_positive_definite,
    355         name=operator_name)
    356 
    357 
    358 class _AddAndReturnMatrix(_Adder):
    359   """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
    360 
    361   def can_add(self, op1, op2):  # pylint: disable=unused-argument
    362     return isinstance(op1, linear_operator.LinearOperator) and isinstance(
    363         op2, linear_operator.LinearOperator)
    364 
    365   def _add(self, op1, op2, operator_name, hints):
    366     if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
    367       op_add_to_tensor, op_other = op1, op2
    368     else:
    369       op_add_to_tensor, op_other = op2, op1
    370     return linear_operator_full_matrix.LinearOperatorFullMatrix(
    371         matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
    372         is_non_singular=hints.is_non_singular,
    373         is_self_adjoint=hints.is_self_adjoint,
    374         is_positive_definite=hints.is_positive_definite,
    375         name=operator_name)
    376 
    377 
    378 ################################################################################
    379 # Constants designating types of LinearOperators
    380 ################################################################################
    381 
    382 # Type name constants for LinearOperator classes.
    383 _IDENTITY = "identity"
    384 _SCALED_IDENTITY = "scaled_identity"
    385 _DIAG = "diag"
    386 _TRIL = "tril"
    387 _MATRIX = "matrix"
    388 
    389 # Groups of operators.
    390 _DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
    391 _IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
    392 # operators with an efficient .add_to_tensor() method.
    393 _EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
    394 
    395 
    396 def _type(operator):
    397   """Returns the type name constant (e.g. _TRIL) for operator."""
    398   if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
    399     return _DIAG
    400   if isinstance(operator,
    401                 linear_operator_lower_triangular.LinearOperatorLowerTriangular):
    402     return _TRIL
    403   if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
    404     return _MATRIX
    405   if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
    406     return _IDENTITY
    407   if isinstance(operator,
    408                 linear_operator_identity.LinearOperatorScaledIdentity):
    409     return _SCALED_IDENTITY
    410   raise TypeError("Operator type unknown: %s" % operator)
    411 
    412 
    413 ################################################################################
    414 # Addition tiers:
    415 # We attempt to use Adders in tier K before K+1.
    416 #
    417 # Organize tiers to
    418 #   (i) reduce O(..) complexity of forming final operator, and
    419 #   (ii) produce the "most efficient" final operator.
    420 # Dev notes:
    421 #  * Results of addition at tier K will be added at tier K or higher.
    422 #  * Tiers may change, and we warn the user that it may change.
    423 ################################################################################
    424 
    425 # Note that the final tier, _AddAndReturnMatrix, will convert everything to a
    426 # dense matrix.  So it is sometimes very inefficient.
    427 _DEFAULT_ADDITION_TIERS = [
    428     [_AddAndReturnScaledIdentity()],
    429     [_AddAndReturnDiag()],
    430     [_AddAndReturnTriL()],
    431     [_AddAndReturnMatrix()],
    432 ]
    433