Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Script Language Operators. See the @{$python/script_ops} guide.
     17 
     18 @@py_func
     19 """
     20 
     21 # pylint: disable=g-bad-name
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import threading
     27 
     28 import numpy as np
     29 import six
     30 
     31 from tensorflow.python import pywrap_tensorflow
     32 from tensorflow.python.eager import context
     33 from tensorflow.python.framework import function
     34 from tensorflow.python.framework import ops
     35 from tensorflow.python.ops import gen_script_ops
     36 from tensorflow.python.util import nest
     37 from tensorflow.python.util.tf_export import tf_export
     38 
     39 
     40 class EagerFunc(object):
     41   """A wrapper for a function owned by an EagerPyFunc."""
     42 
     43   def __init__(self, func, Tout):
     44     """Constructs an EagerFunc.
     45 
     46     Args:
     47       func: The function to wrap.
     48       Tout: A list of datatypes for the output; an empty list if the output is
     49             None.
     50     """
     51     self._func = func
     52     self._out_dtypes = Tout
     53 
     54   def __call__(self, on_gpu, args):
     55     """Passes `args` to `self._func`, which is executed eagerly."""
     56     with context.eager_mode():
     57       ret = self._func(*args)
     58       maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu()
     59       if isinstance(ret, (tuple, list)):
     60         return [
     61             maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype))
     62             for (x, dtype) in zip(ret, self._out_dtypes)
     63         ]
     64       elif ret is None:
     65         return ret
     66       else:
     67         return maybe_copy_to_gpu(
     68             ops.convert_to_tensor(ret, dtype=self._out_dtypes[0]))
     69 
     70 
     71 class FuncRegistry(object):
     72   """A helper class to keep track of registered py functions.
     73 
     74   FuncRegistry keeps a map from unique tokens (string) to python
     75   functions, which takes numpy arrays and outputs numpy arrays.
     76   """
     77 
     78   def __init__(self):
     79     self._lock = threading.Lock()
     80     self._unique_id = 0  # GUARDED_BY(self._lock)
     81     self._funcs = {}
     82 
     83   def insert(self, func):
     84     """Registers `func` and returns a unique token for this entry."""
     85     token = self._next_unique_token()
     86     self._funcs[token] = func
     87     return token
     88 
     89   def remove(self, token):
     90     """Removes the registered function corresponding to `token`."""
     91     self._funcs.pop(token, None)
     92 
     93   @staticmethod
     94   def _convert(value, dtype=None):
     95     """Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
     96 
     97     Numpy pads with zeros when using string and unicode dtypes if different
     98     components of a tensor have different lengths.  This is bad: ignoring the
     99     padding is wrong for text data, and removing the padding is wrong for binary
    100     data.  To avoid this bug, we redo the conversion using an object dtype.
    101     Additionally, we convert unicode strings to (byte-)strings for
    102     compatibility.
    103 
    104     Args:
    105       value: Value to convert to a numpy array.
    106       dtype: (Optional.) Desired NumPy type for the returned value.
    107 
    108     Returns:
    109       A numpy array.
    110     """
    111     result = np.asarray(value, dtype=dtype, order="C")
    112     if result.dtype.char == "S" and result is not value:
    113       return np.asarray(value, order="C", dtype=object)
    114     elif result.dtype.char == "U" and result is not value:
    115       value = np.vectorize(lambda x: x.encode("utf8"))(value)
    116       return np.asarray(value, order="C", dtype=object)
    117     elif result.dtype.char == "U":
    118       return result.astype(np.bytes_)
    119     else:
    120       return result
    121 
    122   def __call__(self, token, on_gpu, args):
    123     """Calls the registered function for `token` with args.
    124 
    125     Args:
    126       token: A key into this `FuncRegistry` identifying which function to call.
    127       on_gpu: A boolean indicating whether or not `token`'s corresponding
    128         operation was placed on GPU; only used if the function registered for
    129         `token` is an `EagerPyFunc`.
    130       args: The arguments to pass to the function registered for `token`.
    131 
    132     Returns:
    133       The output of the function registered for `token`.
    134 
    135     Raises:
    136       ValueError: if no function is registered for `token`.
    137     """
    138     func = self._funcs[token]
    139     if func is None:
    140       raise ValueError("callback %s is not found" % token)
    141     if isinstance(func, EagerFunc):
    142       return func(on_gpu, args)
    143     else:
    144       ret = func(*args)
    145       # Strings seem to lead to a memory leak here if they're not wrapped in a
    146       # list.
    147       if isinstance(ret, six.binary_type):
    148         ret = [ret]
    149       # Ensures that we return either a single numpy array or a list of numpy
    150       # arrays.
    151       if isinstance(ret, (tuple, list)):
    152         return [self._convert(x) for x in ret]
    153       else:
    154         return self._convert(ret)
    155 
    156   def size(self):
    157     """Returns how many functions are currently registered."""
    158     return len(self._funcs)
    159 
    160   def _next_unique_token(self):
    161     """Returns a unique token."""
    162     with self._lock:
    163       uid = self._unique_id
    164       self._unique_id += 1
    165     return "pyfunc_%d" % uid
    166 
    167 # Global registry for py functions.
    168 _py_funcs = FuncRegistry()
    169 
    170 pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
    171 
    172 
    173 class CleanupFunc(object):
    174   """A helper class to remove a registered function from _py_funcs."""
    175 
    176   def __init__(self, token):
    177     self._token = token
    178 
    179   def __del__(self):
    180     if _py_funcs is not None:
    181       # If _py_funcs is None, the program is most likely in shutdown, and the
    182       # _py_funcs object has been destroyed already.
    183       _py_funcs.remove(self._token)
    184 
    185 
    186 def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
    187   """See documentation for py_func and eager_py_func."""
    188 
    189   is_list_or_tuple = False
    190   if isinstance(Tout, (list, tuple)):
    191     is_list_or_tuple = True
    192   else:
    193     Tout = [Tout]
    194 
    195   if eager:
    196     func = EagerFunc(func, Tout)
    197 
    198   token = _py_funcs.insert(func)
    199   # We tie the registered function's lifetime with the current default graph,
    200   # i.e., when the current graph is destroyed, we remove its py funcs.
    201   graph = ops.get_default_graph()
    202 
    203   # pylint: disable=protected-access
    204   while isinstance(graph, function._FuncGraph):
    205     # If the py_func was declared inside a _FuncGraph, its lifetime should be
    206     # bound to that of the outer graph instead.
    207     graph = graph._outer_graph
    208 
    209   cleanup = CleanupFunc(token)
    210 
    211   # TODO(zhifengc): Consider adding a Graph method to collect
    212   # `cleanup` objects in one of its member.
    213   if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
    214     graph._cleanup_py_funcs_used_in_graph = []
    215 
    216   # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
    217   # will be destroyed and their __del__ will remove the 'token' from
    218   # the funcs registry.
    219   graph._cleanup_py_funcs_used_in_graph.append(cleanup)
    220   # pylint: enable=protected-access
    221 
    222   # pylint: disable=protected-access
    223   if eager:
    224     result = gen_script_ops._eager_py_func(
    225         input=inp, token=token, Tout=Tout, name=name)
    226   else:
    227     if stateful:
    228       result = gen_script_ops._py_func(
    229           input=inp, token=token, Tout=Tout, name=name)
    230     else:
    231       result = gen_script_ops._py_func_stateless(
    232           input=inp, token=token, Tout=Tout, name=name)
    233   # pylint: enable=protected-access
    234   return result if is_list_or_tuple else result[0]
    235 
    236 
    237 def eager_py_func(func, inp, Tout, name=None):
    238   """Wraps a python function into a TensorFlow op.
    239 
    240   When the returned op is executed, `func` is invoked with eager execution
    241   enabled. Inputs are Tensor objects and func must return None or objects
    242   that may be converted to Tensor objects.
    243 
    244   This function has the same limitations as `py_func` with respect to
    245   serialization and distribution.
    246 
    247   Args:
    248     func: A Python function which accepts a list of `Tensor` objects
    249       having element types that match the corresponding `tf.Tensor` objects
    250       in `inp` and returns a list of `Tensor` objects (or a single
    251       `Tensor`, or `None`) having element types that match the
    252       corresponding values in `Tout`.
    253     inp: A list of `Tensor` objects.
    254     Tout: A list or tuple of tensorflow data types or a single tensorflow data
    255       type if there is only one, indicating what `func` returns; an empty list
    256       if no value is returned (i.e., if the return value is `None`).
    257     name: A name for the operation (optional).
    258 
    259   Returns:
    260     A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
    261     if `func` returns None.
    262   """
    263   return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
    264 
    265 
    266 @tf_export("py_func")
    267 def py_func(func, inp, Tout, stateful=True, name=None):
    268   """Wraps a python function and uses it as a TensorFlow op.
    269 
    270   Given a python function `func`, which takes numpy arrays as its
    271   arguments and returns numpy arrays as its outputs, wrap this function as an
    272   operation in a TensorFlow graph. The following snippet constructs a simple
    273   TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
    274   in the graph:
    275 
    276   ```python
    277   def my_func(x):
    278     # x will be a numpy array with the contents of the placeholder below
    279     return np.sinh(x)
    280   input = tf.placeholder(tf.float32)
    281   y = tf.py_func(my_func, [input], tf.float32)
    282   ```
    283 
    284   **N.B.** The `tf.py_func()` operation has the following known limitations:
    285 
    286   * The body of the function (i.e. `func`) will not be serialized in a
    287     `GraphDef`. Therefore, you should not use this function if you need to
    288     serialize your model and restore it in a different environment.
    289 
    290   * The operation must run in the same address space as the Python program
    291     that calls `tf.py_func()`. If you are using distributed TensorFlow, you
    292     must run a `tf.train.Server` in the same process as the program that calls
    293     `tf.py_func()` and you must pin the created operation to a device in that
    294     server (e.g. using `with tf.device():`).
    295 
    296   Args:
    297     func: A Python function, which accepts `ndarray` objects as arguments and
    298       returns a list of `ndarray` objects (or a single `ndarray`). This function
    299       must accept as many arguments as there are tensors in `inp`, and these
    300       argument types will match the corresponding `tf.Tensor` objects
    301       in `inp`. The returns `ndarray`s must match the number and types defined
    302       `Tout`.
    303       Important Note: Input and output numpy `ndarray`s of `func` are not
    304       guaranteed to be copies. In some cases their underlying memory will be
    305       shared with the corresponding TensorFlow tensors.
    306       In-place modification or storing `func` input or return values in
    307       python datastructures without explicit (np.)copy
    308       can have non-deterministic consequences.
    309     inp: A list of `Tensor` objects.
    310     Tout: A list or tuple of tensorflow data types or a single tensorflow data
    311       type if there is only one, indicating what `func` returns.
    312     stateful: (Boolean.) If True, the function should be considered stateful.
    313       If a function is stateless, when given the same input it will return the
    314       same output and have no observable side effects. Optimizations such as
    315       common subexpression elimination are only performed on stateless
    316       operations.
    317     name: A name for the operation (optional).
    318 
    319   Returns:
    320     A list of `Tensor` or a single `Tensor` which `func` computes.
    321   """
    322   if context.in_eager_mode():
    323     result = func(*[x.numpy() for x in inp])
    324     result = nest.flatten(result)
    325 
    326     return [x if x is None else ops.convert_to_tensor(x) for x in result]
    327 
    328   return _internal_py_func(
    329       func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
    330 
    331 
    332 ops.NotDifferentiable("PyFunc")
    333 ops.NotDifferentiable("PyFuncStateless")
    334