Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for BigQueryReader Op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import json
     22 import os
     23 import re
     24 import socket
     25 import threading
     26 
     27 from six.moves import SimpleHTTPServer
     28 from six.moves import socketserver
     29 
     30 from tensorflow.contrib.cloud.python.ops import bigquery_reader_ops as cloud
     31 from tensorflow.core.example import example_pb2
     32 from tensorflow.core.framework import types_pb2
     33 from tensorflow.python.framework import dtypes
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import data_flow_ops
     36 from tensorflow.python.ops import parsing_ops
     37 from tensorflow.python.platform import test
     38 from tensorflow.python.platform import tf_logging as logging
     39 from tensorflow.python.util import compat
     40 
     41 _PROJECT = "test-project"
     42 _DATASET = "test-dataset"
     43 _TABLE = "test-table"
     44 # List representation of the test rows in the 'test-table' in BigQuery.
     45 # The schema for each row is: [int64, string, float].
     46 # The values for rows are generated such that some columns have null values. The
     47 # general formula here is:
     48 #   - The int64 column is present in every row.
     49 #   - The string column is only available in even rows.
     50 #   - The float column is only available in every third row.
     51 _ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1],
     52          [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None],
     53          [8, "s_8", None], [9, None, 9.1]]
     54 # Schema for 'test-table'.
     55 # The schema currently has three columns: int64, string, and float
     56 _SCHEMA = {
     57     "kind": "bigquery#table",
     58     "id": "test-project:test-dataset.test-table",
     59     "schema": {
     60         "fields": [{
     61             "name": "int64_col",
     62             "type": "INTEGER",
     63             "mode": "NULLABLE"
     64         }, {
     65             "name": "string_col",
     66             "type": "STRING",
     67             "mode": "NULLABLE"
     68         }, {
     69             "name": "float_col",
     70             "type": "FLOAT",
     71             "mode": "NULLABLE"
     72         }]
     73     }
     74 }
     75 
     76 
     77 def _ConvertRowToExampleProto(row):
     78   """Converts the input row to an Example proto.
     79 
     80   Args:
     81     row: Input Row instance.
     82 
     83   Returns:
     84     An Example proto initialized with row values.
     85   """
     86 
     87   example = example_pb2.Example()
     88   example.features.feature["int64_col"].int64_list.value.append(row[0])
     89   if row[1] is not None:
     90     example.features.feature["string_col"].bytes_list.value.append(
     91         compat.as_bytes(row[1]))
     92   if row[2] is not None:
     93     example.features.feature["float_col"].float_list.value.append(row[2])
     94   return example
     95 
     96 
     97 class IPv6TCPServer(socketserver.TCPServer):
     98   address_family = socket.AF_INET6
     99 
    100 
    101 class FakeBigQueryServer(threading.Thread):
    102   """Fake http server to return schema and data for sample table."""
    103 
    104   def __init__(self, address, port):
    105     """Creates a FakeBigQueryServer.
    106 
    107     Args:
    108       address: Server address
    109       port: Server port. Pass 0 to automatically pick an empty port.
    110     """
    111     threading.Thread.__init__(self)
    112     self.handler = BigQueryRequestHandler
    113     try:
    114       self.httpd = socketserver.TCPServer((address, port), self.handler)
    115       self.host_port = "{}:{}".format(*self.httpd.server_address)
    116     except IOError:
    117       self.httpd = IPv6TCPServer((address, port), self.handler)
    118       self.host_port = "[{}]:{}".format(*self.httpd.server_address)
    119 
    120   def run(self):
    121     self.httpd.serve_forever()
    122 
    123   def shutdown(self):
    124     self.httpd.shutdown()
    125     self.httpd.socket.close()
    126 
    127 
    128 class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
    129   """Responds to BigQuery HTTP requests.
    130 
    131     Attributes:
    132       num_rows: num_rows in the underlying table served by this class.
    133   """
    134 
    135   num_rows = 0
    136 
    137   def do_GET(self):
    138     if "data?maxResults=" not in self.path:
    139       # This is a schema request.
    140       _SCHEMA["numRows"] = self.num_rows
    141       response = json.dumps(_SCHEMA)
    142     else:
    143       # This is a data request.
    144       #
    145       # Extract max results and start index.
    146       max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0])
    147       start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0])
    148 
    149       # Send the rows as JSON.
    150       rows = []
    151       for row in _ROWS[start_index:start_index + max_results]:
    152         row_json = {
    153             "f": [{
    154                 "v": str(row[0])
    155             }, {
    156                 "v": str(row[1]) if row[1] is not None else None
    157             }, {
    158                 "v": str(row[2]) if row[2] is not None else None
    159             }]
    160         }
    161         rows.append(row_json)
    162       response = json.dumps({
    163           "kind": "bigquery#table",
    164           "id": "test-project:test-dataset.test-table",
    165           "rows": rows
    166       })
    167     self.send_response(200)
    168     self.end_headers()
    169     self.wfile.write(compat.as_bytes(response))
    170 
    171 
    172 def _SetUpQueue(reader):
    173   """Sets up a queue for a reader."""
    174   queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=())
    175   key, value = reader.read(queue)
    176   queue.enqueue_many(reader.partitions()).run()
    177   queue.close().run()
    178   return key, value
    179 
    180 
    181 class BigQueryReaderOpsTest(test.TestCase):
    182 
    183   def setUp(self):
    184     super(BigQueryReaderOpsTest, self).setUp()
    185     self.server = FakeBigQueryServer("localhost", 0)
    186     self.server.start()
    187     logging.info("server address is %s", self.server.host_port)
    188 
    189     # An override to bypass the GCP auth token retrieval logic
    190     # in google_auth_provider.cc.
    191     os.environ["GOOGLE_AUTH_TOKEN_FOR_TESTING"] = "not-used"
    192 
    193   def tearDown(self):
    194     self.server.shutdown()
    195     super(BigQueryReaderOpsTest, self).tearDown()
    196 
    197   def _ReadAndCheckRowsUsingFeatures(self, num_rows):
    198     self.server.handler.num_rows = num_rows
    199 
    200     with self.test_session() as sess:
    201       feature_configs = {
    202           "int64_col":
    203               parsing_ops.FixedLenFeature(
    204                   [1], dtype=dtypes.int64),
    205           "string_col":
    206               parsing_ops.FixedLenFeature(
    207                   [1], dtype=dtypes.string, default_value="s_default"),
    208       }
    209       reader = cloud.BigQueryReader(
    210           project_id=_PROJECT,
    211           dataset_id=_DATASET,
    212           table_id=_TABLE,
    213           num_partitions=4,
    214           features=feature_configs,
    215           timestamp_millis=1,
    216           test_end_point=self.server.host_port)
    217 
    218       key, value = _SetUpQueue(reader)
    219 
    220       seen_rows = []
    221       features = parsing_ops.parse_example(
    222           array_ops.reshape(value, [1]), feature_configs)
    223       for _ in range(num_rows):
    224         int_value, str_value = sess.run(
    225             [features["int64_col"], features["string_col"]])
    226 
    227         # Parse values returned from the session.
    228         self.assertEqual(int_value.shape, (1, 1))
    229         self.assertEqual(str_value.shape, (1, 1))
    230         int64_col = int_value[0][0]
    231         string_col = str_value[0][0]
    232         seen_rows.append(int64_col)
    233 
    234         # Compare.
    235         expected_row = _ROWS[int64_col]
    236         self.assertEqual(int64_col, expected_row[0])
    237         self.assertEqual(
    238             compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
    239             else "s_default")
    240 
    241       self.assertItemsEqual(seen_rows, range(num_rows))
    242 
    243       with self.assertRaisesOpError("is closed and has insufficient elements "
    244                                     "\\(requested 1, current size 0\\)"):
    245         sess.run([key, value])
    246 
    247   def testReadingSingleRowUsingFeatures(self):
    248     self._ReadAndCheckRowsUsingFeatures(1)
    249 
    250   def testReadingMultipleRowsUsingFeatures(self):
    251     self._ReadAndCheckRowsUsingFeatures(10)
    252 
    253   def testReadingMultipleRowsUsingColumns(self):
    254     num_rows = 10
    255     self.server.handler.num_rows = num_rows
    256 
    257     with self.test_session() as sess:
    258       reader = cloud.BigQueryReader(
    259           project_id=_PROJECT,
    260           dataset_id=_DATASET,
    261           table_id=_TABLE,
    262           num_partitions=4,
    263           columns=["int64_col", "float_col", "string_col"],
    264           timestamp_millis=1,
    265           test_end_point=self.server.host_port)
    266       key, value = _SetUpQueue(reader)
    267       seen_rows = []
    268       for row_index in range(num_rows):
    269         returned_row_id, example_proto = sess.run([key, value])
    270         example = example_pb2.Example()
    271         example.ParseFromString(example_proto)
    272         self.assertIn("int64_col", example.features.feature)
    273         feature = example.features.feature["int64_col"]
    274         self.assertEqual(len(feature.int64_list.value), 1)
    275         int64_col = feature.int64_list.value[0]
    276         seen_rows.append(int64_col)
    277 
    278         # Create our expected Example.
    279         expected_example = example_pb2.Example()
    280         expected_example = _ConvertRowToExampleProto(_ROWS[int64_col])
    281 
    282         # Compare.
    283         self.assertProtoEquals(example, expected_example)
    284         self.assertEqual(row_index, int(returned_row_id))
    285 
    286       self.assertItemsEqual(seen_rows, range(num_rows))
    287 
    288       with self.assertRaisesOpError("is closed and has insufficient elements "
    289                                     "\\(requested 1, current size 0\\)"):
    290         sess.run([key, value])
    291 
    292 
    293 if __name__ == "__main__":
    294   test.main()
    295