1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.compare_and_bitpack_op.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.ops import math_ops 24 from tensorflow.python.platform import test 25 26 27 class CompareAndBitpackTest(test.TestCase): 28 29 def _testCompareAndBitpack(self, 30 x, threshold, 31 truth, 32 expected_err_re=None): 33 with self.test_session(use_gpu=True): 34 ans = math_ops.compare_and_bitpack(x, threshold) 35 if expected_err_re is None: 36 tf_ans = ans.eval() 37 self.assertShapeEqual(truth, ans) 38 self.assertAllEqual(tf_ans, truth) 39 else: 40 with self.assertRaisesOpError(expected_err_re): 41 ans.eval() 42 43 def _testBasic(self, dtype): 44 rows = 371 45 cols = 294 46 x = np.random.randn(rows, cols * 8) 47 if dtype == np.bool: 48 x = x > 0 49 else: 50 x = x.astype(dtype) 51 threshold = dtype(0) 52 # np.packbits flattens the tensor, so we reshape it back to the 53 # expected dimensions. 54 truth = np.packbits(x > threshold).reshape(rows, cols) 55 self._testCompareAndBitpack(x, threshold, truth) 56 57 def testBasicFloat32(self): 58 self._testBasic(np.float32) 59 60 def testBasicFloat64(self): 61 self._testBasic(np.float64) 62 63 def testBasicFloat16(self): 64 self._testBasic(np.float16) 65 66 def testBasicBool(self): 67 self._testBasic(np.bool) 68 69 def testBasicInt8(self): 70 self._testBasic(np.int8) 71 72 def testBasicInt16(self): 73 self._testBasic(np.int16) 74 75 def testBasicInt32(self): 76 self._testBasic(np.int32) 77 78 def testBasicInt64(self): 79 self._testBasic(np.int64) 80 81 82 if __name__ == "__main__": 83 test.main() 84