Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for DecodeCSV op from parsing_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.ops import parsing_ops
     24 from tensorflow.python.platform import test
     25 
     26 
     27 class DecodeCSVOpTest(test.TestCase):
     28 
     29   def _test(self, args, expected_out=None, expected_err_re=None):
     30     with self.test_session() as sess:
     31       decode = parsing_ops.decode_csv(**args)
     32 
     33       if expected_err_re is None:
     34         out = sess.run(decode)
     35 
     36         for i, field in enumerate(out):
     37           if field.dtype == np.float32 or field.dtype == np.float64:
     38             self.assertAllClose(field, expected_out[i])
     39           else:
     40             self.assertAllEqual(field, expected_out[i])
     41 
     42       else:
     43         with self.assertRaisesOpError(expected_err_re):
     44           sess.run(decode)
     45 
     46   def testSimple(self):
     47     args = {
     48         "records": ["1", "2", '"3"'],
     49         "record_defaults": [[1]],
     50     }
     51 
     52     expected_out = [[1, 2, 3]]
     53 
     54     self._test(args, expected_out)
     55 
     56   def testSimpleNoQuoteDelimiter(self):
     57     args = {
     58         "records": ["1", "2", '"3"'],
     59         "record_defaults": [[""]],
     60         "use_quote_delim": False,
     61     }
     62 
     63     expected_out = [[b"1", b"2", b'"3"']]
     64 
     65     self._test(args, expected_out)
     66 
     67   def testScalar(self):
     68     args = {"records": '1,""', "record_defaults": [[3], [4]]}
     69 
     70     expected_out = [1, 4]
     71 
     72     self._test(args, expected_out)
     73 
     74   def test2D(self):
     75     args = {"records": [["1", "2"], ['""', "4"]], "record_defaults": [[5]]}
     76     expected_out = [[[1, 2], [5, 4]]]
     77 
     78     self._test(args, expected_out)
     79 
     80   def test2DNoQuoteDelimiter(self):
     81     args = {"records": [["1", "2"], ['""', '"']],
     82             "record_defaults": [[""]],
     83             "use_quote_delim": False}
     84     expected_out = [[[b"1", b"2"], [b'""', b'"']]]
     85 
     86     self._test(args, expected_out)
     87 
     88   def testDouble(self):
     89     args = {
     90         "records": ["1.0", "-1.79e+308", '"1.79e+308"'],
     91         "record_defaults": [np.array(
     92             [], dtype=np.double)],
     93     }
     94 
     95     expected_out = [[1.0, -1.79e+308, 1.79e+308]]
     96 
     97     self._test(args, expected_out)
     98 
     99   def testInt64(self):
    100     args = {
    101         "records": ["1", "2", '"2147483648"'],
    102         "record_defaults": [np.array(
    103             [], dtype=np.int64)],
    104     }
    105 
    106     expected_out = [[1, 2, 2147483648]]
    107 
    108     self._test(args, expected_out)
    109 
    110   def testComplexString(self):
    111     args = {
    112         "records": ['"1.0"', '"ab , c"', '"a\nbc"', '"ab""c"', " abc "],
    113         "record_defaults": [["1"]]
    114     }
    115 
    116     expected_out = [[b"1.0", b"ab , c", b"a\nbc", b'ab"c', b" abc "]]
    117 
    118     self._test(args, expected_out)
    119 
    120   def testMultiRecords(self):
    121     args = {
    122         "records": ["1.0,4,aa", "0.2,5,bb", "3,6,cc"],
    123         "record_defaults": [[1.0], [1], ["aa"]]
    124     }
    125 
    126     expected_out = [[1.0, 0.2, 3], [4, 5, 6], [b"aa", b"bb", b"cc"]]
    127 
    128     self._test(args, expected_out)
    129 
    130   def testNA(self):
    131     args = {
    132         "records": ["2.0,NA,aa", "NA,5,bb", "3,6,NA"],
    133         "record_defaults": [[0.0], [0], [""]],
    134         "na_value": "NA"
    135     }
    136 
    137     expected_out = [[2.0, 0.0, 3], [0, 5, 6], [b"aa", b"bb", b""]]
    138 
    139     self._test(args, expected_out)
    140 
    141   def testWithDefaults(self):
    142     args = {
    143         "records": [",1,", "0.2,3,bcd", "3.0,,"],
    144         "record_defaults": [[1.0], [0], ["a"]]
    145     }
    146 
    147     expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], [b"a", b"bcd", b"a"]]
    148 
    149     self._test(args, expected_out)
    150 
    151   def testWithDefaultsAndNoQuoteDelimiter(self):
    152     args = {
    153         "records": [",1,", "0.2,3,bcd", '3.0,,"'],
    154         "record_defaults": [[1.0], [0], ["a"]],
    155         "use_quote_delim": False,
    156     }
    157 
    158     expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], [b"a", b"bcd", b"\""]]
    159 
    160     self._test(args, expected_out)
    161 
    162   def testWithTabDelim(self):
    163     args = {
    164         "records": ["1\t1", "0.2\t3", "3.0\t"],
    165         "record_defaults": [[1.0], [0]],
    166         "field_delim": "\t"
    167     }
    168 
    169     expected_out = [[1.0, 0.2, 3.0], [1, 3, 0]]
    170 
    171     self._test(args, expected_out)
    172 
    173   def testWithoutDefaultsError(self):
    174     args = {
    175         "records": [",1", "0.2,3", "3.0,"],
    176         "record_defaults": [[1.0], np.array(
    177             [], dtype=np.int32)]
    178     }
    179 
    180     self._test(
    181         args, expected_err_re="Field 1 is required but missing in record 2!")
    182 
    183   def testWrongFieldIntError(self):
    184     args = {
    185         "records": [",1", "0.2,234a", "3.0,2"],
    186         "record_defaults": [[1.0], np.array(
    187             [], dtype=np.int32)]
    188     }
    189 
    190     self._test(
    191         args, expected_err_re="Field 1 in record 1 is not a valid int32: 234a")
    192 
    193   def testOutOfRangeError(self):
    194     args = {
    195         "records": ["1", "9999999999999999999999999", "3"],
    196         "record_defaults": [[1]]
    197     }
    198 
    199     self._test(
    200         args, expected_err_re="Field 0 in record 1 is not a valid int32: ")
    201 
    202   def testWrongFieldFloatError(self):
    203     args = {
    204         "records": [",1", "0.2,2", "3.0adf,3"],
    205         "record_defaults": [[1.0], np.array(
    206             [], dtype=np.int32)]
    207     }
    208 
    209     self._test(
    210         args, expected_err_re="Field 0 in record 2 is not a valid float: ")
    211 
    212   def testWrongFieldStringError(self):
    213     args = {"records": ['"1,a,"', "0.22", 'a"bc'], "record_defaults": [["a"]]}
    214 
    215     self._test(
    216         args, expected_err_re="Unquoted fields cannot have quotes/CRLFs inside")
    217 
    218   def testWrongDefaults(self):
    219     args = {"records": [",1", "0.2,2", "3.0adf,3"], "record_defaults": [[1.0]]}
    220 
    221     self._test(args, expected_err_re="Expect 1 fields but have 2 in record 0")
    222 
    223   def testShortQuotedString(self):
    224     args = {
    225         "records": ["\""],
    226         "record_defaults": [["default"]],
    227     }
    228 
    229     self._test(
    230         args, expected_err_re="Quoted field has to end with quote followed.*")
    231 
    232 
    233 if __name__ == "__main__":
    234   test.main()
    235