1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for DecodeCSV op from parsing_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.ops import parsing_ops 24 from tensorflow.python.platform import test 25 26 27 class DecodeCSVOpTest(test.TestCase): 28 29 def _test(self, args, expected_out=None, expected_err_re=None): 30 with self.test_session() as sess: 31 decode = parsing_ops.decode_csv(**args) 32 33 if expected_err_re is None: 34 out = sess.run(decode) 35 36 for i, field in enumerate(out): 37 if field.dtype == np.float32 or field.dtype == np.float64: 38 self.assertAllClose(field, expected_out[i]) 39 else: 40 self.assertAllEqual(field, expected_out[i]) 41 42 else: 43 with self.assertRaisesOpError(expected_err_re): 44 sess.run(decode) 45 46 def testSimple(self): 47 args = { 48 "records": ["1", "2", '"3"'], 49 "record_defaults": [[1]], 50 } 51 52 expected_out = [[1, 2, 3]] 53 54 self._test(args, expected_out) 55 56 def testSimpleNoQuoteDelimiter(self): 57 args = { 58 "records": ["1", "2", '"3"'], 59 "record_defaults": [[""]], 60 "use_quote_delim": False, 61 } 62 63 expected_out = [[b"1", b"2", b'"3"']] 64 65 self._test(args, expected_out) 66 67 def testScalar(self): 68 args = {"records": '1,""', "record_defaults": [[3], [4]]} 69 70 expected_out = [1, 4] 71 72 self._test(args, expected_out) 73 74 def test2D(self): 75 args = {"records": [["1", "2"], ['""', "4"]], "record_defaults": [[5]]} 76 expected_out = [[[1, 2], [5, 4]]] 77 78 self._test(args, expected_out) 79 80 def test2DNoQuoteDelimiter(self): 81 args = {"records": [["1", "2"], ['""', '"']], 82 "record_defaults": [[""]], 83 "use_quote_delim": False} 84 expected_out = [[[b"1", b"2"], [b'""', b'"']]] 85 86 self._test(args, expected_out) 87 88 def testDouble(self): 89 args = { 90 "records": ["1.0", "-1.79e+308", '"1.79e+308"'], 91 "record_defaults": [np.array( 92 [], dtype=np.double)], 93 } 94 95 expected_out = [[1.0, -1.79e+308, 1.79e+308]] 96 97 self._test(args, expected_out) 98 99 def testInt64(self): 100 args = { 101 "records": ["1", "2", '"2147483648"'], 102 "record_defaults": [np.array( 103 [], dtype=np.int64)], 104 } 105 106 expected_out = [[1, 2, 2147483648]] 107 108 self._test(args, expected_out) 109 110 def testComplexString(self): 111 args = { 112 "records": ['"1.0"', '"ab , c"', '"a\nbc"', '"ab""c"', " abc "], 113 "record_defaults": [["1"]] 114 } 115 116 expected_out = [[b"1.0", b"ab , c", b"a\nbc", b'ab"c', b" abc "]] 117 118 self._test(args, expected_out) 119 120 def testMultiRecords(self): 121 args = { 122 "records": ["1.0,4,aa", "0.2,5,bb", "3,6,cc"], 123 "record_defaults": [[1.0], [1], ["aa"]] 124 } 125 126 expected_out = [[1.0, 0.2, 3], [4, 5, 6], [b"aa", b"bb", b"cc"]] 127 128 self._test(args, expected_out) 129 130 def testNA(self): 131 args = { 132 "records": ["2.0,NA,aa", "NA,5,bb", "3,6,NA"], 133 "record_defaults": [[0.0], [0], [""]], 134 "na_value": "NA" 135 } 136 137 expected_out = [[2.0, 0.0, 3], [0, 5, 6], [b"aa", b"bb", b""]] 138 139 self._test(args, expected_out) 140 141 def testWithDefaults(self): 142 args = { 143 "records": [",1,", "0.2,3,bcd", "3.0,,"], 144 "record_defaults": [[1.0], [0], ["a"]] 145 } 146 147 expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], [b"a", b"bcd", b"a"]] 148 149 self._test(args, expected_out) 150 151 def testWithDefaultsAndNoQuoteDelimiter(self): 152 args = { 153 "records": [",1,", "0.2,3,bcd", '3.0,,"'], 154 "record_defaults": [[1.0], [0], ["a"]], 155 "use_quote_delim": False, 156 } 157 158 expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], [b"a", b"bcd", b"\""]] 159 160 self._test(args, expected_out) 161 162 def testWithTabDelim(self): 163 args = { 164 "records": ["1\t1", "0.2\t3", "3.0\t"], 165 "record_defaults": [[1.0], [0]], 166 "field_delim": "\t" 167 } 168 169 expected_out = [[1.0, 0.2, 3.0], [1, 3, 0]] 170 171 self._test(args, expected_out) 172 173 def testWithoutDefaultsError(self): 174 args = { 175 "records": [",1", "0.2,3", "3.0,"], 176 "record_defaults": [[1.0], np.array( 177 [], dtype=np.int32)] 178 } 179 180 self._test( 181 args, expected_err_re="Field 1 is required but missing in record 2!") 182 183 def testWrongFieldIntError(self): 184 args = { 185 "records": [",1", "0.2,234a", "3.0,2"], 186 "record_defaults": [[1.0], np.array( 187 [], dtype=np.int32)] 188 } 189 190 self._test( 191 args, expected_err_re="Field 1 in record 1 is not a valid int32: 234a") 192 193 def testOutOfRangeError(self): 194 args = { 195 "records": ["1", "9999999999999999999999999", "3"], 196 "record_defaults": [[1]] 197 } 198 199 self._test( 200 args, expected_err_re="Field 0 in record 1 is not a valid int32: ") 201 202 def testWrongFieldFloatError(self): 203 args = { 204 "records": [",1", "0.2,2", "3.0adf,3"], 205 "record_defaults": [[1.0], np.array( 206 [], dtype=np.int32)] 207 } 208 209 self._test( 210 args, expected_err_re="Field 0 in record 2 is not a valid float: ") 211 212 def testWrongFieldStringError(self): 213 args = {"records": ['"1,a,"', "0.22", 'a"bc'], "record_defaults": [["a"]]} 214 215 self._test( 216 args, expected_err_re="Unquoted fields cannot have quotes/CRLFs inside") 217 218 def testWrongDefaults(self): 219 args = {"records": [",1", "0.2,2", "3.0adf,3"], "record_defaults": [[1.0]]} 220 221 self._test(args, expected_err_re="Expect 1 fields but have 2 in record 0") 222 223 def testShortQuotedString(self): 224 args = { 225 "records": ["\""], 226 "record_defaults": [["default"]], 227 } 228 229 self._test( 230 args, expected_err_re="Quoted field has to end with quote followed.*") 231 232 233 if __name__ == "__main__": 234 test.main() 235