Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """BigQuery reading support for TensorFlow."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.cloud.python.ops import gen_bigquery_reader_ops
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import io_ops
     24 
     25 
     26 class BigQueryReader(io_ops.ReaderBase):
     27   """A Reader that outputs keys and tf.Example values from a BigQuery table.
     28 
     29   Example use:
     30     ```python
     31     # Assume a BigQuery has the following schema,
     32     #     name      STRING,
     33     #     age       INT,
     34     #     state     STRING
     35 
     36     # Create the parse_examples list of features.
     37     features = dict(
     38       name=tf.FixedLenFeature([1], tf.string),
     39       age=tf.FixedLenFeature([1], tf.int32),
     40       state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK"))
     41 
     42     # Create a Reader.
     43     reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT,
     44                                                 dataset_id=DATASET,
     45                                                 table_id=TABLE,
     46                                                 timestamp_millis=TIME,
     47                                                 num_partitions=NUM_PARTITIONS,
     48                                                 features=features)
     49 
     50     # Populate a queue with the BigQuery Table partitions.
     51     queue = tf.train.string_input_producer(reader.partitions())
     52 
     53     # Read and parse examples.
     54     row_id, examples_serialized = reader.read(queue)
     55     examples = tf.parse_example(examples_serialized, features=features)
     56 
     57     # Process the Tensors examples["name"], examples["age"], etc...
     58     ```
     59 
     60   Note that to create a reader a snapshot timestamp is necessary. This
     61   will enable the reader to look at a consistent snapshot of the table.
     62   For more information, see 'Table Decorators' in BigQuery docs.
     63 
     64   See ReaderBase for supported methods.
     65   """
     66 
     67   def __init__(self,
     68                project_id,
     69                dataset_id,
     70                table_id,
     71                timestamp_millis,
     72                num_partitions,
     73                features=None,
     74                columns=None,
     75                test_end_point=None,
     76                name=None):
     77     """Creates a BigQueryReader.
     78 
     79     Args:
     80       project_id: GCP project ID.
     81       dataset_id: BigQuery dataset ID.
     82       table_id: BigQuery table ID.
     83       timestamp_millis: timestamp to snapshot the table in milliseconds since
     84         the epoch. Relative (negative or zero) snapshot times are not allowed.
     85         For more details, see 'Table Decorators' in BigQuery docs.
     86       num_partitions: Number of non-overlapping partitions to read from.
     87       features: parse_example compatible dict from keys to `VarLenFeature` and
     88         `FixedLenFeature` objects.  Keys are read as columns from the db.
     89       columns: list of columns to read, can be set iff features is None.
     90       test_end_point: Used only for testing purposes (optional).
     91       name: a name for the operation (optional).
     92 
     93     Raises:
     94       TypeError: - If features is neither None nor a dict or
     95                  - If columns is neither None nor a list or
     96                  - If both features and columns are None or set.
     97     """
     98     if (features is None) == (columns is None):
     99       raise TypeError("exactly one of features and columns must be set.")
    100 
    101     if features is not None:
    102       if not isinstance(features, dict):
    103         raise TypeError("features must be a dict.")
    104       self._columns = list(features.keys())
    105     elif columns is not None:
    106       if not isinstance(columns, list):
    107         raise TypeError("columns must be a list.")
    108       self._columns = columns
    109 
    110     self._project_id = project_id
    111     self._dataset_id = dataset_id
    112     self._table_id = table_id
    113     self._timestamp_millis = timestamp_millis
    114     self._num_partitions = num_partitions
    115     self._test_end_point = test_end_point
    116 
    117     reader = gen_bigquery_reader_ops.big_query_reader(
    118         name=name,
    119         project_id=self._project_id,
    120         dataset_id=self._dataset_id,
    121         table_id=self._table_id,
    122         timestamp_millis=self._timestamp_millis,
    123         columns=self._columns,
    124         test_end_point=self._test_end_point)
    125     super(BigQueryReader, self).__init__(reader)
    126 
    127   def partitions(self, name=None):
    128     """Returns serialized BigQueryTablePartition messages.
    129 
    130     These messages represent a non-overlapping division of a table for a
    131     bulk read.
    132 
    133     Args:
    134       name: a name for the operation (optional).
    135 
    136     Returns:
    137       `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages.
    138     """
    139     return gen_bigquery_reader_ops.generate_big_query_reader_partitions(
    140         name=name,
    141         project_id=self._project_id,
    142         dataset_id=self._dataset_id,
    143         table_id=self._table_id,
    144         timestamp_millis=self._timestamp_millis,
    145         num_partitions=self._num_partitions,
    146         test_end_point=self._test_end_point,
    147         columns=self._columns)
    148 
    149 
    150 ops.NotDifferentiable("BigQueryReader")
    151