1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for StringToNumber op from parsing_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import test_util 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import parsing_ops 25 from tensorflow.python.platform import test 26 27 _ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: " 28 29 30 class StringToNumberOpTest(test.TestCase): 31 32 def _test(self, tf_type, good_pairs, bad_pairs): 33 with self.cached_session(): 34 # Build a small testing graph. 35 input_string = array_ops.placeholder(dtypes.string) 36 output = parsing_ops.string_to_number( 37 input_string, out_type=tf_type) 38 39 # Check all the good input/output pairs. 40 for instr, outnum in good_pairs: 41 result, = output.eval(feed_dict={input_string: [instr]}) 42 self.assertAllClose([outnum], [result]) 43 44 # Check that the bad inputs produce the right errors. 45 for instr, outstr in bad_pairs: 46 with self.assertRaisesOpError(outstr): 47 output.eval(feed_dict={input_string: [instr]}) 48 49 @test_util.run_deprecated_v1 50 def testToFloat(self): 51 self._test(dtypes.float32, 52 [("0", 0), ("3", 3), ("-1", -1), 53 ("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5), 54 ("3.40282e+38", 3.40282e+38), 55 # Greater than max value of float. 56 ("3.40283e+38", float("INF")), 57 ("-3.40283e+38", float("-INF")), 58 # Less than min value of float. 59 ("NAN", float("NAN")), 60 ("INF", float("INF"))], 61 [("10foobar", _ERROR_MESSAGE + "10foobar")]) 62 63 @test_util.run_deprecated_v1 64 def testToDouble(self): 65 self._test(dtypes.float64, 66 [("0", 0), ("3", 3), ("-1", -1), 67 ("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5), 68 ("3.40282e+38", 3.40282e+38), 69 # Greater than max value of float. 70 ("3.40283e+38", 3.40283e+38), 71 # Less than min value of float. 72 ("-3.40283e+38", -3.40283e+38), 73 ("NAN", float("NAN")), 74 ("INF", float("INF"))], 75 [("10foobar", _ERROR_MESSAGE + "10foobar")]) 76 77 @test_util.run_deprecated_v1 78 def testToInt32(self): 79 self._test(dtypes.int32, 80 [("0", 0), ("3", 3), ("-1", -1), 81 (" -10", -10), 82 ("-2147483648", -2147483648), 83 ("2147483647", 2147483647)], 84 [ # Less than min value of int32. 85 ("-2147483649", _ERROR_MESSAGE + "-2147483649"), 86 # Greater than max value of int32. 87 ("2147483648", _ERROR_MESSAGE + "2147483648"), 88 ("2.9", _ERROR_MESSAGE + "2.9"), 89 ("10foobar", _ERROR_MESSAGE + "10foobar")]) 90 91 @test_util.run_deprecated_v1 92 def testToInt64(self): 93 self._test(dtypes.int64, 94 [("0", 0), ("3", 3), ("-1", -1), 95 (" -10", -10), 96 ("-2147483648", -2147483648), 97 ("2147483647", 2147483647), 98 ("-2147483649", -2147483649), # Less than min value of int32. 99 ("2147483648", 2147483648)], # Greater than max value of int32. 100 [("2.9", _ERROR_MESSAGE + "2.9"), 101 ("10foobar", _ERROR_MESSAGE + "10foobar")]) 102 103 104 if __name__ == "__main__": 105 test.main() 106