Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the behavior of the auto-compilation pass."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 from tensorflow.compiler.tests.xla_test import XLATestCase
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.platform import googletest
     31 
     32 CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
     33 
     34 
     35 class ClusteringTest(XLATestCase):
     36 
     37   def testAdd(self):
     38     val1 = np.array([4, 3, 2, 1], dtype=np.float32)
     39     val2 = np.array([5, 6, 7, 8], dtype=np.float32)
     40     expected = val1 + val2
     41     with self.test_session():
     42       with self.test_scope():
     43         input1 = constant_op.constant(val1, name="const1")
     44         input2 = constant_op.constant(val2, name="const2")
     45         output = math_ops.add(input1, input2)
     46       result = output.eval()
     47     self.assertAllClose(result, expected, rtol=1e-3)
     48 
     49   def testAddFromCpuMultiple(self):
     50     val1 = np.array([4, 3, 2, 1]).astype(np.float32)
     51     val2 = np.array([5, 6, 7, 8]).astype(np.float32)
     52     expected = val1 + val2
     53     with self.test_session():
     54       with ops.device(CPU_DEVICE):
     55         input1 = constant_op.constant(val1, name="const1")
     56         input2 = constant_op.constant(val2, name="const2")
     57       with self.test_scope():
     58         output = math_ops.add(input1, input2)
     59       for _ in xrange(10):
     60         result = output.eval()
     61         self.assertAllClose(result, expected, rtol=1e-3)
     62 
     63   def testDeadlock(self):
     64     # Builds a graph of the form:
     65     #  x -> y
     66     #       | \
     67     #       z -> w
     68     # where x and z are placed on the CPU and y and w are placed on the XLA
     69     # device. If y and w are clustered for compilation, then the graph will
     70     # deadlock since the clustered graph will contain a self-loop.
     71     with self.test_session() as sess:
     72       with ops.device(CPU_DEVICE):
     73         x = array_ops.placeholder(dtypes.float32, [2])
     74       with self.test_scope():
     75         y = x * 2
     76       with ops.device(CPU_DEVICE):
     77         z = y * y
     78       with self.test_scope():
     79         w = y + z
     80       result = sess.run(w, {x: [1.5, 0.5]})
     81     self.assertAllClose(result, [12., 2.], rtol=1e-3)
     82 
     83   def testHostMemory(self):
     84     with self.test_session() as sess:
     85       x = array_ops.placeholder(dtypes.int32)
     86       with self.test_scope():
     87         y = x + 1
     88       with ops.device(CPU_DEVICE):
     89         # Place a computation on the CPU, so y and w cannot be merged into the
     90         # same JIT compilation.
     91         z = y * 2
     92       with self.test_scope():
     93         # Argument 'y' is a non-constant output of a previous cluster. Make sure
     94         # it is properly copied to host memory so it can be used as a
     95         # compile-time constant input for this cluster.
     96         w = array_ops.reshape(z, y)
     97       result = sess.run(w, {x: [1, 0]})
     98       expected = np.array([[4], [2]], dtype=np.int32)
     99       self.assertAllClose(expected, result, rtol=1e-3)
    100 
    101 
    102 if __name__ == "__main__":
    103   googletest.main()
    104