Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Test cases for ternary operators."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.platform import googletest
     28 
     29 
     30 class TernaryOpsTest(XLATestCase):
     31 
     32   def _testTernary(self, op, a, b, c, expected):
     33     with self.test_session() as session:
     34       with self.test_scope():
     35         pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
     36         pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
     37         pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
     38         output = op(pa, pb, pc)
     39       result = session.run(output, {pa: a, pb: b, pc: c})
     40       self.assertAllClose(result, expected, rtol=1e-3)
     41 
     42   def testLinspace(self):
     43     self._testTernary(
     44         math_ops.linspace,
     45         np.float32(1),
     46         np.float32(2),
     47         np.int32(1),
     48         expected=np.array([1], dtype=np.float32))
     49     self._testTernary(
     50         math_ops.linspace,
     51         np.float32(1),
     52         np.float32(4),
     53         np.int32(3),
     54         expected=np.array([1, 2.5, 4], dtype=np.float32))
     55 
     56   def testRange(self):
     57     self._testTernary(
     58         math_ops.range,
     59         np.int32(1),
     60         np.int32(2),
     61         np.int32(1),
     62         expected=np.array([1], dtype=np.int32))
     63     self._testTernary(
     64         math_ops.range,
     65         np.int32(1),
     66         np.int32(7),
     67         np.int32(2),
     68         expected=np.array([1, 3, 5], dtype=np.int32))
     69 
     70   def testSelect(self):
     71     self._testTernary(
     72         array_ops.where,
     73         np.array(0, dtype=np.bool),
     74         np.array(2, dtype=np.float32),
     75         np.array(7, dtype=np.float32),
     76         expected=np.array(7, dtype=np.float32))
     77 
     78     self._testTernary(
     79         array_ops.where,
     80         np.array(1, dtype=np.bool),
     81         np.array([1, 2, 3, 4], dtype=np.float32),
     82         np.array([5, 6, 7, 8], dtype=np.float32),
     83         expected=np.array([1, 2, 3, 4], dtype=np.float32))
     84 
     85     self._testTernary(
     86         array_ops.where,
     87         np.array(0, dtype=np.bool),
     88         np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
     89         np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
     90         expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32))
     91 
     92     self._testTernary(
     93         array_ops.where,
     94         np.array([0, 1, 1, 0], dtype=np.bool),
     95         np.array([1, 2, 3, 4], dtype=np.float32),
     96         np.array([5, 6, 7, 8], dtype=np.float32),
     97         expected=np.array([5, 2, 3, 8], dtype=np.float32))
     98 
     99     self._testTernary(
    100         array_ops.where,
    101         np.array([0, 1, 0], dtype=np.bool),
    102         np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
    103         np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
    104         expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
    105 
    106   def testSlice(self):
    107     for dtype in self.numeric_types:
    108       self._testTernary(
    109           array_ops.slice,
    110           np.array([[], [], []], dtype=dtype),
    111           np.array([1, 0], dtype=np.int32),
    112           np.array([2, 0], dtype=np.int32),
    113           expected=np.array([[], []], dtype=dtype))
    114 
    115       self._testTernary(
    116           array_ops.slice,
    117           np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype),
    118           np.array([0, 1], dtype=np.int32),
    119           np.array([2, 1], dtype=np.int32),
    120           expected=np.array([[2], [5]], dtype=dtype))
    121 
    122 
    123 if __name__ == "__main__":
    124   googletest.main()
    125