Home | History | Annotate | Download | only in util
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for operator dispatch."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import ops
     22 from tensorflow.python.framework import test_util
     23 from tensorflow.python.ops import gen_math_ops
     24 from tensorflow.python.ops import math_ops
     25 from tensorflow.python.platform import googletest
     26 from tensorflow.python.util import dispatch
     27 from tensorflow.python.util.tf_export import tf_export
     28 
     29 
     30 class CustomTensor(object):
     31   """A fake composite tensor class, for testing type-based dispatching."""
     32 
     33   def __init__(self, tensor, score):
     34     self.tensor = ops.convert_to_tensor(tensor)
     35     self.score = score
     36 
     37 
     38 @tf_export("test_op")
     39 @dispatch.add_dispatch_support
     40 def test_op(x, y, z):
     41   """A fake op for testing dispatch of Python ops."""
     42   return x + (2 * y) + (3 * z)
     43 
     44 
     45 @test_util.run_all_in_graph_and_eager_modes
     46 class DispatchTest(test_util.TensorFlowTestCase):
     47 
     48   def testAddDispatchForTypes_With_CppOp(self):
     49     original_handlers = gen_math_ops.add._tf_dispatchers[:]
     50 
     51     # Override the behavior of gen_math_ops.add.
     52     @dispatch.dispatch_for_types(gen_math_ops.add, CustomTensor)
     53     def custom_add(x, y, name=None):  # pylint: disable=unused-variable
     54       return CustomTensor(gen_math_ops.add(x.tensor, y.tensor, name),
     55                           (x.score+y.score) / 2.0)
     56     self.assertEqual(len(math_ops.add._tf_dispatchers),
     57                      len(original_handlers) + 1)
     58 
     59     # Test that we see the overridden behavior when using CustomTensors.
     60     x = CustomTensor([1, 2, 3], 2.0)
     61     y = CustomTensor([7, 8, 2], 0.0)
     62     x_plus_y = gen_math_ops.add(x, y)
     63     self.assertAllEqual(self.evaluate(x_plus_y.tensor), [8, 10, 5])
     64     self.assertNear(x_plus_y.score, 1.0, 0.001)
     65 
     66     # Test that we still get the right behavior when using normal Tensors.
     67     a = [1, 2, 3]
     68     b = [4, 5, 6]
     69     a_plus_b = gen_math_ops.add(a, b)
     70     self.assertAllEqual(a_plus_b, [5, 7, 9])
     71 
     72     # Test that we still get a TypeError or ValueError if we pass some
     73     # type that's not supported by any dispatcher.
     74     with self.assertRaises((TypeError, ValueError)):
     75       gen_math_ops.add(a, None)
     76 
     77     # Clean up
     78     gen_math_ops.add._tf_dispatchers = original_handlers
     79 
     80   def testAddDispatchForTypes_With_PythonOp(self):
     81     original_handlers = test_op._tf_dispatchers[:]
     82 
     83     @dispatch.dispatch_for_types(test_op, CustomTensor)
     84     def override_for_test_op(x, y, z):  # pylint: disable=unused-variable
     85       return CustomTensor(test_op(x.tensor, y.tensor, z.tensor),
     86                           (x.score + y.score + z.score) / 3.0)
     87 
     88     x = CustomTensor([1, 2, 3], 0.2)
     89     y = CustomTensor([7, 8, 2], 0.4)
     90     z = CustomTensor([0, 1, 2], 0.6)
     91 
     92     result = test_op(x, y, z)
     93     self.assertAllEqual(self.evaluate(result.tensor), [15, 21, 13])
     94     self.assertNear(result.score, 0.4, 0.001)
     95 
     96     # Clean up
     97     test_op._tf_dispatchers = original_handlers
     98 
     99   def testDispatchForTypes_SignatureMismatch(self):
    100     with self.assertRaisesRegexp(AssertionError, "The decorated function's "
    101                                  "signature must exactly match.*"):
    102       @dispatch.dispatch_for_types(test_op, CustomTensor)
    103       def override_for_test_op(a, b, c):  # pylint: disable=unused-variable
    104         return CustomTensor(test_op(a.tensor, b.tensor, c.tensor),
    105                             (a.score + b.score + c.score) / 3.0)
    106 
    107   def testDispatchForTypes_OpDoesNotSupportDispatch(self):
    108     def some_op(x, y):
    109       return x + y
    110 
    111     with self.assertRaisesRegexp(AssertionError, "Dispatching not enabled for"):
    112       @dispatch.dispatch_for_types(some_op, CustomTensor)
    113       def override_for_some_op(x, y):  # pylint: disable=unused-variable
    114         return x if x.score > 0 else y
    115 
    116 
    117 if __name__ == "__main__":
    118   googletest.main()
    119 
    120 
    121