Home | History | Annotate | Download | only in util
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tensor utility functions."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import functools
     23 import re
     24 
     25 from tensorflow.python.platform import tf_logging as logging
     26 from tensorflow.python.util import decorator_utils
     27 from tensorflow.python.util import is_in_graph_mode
     28 from tensorflow.python.util import tf_contextlib
     29 from tensorflow.python.util import tf_decorator
     30 from tensorflow.python.util import tf_inspect
     31 
     32 
     33 # Allow deprecation warnings to be silenced temporarily with a context manager.
     34 _PRINT_DEPRECATION_WARNINGS = True
     35 
     36 # Remember which deprecation warnings have been printed already.
     37 _PRINTED_WARNING = {}
     38 
     39 
     40 def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
     41   """Adds a deprecation notice to a docstring for deprecated functions."""
     42   main_text = ['THIS FUNCTION IS DEPRECATED. It will be removed %s.' %
     43                ('in a future version' if date is None else ('after %s' % date))]
     44   if instructions:
     45     main_text.append('Instructions for updating:')
     46   return decorator_utils.add_notice_to_docstring(
     47       doc, instructions,
     48       'DEPRECATED FUNCTION',
     49       '(deprecated)', main_text)
     50 
     51 
     52 def _add_deprecated_arg_notice_to_docstring(doc, date, instructions):
     53   """Adds a deprecation notice to a docstring for deprecated arguments."""
     54   return decorator_utils.add_notice_to_docstring(
     55       doc, instructions,
     56       'DEPRECATED FUNCTION ARGUMENTS',
     57       '(deprecated arguments)', [
     58           'SOME ARGUMENTS ARE DEPRECATED. '
     59           'They will be removed %s.' % (
     60               'in a future version' if date is None else ('after %s' % date)),
     61           'Instructions for updating:'])
     62 
     63 
     64 def _validate_deprecation_args(date, instructions):
     65   if date is not None and not re.match(r'20\d\d-[01]\d-[0123]\d', date):
     66     raise ValueError('Date must be YYYY-MM-DD.')
     67   if not instructions:
     68     raise ValueError('Don\'t deprecate things without conversion instructions!')
     69 
     70 
     71 def _call_location(outer=False):
     72   """Returns call location given level up from current call."""
     73   frame = tf_inspect.currentframe()
     74   if frame:
     75     # CPython internals are available, use them for performance.
     76     # walk back two frames to get to deprecated function caller.
     77     frame = frame.f_back
     78     if frame.f_back:
     79       frame = frame.f_back
     80     if outer and frame.f_back:
     81       frame = frame.f_back
     82     return '%s:%d' % (frame.f_code.co_filename, frame.f_lineno)
     83   else:
     84     # Slow fallback path
     85     stack = tf_inspect.stack(0)  # 0 avoids generating unused context
     86     entry = stack[3 if outer else 2]
     87     return '%s:%d' % (entry[1], entry[2])
     88 
     89 
     90 def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
     91   """Deprecate a symbol in favor of a new name with identical semantics.
     92 
     93   This function is meant to be used when defining a backwards-compatibility
     94   alias for a symbol which has been moved. For example:
     95 
     96   module1.py:
     97   ```python
     98   class NewNameForClass: pass
     99   ```
    100 
    101   module2.py:
    102   ```python
    103   import module1
    104 
    105   DeprecatedNameForClass = deprecated_alias(
    106     deprecated_name='module2.DeprecatedNameForClass',
    107     name='module1.NewNameForClass',
    108     module1.NewNameForClass)
    109   ```
    110 
    111   This function works for classes and functions.
    112 
    113   For classes, it creates a new class which is functionally identical (it
    114   inherits from the original, and overrides its constructor), but which prints
    115   a deprecation warning when an instance is created. It also adds a deprecation
    116   notice to the class' docstring.
    117 
    118   For functions, it returns a function wrapped by `tf_decorator.make_decorator`.
    119   That function prints a warning when used, and has a deprecation notice in its
    120   docstring. This is more or less equivalent (the deprecation warning has
    121   slightly different text) to writing:
    122 
    123   ```python
    124   @deprecated
    125   def deprecated_alias(original_args):
    126     real_function(original_args)
    127   ```
    128 
    129   Args:
    130     deprecated_name: The name of the symbol that is being deprecated, to be used
    131       in the warning message. This should be its fully qualified name to avoid
    132       confusion.
    133     name: The name of the symbol that is to be used instead of the deprecated
    134       name. This should be a fully qualified name to avoid confusion.
    135     func_or_class: The (non-deprecated) class or function for which a deprecated
    136       alias should be created.
    137     warn_once: If True (the default), only print a deprecation warning the first
    138       time this function is used, or the class is instantiated.
    139 
    140   Returns:
    141     A wrapped version of `func_or_class` which prints a deprecation warning on
    142     use and has a modified docstring.
    143   """
    144   if tf_inspect.isclass(func_or_class):
    145 
    146     # Make a new class with __init__ wrapped in a warning.
    147     class NewClass(func_or_class):  # pylint: disable=missing-docstring
    148       __doc__ = decorator_utils.add_notice_to_docstring(
    149           func_or_class.__doc__, 'Please use %s instead.' % name,
    150           'DEPRECATED CLASS',
    151           '(deprecated)', ['THIS CLASS IS DEPRECATED. '
    152                            'It will be removed in a future version. '])
    153       __name__ = func_or_class.__name__
    154       __module__ = _call_location(outer=True)
    155 
    156       def __init__(self, *args, **kwargs):
    157         if hasattr(NewClass.__init__, '__func__'):
    158           # Python 2
    159           NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
    160         else:
    161           # Python 3
    162           NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
    163 
    164         if _PRINT_DEPRECATION_WARNINGS:
    165           # We're making the alias as we speak. The original may have other
    166           # aliases, so we cannot use it to check for whether it's already been
    167           # warned about.
    168           if NewClass.__init__ not in _PRINTED_WARNING:
    169             if warn_once:
    170               _PRINTED_WARNING[NewClass.__init__] = True
    171             logging.warning(
    172                 'From %s: The name %s is deprecated. Please use %s instead.\n',
    173                 _call_location(), deprecated_name, name)
    174         super(NewClass, self).__init__(*args, **kwargs)
    175 
    176     return NewClass
    177   else:
    178     decorator_utils.validate_callable(func_or_class, 'deprecated')
    179 
    180     # Make a wrapper for the original
    181     @functools.wraps(func_or_class)
    182     def new_func(*args, **kwargs):  # pylint: disable=missing-docstring
    183       if _PRINT_DEPRECATION_WARNINGS:
    184         # We're making the alias as we speak. The original may have other
    185         # aliases, so we cannot use it to check for whether it's already been
    186         # warned about.
    187         if new_func not in _PRINTED_WARNING:
    188           if warn_once:
    189             _PRINTED_WARNING[new_func] = True
    190           logging.warning(
    191               'From %s: The name %s is deprecated. Please use %s instead.\n',
    192               _call_location(), deprecated_name, name)
    193       return func_or_class(*args, **kwargs)
    194     return tf_decorator.make_decorator(
    195         func_or_class, new_func, 'deprecated',
    196         _add_deprecated_function_notice_to_docstring(
    197             func_or_class.__doc__, None, 'Please use %s instead.' % name))
    198 
    199 
    200 def deprecated(date, instructions, warn_once=True):
    201   """Decorator for marking functions or methods deprecated.
    202 
    203   This decorator logs a deprecation warning whenever the decorated function is
    204   called. It has the following format:
    205 
    206     <function> (from <module>) is deprecated and will be removed after <date>.
    207     Instructions for updating:
    208     <instructions>
    209 
    210   If `date` is None, 'after <date>' is replaced with 'in a future version'.
    211   <function> will include the class name if it is a method.
    212 
    213   It also edits the docstring of the function: ' (deprecated)' is appended
    214   to the first line of the docstring and a deprecation notice is prepended
    215   to the rest of the docstring.
    216 
    217   Args:
    218     date: String or None. The date the function is scheduled to be removed.
    219       Must be ISO 8601 (YYYY-MM-DD), or None.
    220     instructions: String. Instructions on how to update code using the
    221       deprecated function.
    222     warn_once: Boolean. Set to `True` to warn only the first time the decorated
    223       function is called. Otherwise, every call will log a warning.
    224 
    225   Returns:
    226     Decorated function or method.
    227 
    228   Raises:
    229     ValueError: If date is not None or in ISO 8601 format, or instructions are
    230       empty.
    231   """
    232   _validate_deprecation_args(date, instructions)
    233 
    234   def deprecated_wrapper(func):
    235     """Deprecation wrapper."""
    236     decorator_utils.validate_callable(func, 'deprecated')
    237     @functools.wraps(func)
    238     def new_func(*args, **kwargs):  # pylint: disable=missing-docstring
    239       if _PRINT_DEPRECATION_WARNINGS:
    240         if func not in _PRINTED_WARNING:
    241           if warn_once:
    242             _PRINTED_WARNING[func] = True
    243           logging.warning(
    244               'From %s: %s (from %s) is deprecated and will be removed %s.\n'
    245               'Instructions for updating:\n%s',
    246               _call_location(), decorator_utils.get_qualified_name(func),
    247               func.__module__,
    248               'in a future version' if date is None else ('after %s' % date),
    249               instructions)
    250       return func(*args, **kwargs)
    251     return tf_decorator.make_decorator(
    252         func, new_func, 'deprecated',
    253         _add_deprecated_function_notice_to_docstring(func.__doc__, date,
    254                                                      instructions))
    255   return deprecated_wrapper
    256 
    257 
    258 DeprecatedArgSpec = collections.namedtuple(
    259     'DeprecatedArgSpec', ['position', 'has_ok_value', 'ok_value'])
    260 
    261 
    262 def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
    263                     **kwargs):
    264   """Decorator for marking specific function arguments as deprecated.
    265 
    266   This decorator logs a deprecation warning whenever the decorated function is
    267   called with the deprecated argument. It has the following format:
    268 
    269     Calling <function> (from <module>) with <arg> is deprecated and will be
    270     removed after <date>. Instructions for updating:
    271       <instructions>
    272 
    273   If `date` is None, 'after <date>' is replaced with 'in a future version'.
    274   <function> includes the class name if it is a method.
    275 
    276   It also edits the docstring of the function: ' (deprecated arguments)' is
    277   appended to the first line of the docstring and a deprecation notice is
    278   prepended to the rest of the docstring.
    279 
    280   Args:
    281     date: String or None. The date the function is scheduled to be removed.
    282       Must be ISO 8601 (YYYY-MM-DD), or None.
    283     instructions: String. Instructions on how to update code using the
    284       deprecated function.
    285     *deprecated_arg_names_or_tuples: String or 2-Tuple(String,
    286       [ok_vals]).  The string is the deprecated argument name.
    287       Optionally, an ok-value may be provided.  If the user provided
    288       argument equals this value, the warning is suppressed.
    289     **kwargs: If `warn_once=False` is passed, every call with a deprecated
    290       argument will log a warning. The default behavior is to only warn the
    291       first time the function is called with any given deprecated argument.
    292       All other kwargs raise `ValueError`.
    293 
    294   Returns:
    295     Decorated function or method.
    296 
    297   Raises:
    298     ValueError: If date is not None or in ISO 8601 format, instructions are
    299       empty, the deprecated arguments are not present in the function
    300       signature, the second element of a deprecated_tuple is not a
    301       list, or if a kwarg other than `warn_once` is passed.
    302   """
    303   _validate_deprecation_args(date, instructions)
    304   if not deprecated_arg_names_or_tuples:
    305     raise ValueError('Specify which argument is deprecated.')
    306   if kwargs and list(kwargs.keys()) != ['warn_once']:
    307     kwargs.pop('warn_once', None)
    308     raise ValueError('Illegal argument to deprecated_args: %s' % kwargs)
    309   warn_once = kwargs.get('warn_once', True)
    310 
    311   def _get_arg_names_to_ok_vals():
    312     """Returns a dict mapping arg_name to DeprecatedArgSpec w/o position."""
    313     d = {}
    314     for name_or_tuple in deprecated_arg_names_or_tuples:
    315       if isinstance(name_or_tuple, tuple):
    316         d[name_or_tuple[0]] = DeprecatedArgSpec(-1, True, name_or_tuple[1])
    317       else:
    318         d[name_or_tuple] = DeprecatedArgSpec(-1, False, None)
    319     return d
    320 
    321   def _get_deprecated_positional_arguments(names_to_ok_vals, arg_spec):
    322     """Builds a dictionary from deprecated arguments to their spec.
    323 
    324     Returned dict is keyed by argument name.
    325     Each value is a DeprecatedArgSpec with the following fields:
    326        position: The zero-based argument position of the argument
    327          within the signature.  None if the argument isn't found in
    328          the signature.
    329        ok_values:  Values of this argument for which warning will be
    330          suppressed.
    331 
    332     Args:
    333       names_to_ok_vals: dict from string arg_name to a list of values,
    334         possibly empty, which should not elicit a warning.
    335       arg_spec: Output from tf_inspect.getargspec on the called function.
    336 
    337     Returns:
    338       Dictionary from arg_name to DeprecatedArgSpec.
    339     """
    340     arg_name_to_pos = dict(
    341         (name, pos) for (pos, name) in enumerate(arg_spec.args))
    342     deprecated_positional_args = {}
    343     for arg_name, spec in iter(names_to_ok_vals.items()):
    344       if arg_name in arg_name_to_pos:
    345         pos = arg_name_to_pos[arg_name]
    346         deprecated_positional_args[arg_name] = DeprecatedArgSpec(
    347             pos, spec.has_ok_value, spec.ok_value)
    348     return deprecated_positional_args
    349 
    350   def deprecated_wrapper(func):
    351     """Deprecation decorator."""
    352     decorator_utils.validate_callable(func, 'deprecated_args')
    353     deprecated_arg_names = _get_arg_names_to_ok_vals()
    354 
    355     arg_spec = tf_inspect.getargspec(func)
    356     deprecated_positions = _get_deprecated_positional_arguments(
    357         deprecated_arg_names, arg_spec)
    358 
    359     is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
    360     is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names
    361 
    362     if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated
    363         != len(deprecated_arg_names_or_tuples)):
    364       known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords]
    365       missing_args = [arg_name for arg_name in deprecated_arg_names
    366                       if arg_name not in known_args]
    367       raise ValueError('The following deprecated arguments are not present '
    368                        'in the function signature: %s. '
    369                        'Found next arguments: %s.' % (missing_args, known_args))
    370 
    371     def _same_value(a, b):
    372       """A comparison operation that works for multiple object types.
    373 
    374       Returns True for two empty lists, two numeric values with the
    375       same value, etc.
    376 
    377       Returns False for (pd.DataFrame, None), and other pairs which
    378       should not be considered equivalent.
    379 
    380       Args:
    381         a: value one of the comparison.
    382         b: value two of the comparison.
    383 
    384       Returns:
    385         A boolean indicating whether the two inputs are the same value
    386         for the purposes of deprecation.
    387       """
    388       if a is b:
    389         return True
    390       try:
    391         equality = a == b
    392         if isinstance(equality, bool):
    393           return equality
    394       except TypeError:
    395         return False
    396       return False
    397 
    398     @functools.wraps(func)
    399     def new_func(*args, **kwargs):
    400       """Deprecation wrapper."""
    401       # TODO(apassos) figure out a way to have reasonable performance with
    402       # deprecation warnings and eager mode.
    403       if is_in_graph_mode.IS_IN_GRAPH_MODE() and _PRINT_DEPRECATION_WARNINGS:
    404         invalid_args = []
    405         named_args = tf_inspect.getcallargs(func, *args, **kwargs)
    406         for arg_name, spec in iter(deprecated_positions.items()):
    407           if (spec.position < len(args) and
    408               not (spec.has_ok_value and
    409                    _same_value(named_args[arg_name], spec.ok_value))):
    410             invalid_args.append(arg_name)
    411         if is_varargs_deprecated and len(args) > len(arg_spec.args):
    412           invalid_args.append(arg_spec.varargs)
    413         if is_kwargs_deprecated and kwargs:
    414           invalid_args.append(arg_spec.keywords)
    415         for arg_name in deprecated_arg_names:
    416           if (arg_name in kwargs and
    417               not (deprecated_positions[arg_name].has_ok_value and
    418                    _same_value(named_args[arg_name],
    419                                deprecated_positions[arg_name].ok_value))):
    420             invalid_args.append(arg_name)
    421         for arg_name in invalid_args:
    422           if (func, arg_name) not in _PRINTED_WARNING:
    423             if warn_once:
    424               _PRINTED_WARNING[(func, arg_name)] = True
    425             logging.warning(
    426                 'From %s: calling %s (from %s) with %s is deprecated and will '
    427                 'be removed %s.\nInstructions for updating:\n%s',
    428                 _call_location(), decorator_utils.get_qualified_name(func),
    429                 func.__module__, arg_name,
    430                 'in a future version' if date is None else ('after %s' % date),
    431                 instructions)
    432       return func(*args, **kwargs)
    433     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    434                                        _add_deprecated_arg_notice_to_docstring(
    435                                            func.__doc__, date, instructions))
    436   return deprecated_wrapper
    437 
    438 
    439 def deprecated_arg_values(date, instructions, warn_once=True,
    440                           **deprecated_kwargs):
    441   """Decorator for marking specific function argument values as deprecated.
    442 
    443   This decorator logs a deprecation warning whenever the decorated function is
    444   called with the deprecated argument values. It has the following format:
    445 
    446     Calling <function> (from <module>) with <arg>=<value> is deprecated and
    447     will be removed after <date>. Instructions for updating:
    448       <instructions>
    449 
    450   If `date` is None, 'after <date>' is replaced with 'in a future version'.
    451   <function> will include the class name if it is a method.
    452 
    453   It also edits the docstring of the function: ' (deprecated arguments)' is
    454   appended to the first line of the docstring and a deprecation notice is
    455   prepended to the rest of the docstring.
    456 
    457   Args:
    458     date: String or None. The date the function is scheduled to be removed.
    459       Must be ISO 8601 (YYYY-MM-DD), or None
    460     instructions: String. Instructions on how to update code using the
    461       deprecated function.
    462     warn_once: If `True`, warn only the first time this function is called with
    463       deprecated argument values. Otherwise, every call (with a deprecated
    464       argument value) will log a warning.
    465     **deprecated_kwargs: The deprecated argument values.
    466 
    467   Returns:
    468     Decorated function or method.
    469 
    470   Raises:
    471     ValueError: If date is not None or in ISO 8601 format, or instructions are
    472       empty.
    473   """
    474   _validate_deprecation_args(date, instructions)
    475   if not deprecated_kwargs:
    476     raise ValueError('Specify which argument values are deprecated.')
    477 
    478   def deprecated_wrapper(func):
    479     """Deprecation decorator."""
    480     decorator_utils.validate_callable(func, 'deprecated_arg_values')
    481     @functools.wraps(func)
    482     def new_func(*args, **kwargs):
    483       """Deprecation wrapper."""
    484       if _PRINT_DEPRECATION_WARNINGS:
    485         named_args = tf_inspect.getcallargs(func, *args, **kwargs)
    486         for arg_name, arg_value in deprecated_kwargs.items():
    487           if arg_name in named_args and named_args[arg_name] == arg_value:
    488             if (func, arg_name) not in _PRINTED_WARNING:
    489               if warn_once:
    490                 _PRINTED_WARNING[(func, arg_name)] = True
    491               logging.warning(
    492                   'From %s: calling %s (from %s) with %s=%s is deprecated and '
    493                   'will be removed %s.\nInstructions for updating:\n%s',
    494                   _call_location(), decorator_utils.get_qualified_name(func),
    495                   func.__module__, arg_name, arg_value, 'in a future version'
    496                   if date is None else ('after %s' % date), instructions)
    497       return func(*args, **kwargs)
    498     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    499                                        _add_deprecated_arg_notice_to_docstring(
    500                                            func.__doc__, date, instructions))
    501   return deprecated_wrapper
    502 
    503 
    504 def deprecated_argument_lookup(new_name, new_value, old_name, old_value):
    505   """Looks up deprecated argument name and ensures both are not used.
    506 
    507   Args:
    508     new_name: new name of argument
    509     new_value: value of new argument (or None if not used)
    510     old_name: old name of argument
    511     old_value: value of old argument (or None if not used)
    512   Returns:
    513     The effective argument that should be used.
    514   Raises:
    515     ValueError: if new_value and old_value are both non-null
    516   """
    517   if old_value is not None:
    518     if new_value is not None:
    519       raise ValueError("Cannot specify both '%s' and '%s'" %
    520                        (old_name, new_name))
    521     return old_value
    522   return new_value
    523 
    524 
    525 def rewrite_argument_docstring(old_doc, old_argument, new_argument):
    526   return old_doc.replace('`%s`' % old_argument, '`%s`' % new_argument).replace(
    527       '%s:' % old_argument, '%s:' % new_argument)
    528 
    529 
    530 @tf_contextlib.contextmanager
    531 def silence():
    532   """Temporarily silence deprecation warnings."""
    533   global _PRINT_DEPRECATION_WARNINGS
    534   print_deprecation_warnings = _PRINT_DEPRECATION_WARNINGS
    535   _PRINT_DEPRECATION_WARNINGS = False
    536   yield
    537   _PRINT_DEPRECATION_WARNINGS = print_deprecation_warnings
    538