Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for StringToNumber op from parsing_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import test_util
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import parsing_ops
     25 from tensorflow.python.platform import test
     26 
     27 _ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "
     28 
     29 
     30 class StringToNumberOpTest(test.TestCase):
     31 
     32   def _test(self, tf_type, good_pairs, bad_pairs):
     33     with self.cached_session():
     34       # Build a small testing graph.
     35       input_string = array_ops.placeholder(dtypes.string)
     36       output = parsing_ops.string_to_number(
     37           input_string, out_type=tf_type)
     38 
     39       # Check all the good input/output pairs.
     40       for instr, outnum in good_pairs:
     41         result, = output.eval(feed_dict={input_string: [instr]})
     42         self.assertAllClose([outnum], [result])
     43 
     44       # Check that the bad inputs produce the right errors.
     45       for instr, outstr in bad_pairs:
     46         with self.assertRaisesOpError(outstr):
     47           output.eval(feed_dict={input_string: [instr]})
     48 
     49   @test_util.run_deprecated_v1
     50   def testToFloat(self):
     51     self._test(dtypes.float32,
     52                [("0", 0), ("3", 3), ("-1", -1),
     53                 ("1.12", 1.12), ("0xF", 15), ("   -10.5", -10.5),
     54                 ("3.40282e+38", 3.40282e+38),
     55                 # Greater than max value of float.
     56                 ("3.40283e+38", float("INF")),
     57                 ("-3.40283e+38", float("-INF")),
     58                 # Less than min value of float.
     59                 ("NAN", float("NAN")),
     60                 ("INF", float("INF"))],
     61                [("10foobar", _ERROR_MESSAGE + "10foobar")])
     62 
     63   @test_util.run_deprecated_v1
     64   def testToDouble(self):
     65     self._test(dtypes.float64,
     66                [("0", 0), ("3", 3), ("-1", -1),
     67                 ("1.12", 1.12), ("0xF", 15), ("   -10.5", -10.5),
     68                 ("3.40282e+38", 3.40282e+38),
     69                 # Greater than max value of float.
     70                 ("3.40283e+38", 3.40283e+38),
     71                 # Less than min value of float.
     72                 ("-3.40283e+38", -3.40283e+38),
     73                 ("NAN", float("NAN")),
     74                 ("INF", float("INF"))],
     75                [("10foobar", _ERROR_MESSAGE + "10foobar")])
     76 
     77   @test_util.run_deprecated_v1
     78   def testToInt32(self):
     79     self._test(dtypes.int32,
     80                [("0", 0), ("3", 3), ("-1", -1),
     81                 ("    -10", -10),
     82                 ("-2147483648", -2147483648),
     83                 ("2147483647", 2147483647)],
     84                [   # Less than min value of int32.
     85                    ("-2147483649", _ERROR_MESSAGE + "-2147483649"),
     86                    # Greater than max value of int32.
     87                    ("2147483648", _ERROR_MESSAGE + "2147483648"),
     88                    ("2.9", _ERROR_MESSAGE + "2.9"),
     89                    ("10foobar", _ERROR_MESSAGE + "10foobar")])
     90 
     91   @test_util.run_deprecated_v1
     92   def testToInt64(self):
     93     self._test(dtypes.int64,
     94                [("0", 0), ("3", 3), ("-1", -1),
     95                 ("    -10", -10),
     96                 ("-2147483648", -2147483648),
     97                 ("2147483647", 2147483647),
     98                 ("-2147483649", -2147483649),  # Less than min value of int32.
     99                 ("2147483648", 2147483648)],  # Greater than max value of int32.
    100                [("2.9", _ERROR_MESSAGE + "2.9"),
    101                 ("10foobar", _ERROR_MESSAGE + "10foobar")])
    102 
    103 
    104 if __name__ == "__main__":
    105   test.main()
    106