Home | History | Annotate | Download | only in util
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities for tf.data options."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 
     22 def _internal_attr_name(name):
     23   return "_" + name
     24 
     25 
     26 class OptionsBase(object):
     27   """Base class for representing a set of tf.data options.
     28 
     29   Attributes:
     30     _options: Stores the option values.
     31   """
     32 
     33   def __init__(self):
     34     # NOTE: Cannot use `self._options` here as we override `__setattr__`
     35     object.__setattr__(self, "_options", {})
     36 
     37   def __eq__(self, other):
     38     if not isinstance(other, self.__class__):
     39       return NotImplemented
     40     for name in set(self._options) | set(other._options):  # pylint: disable=protected-access
     41       if getattr(self, name) != getattr(other, name):
     42         return False
     43     return True
     44 
     45   def __ne__(self, other):
     46     if isinstance(other, self.__class__):
     47       return not self.__eq__(other)
     48     else:
     49       return NotImplemented
     50 
     51   def __setattr__(self, name, value):
     52     if hasattr(self, name):
     53       object.__setattr__(self, name, value)
     54     else:
     55       raise AttributeError(
     56           "Cannot set the property %s on %s." % (name, type(self).__name__))
     57 
     58 
     59 def create_option(name, ty, docstring, default_factory=lambda: None):
     60   """Creates a type-checked property.
     61 
     62   Args:
     63     name: The name to use.
     64     ty: The type to use. The type of the property will be validated when it
     65       is set.
     66     docstring: The docstring to use.
     67     default_factory: A callable that takes no arguments and returns a default
     68       value to use if not set.
     69 
     70   Returns:
     71     A type-checked property.
     72   """
     73 
     74   def get_fn(option):
     75     # pylint: disable=protected-access
     76     if name not in option._options:
     77       option._options[name] = default_factory()
     78     return option._options.get(name)
     79 
     80   def set_fn(option, value):
     81     if not isinstance(value, ty):
     82       raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
     83                       (name, ty, value, type(value)))
     84     option._options[name] = value  # pylint: disable=protected-access
     85 
     86   return property(get_fn, set_fn, None, docstring)
     87 
     88 
     89 def merge_options(*options_list):
     90   """Merges the given options, returning the result as a new options object.
     91 
     92   The input arguments are expected to have a matching type that derives from
     93   `OptionsBase` (and thus each represent a set of options). The method outputs
     94   an object of the same type created by merging the sets of options represented
     95   by the input arguments.
     96 
     97   The sets of options can be merged as long as there does not exist an option
     98   with different non-default values.
     99 
    100   If an option is an instance of `OptionsBase` itself, then this method is
    101   applied recursively to the set of options represented by this option.
    102 
    103   Args:
    104     *options_list: options to merge
    105 
    106   Raises:
    107     TypeError: if the input arguments are incompatible or not derived from
    108       `OptionsBase`
    109     ValueError: if the given options cannot be merged
    110 
    111   Returns:
    112     A new options object which is the result of merging the given options.
    113   """
    114   if len(options_list) < 1:
    115     raise ValueError("At least one options should be provided")
    116   result_type = type(options_list[0])
    117 
    118   for options in options_list:
    119     if not isinstance(options, result_type):
    120       raise TypeError("Incompatible options type: %r vs %r" % (type(options),
    121                                                                result_type))
    122 
    123   if not isinstance(options_list[0], OptionsBase):
    124     raise TypeError("The inputs should inherit from `OptionsBase`")
    125 
    126   default_options = result_type()
    127   result = result_type()
    128   for options in options_list:
    129     # Iterate over all set options and merge the into the result.
    130     for name in options._options:  # pylint: disable=protected-access
    131       this = getattr(result, name)
    132       that = getattr(options, name)
    133       default = getattr(default_options, name)
    134       if that == default:
    135         continue
    136       elif this == default:
    137         setattr(result, name, that)
    138       elif isinstance(this, OptionsBase):
    139         setattr(result, name, merge_options(this, that))
    140       elif this != that:
    141         raise ValueError(
    142             "Cannot merge incompatible values (%r and %r) of option: %s" %
    143             (this, that, name))
    144   return result
    145