1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for BigQueryReader Op.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import json 22 import os 23 import re 24 import socket 25 import threading 26 27 from six.moves import SimpleHTTPServer 28 from six.moves import socketserver 29 30 from tensorflow.contrib.cloud.python.ops import bigquery_reader_ops as cloud 31 from tensorflow.core.example import example_pb2 32 from tensorflow.core.framework import types_pb2 33 from tensorflow.python.framework import dtypes 34 from tensorflow.python.ops import array_ops 35 from tensorflow.python.ops import data_flow_ops 36 from tensorflow.python.ops import parsing_ops 37 from tensorflow.python.platform import test 38 from tensorflow.python.platform import tf_logging as logging 39 from tensorflow.python.util import compat 40 41 _PROJECT = "test-project" 42 _DATASET = "test-dataset" 43 _TABLE = "test-table" 44 # List representation of the test rows in the 'test-table' in BigQuery. 45 # The schema for each row is: [int64, string, float]. 46 # The values for rows are generated such that some columns have null values. The 47 # general formula here is: 48 # - The int64 column is present in every row. 49 # - The string column is only available in even rows. 50 # - The float column is only available in every third row. 51 _ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1], 52 [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None], 53 [8, "s_8", None], [9, None, 9.1]] 54 # Schema for 'test-table'. 55 # The schema currently has three columns: int64, string, and float 56 _SCHEMA = { 57 "kind": "bigquery#table", 58 "id": "test-project:test-dataset.test-table", 59 "schema": { 60 "fields": [{ 61 "name": "int64_col", 62 "type": "INTEGER", 63 "mode": "NULLABLE" 64 }, { 65 "name": "string_col", 66 "type": "STRING", 67 "mode": "NULLABLE" 68 }, { 69 "name": "float_col", 70 "type": "FLOAT", 71 "mode": "NULLABLE" 72 }] 73 } 74 } 75 76 77 def _ConvertRowToExampleProto(row): 78 """Converts the input row to an Example proto. 79 80 Args: 81 row: Input Row instance. 82 83 Returns: 84 An Example proto initialized with row values. 85 """ 86 87 example = example_pb2.Example() 88 example.features.feature["int64_col"].int64_list.value.append(row[0]) 89 if row[1] is not None: 90 example.features.feature["string_col"].bytes_list.value.append( 91 compat.as_bytes(row[1])) 92 if row[2] is not None: 93 example.features.feature["float_col"].float_list.value.append(row[2]) 94 return example 95 96 97 class IPv6TCPServer(socketserver.TCPServer): 98 address_family = socket.AF_INET6 99 100 101 class FakeBigQueryServer(threading.Thread): 102 """Fake http server to return schema and data for sample table.""" 103 104 def __init__(self, address, port): 105 """Creates a FakeBigQueryServer. 106 107 Args: 108 address: Server address 109 port: Server port. Pass 0 to automatically pick an empty port. 110 """ 111 threading.Thread.__init__(self) 112 self.handler = BigQueryRequestHandler 113 try: 114 self.httpd = socketserver.TCPServer((address, port), self.handler) 115 self.host_port = "{}:{}".format(*self.httpd.server_address) 116 except IOError: 117 self.httpd = IPv6TCPServer((address, port), self.handler) 118 self.host_port = "[{}]:{}".format(*self.httpd.server_address) 119 120 def run(self): 121 self.httpd.serve_forever() 122 123 def shutdown(self): 124 self.httpd.shutdown() 125 self.httpd.socket.close() 126 127 128 class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): 129 """Responds to BigQuery HTTP requests. 130 131 Attributes: 132 num_rows: num_rows in the underlying table served by this class. 133 """ 134 135 num_rows = 0 136 137 def do_GET(self): 138 if "data?maxResults=" not in self.path: 139 # This is a schema request. 140 _SCHEMA["numRows"] = self.num_rows 141 response = json.dumps(_SCHEMA) 142 else: 143 # This is a data request. 144 # 145 # Extract max results and start index. 146 max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0]) 147 start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0]) 148 149 # Send the rows as JSON. 150 rows = [] 151 for row in _ROWS[start_index:start_index + max_results]: 152 row_json = { 153 "f": [{ 154 "v": str(row[0]) 155 }, { 156 "v": str(row[1]) if row[1] is not None else None 157 }, { 158 "v": str(row[2]) if row[2] is not None else None 159 }] 160 } 161 rows.append(row_json) 162 response = json.dumps({ 163 "kind": "bigquery#table", 164 "id": "test-project:test-dataset.test-table", 165 "rows": rows 166 }) 167 self.send_response(200) 168 self.end_headers() 169 self.wfile.write(compat.as_bytes(response)) 170 171 172 def _SetUpQueue(reader): 173 """Sets up a queue for a reader.""" 174 queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=()) 175 key, value = reader.read(queue) 176 queue.enqueue_many(reader.partitions()).run() 177 queue.close().run() 178 return key, value 179 180 181 class BigQueryReaderOpsTest(test.TestCase): 182 183 def setUp(self): 184 super(BigQueryReaderOpsTest, self).setUp() 185 self.server = FakeBigQueryServer("localhost", 0) 186 self.server.start() 187 logging.info("server address is %s", self.server.host_port) 188 189 # An override to bypass the GCP auth token retrieval logic 190 # in google_auth_provider.cc. 191 os.environ["GOOGLE_AUTH_TOKEN_FOR_TESTING"] = "not-used" 192 193 def tearDown(self): 194 self.server.shutdown() 195 super(BigQueryReaderOpsTest, self).tearDown() 196 197 def _ReadAndCheckRowsUsingFeatures(self, num_rows): 198 self.server.handler.num_rows = num_rows 199 200 with self.test_session() as sess: 201 feature_configs = { 202 "int64_col": 203 parsing_ops.FixedLenFeature( 204 [1], dtype=dtypes.int64), 205 "string_col": 206 parsing_ops.FixedLenFeature( 207 [1], dtype=dtypes.string, default_value="s_default"), 208 } 209 reader = cloud.BigQueryReader( 210 project_id=_PROJECT, 211 dataset_id=_DATASET, 212 table_id=_TABLE, 213 num_partitions=4, 214 features=feature_configs, 215 timestamp_millis=1, 216 test_end_point=self.server.host_port) 217 218 key, value = _SetUpQueue(reader) 219 220 seen_rows = [] 221 features = parsing_ops.parse_example( 222 array_ops.reshape(value, [1]), feature_configs) 223 for _ in range(num_rows): 224 int_value, str_value = sess.run( 225 [features["int64_col"], features["string_col"]]) 226 227 # Parse values returned from the session. 228 self.assertEqual(int_value.shape, (1, 1)) 229 self.assertEqual(str_value.shape, (1, 1)) 230 int64_col = int_value[0][0] 231 string_col = str_value[0][0] 232 seen_rows.append(int64_col) 233 234 # Compare. 235 expected_row = _ROWS[int64_col] 236 self.assertEqual(int64_col, expected_row[0]) 237 self.assertEqual( 238 compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] 239 else "s_default") 240 241 self.assertItemsEqual(seen_rows, range(num_rows)) 242 243 with self.assertRaisesOpError("is closed and has insufficient elements " 244 "\\(requested 1, current size 0\\)"): 245 sess.run([key, value]) 246 247 def testReadingSingleRowUsingFeatures(self): 248 self._ReadAndCheckRowsUsingFeatures(1) 249 250 def testReadingMultipleRowsUsingFeatures(self): 251 self._ReadAndCheckRowsUsingFeatures(10) 252 253 def testReadingMultipleRowsUsingColumns(self): 254 num_rows = 10 255 self.server.handler.num_rows = num_rows 256 257 with self.test_session() as sess: 258 reader = cloud.BigQueryReader( 259 project_id=_PROJECT, 260 dataset_id=_DATASET, 261 table_id=_TABLE, 262 num_partitions=4, 263 columns=["int64_col", "float_col", "string_col"], 264 timestamp_millis=1, 265 test_end_point=self.server.host_port) 266 key, value = _SetUpQueue(reader) 267 seen_rows = [] 268 for row_index in range(num_rows): 269 returned_row_id, example_proto = sess.run([key, value]) 270 example = example_pb2.Example() 271 example.ParseFromString(example_proto) 272 self.assertIn("int64_col", example.features.feature) 273 feature = example.features.feature["int64_col"] 274 self.assertEqual(len(feature.int64_list.value), 1) 275 int64_col = feature.int64_list.value[0] 276 seen_rows.append(int64_col) 277 278 # Create our expected Example. 279 expected_example = example_pb2.Example() 280 expected_example = _ConvertRowToExampleProto(_ROWS[int64_col]) 281 282 # Compare. 283 self.assertProtoEquals(example, expected_example) 284 self.assertEqual(row_index, int(returned_row_id)) 285 286 self.assertItemsEqual(seen_rows, range(num_rows)) 287 288 with self.assertRaisesOpError("is closed and has insufficient elements " 289 "\\(requested 1, current size 0\\)"): 290 sess.run([key, value]) 291 292 293 if __name__ == "__main__": 294 test.main() 295