Home | History | Annotate | Download | only in ops
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Python API for TensorFlow's Cloud Bigtable integration.
     16 
     17 TensorFlow has support for reading from and writing to Cloud Bigtable. To use
     18 TensorFlow + Cloud Bigtable integration, first create a BigtableClient to
     19 configure your connection to Cloud Bigtable, and then create a BigtableTable
     20 object to allow you to create numerous `tf.data.Dataset`s to read data, or
     21 write a `tf.data.Dataset` object to the underlying Cloud Bigtable table.
     22 
     23 For background on Cloud Bigtable, see: https://cloud.google.com/bigtable .
     24 """
     25 
     26 from __future__ import absolute_import
     27 from __future__ import division
     28 from __future__ import print_function
     29 
     30 from six import iteritems
     31 from six import string_types
     32 
     33 from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
     34 from tensorflow.contrib.util import loader
     35 from tensorflow.python.data.experimental.ops import interleave_ops
     36 from tensorflow.python.data.ops import dataset_ops
     37 from tensorflow.python.data.util import nest
     38 from tensorflow.python.data.util import structure
     39 from tensorflow.python.framework import dtypes
     40 from tensorflow.python.framework import tensor_shape
     41 from tensorflow.python.platform import resource_loader
     42 
     43 _bigtable_so = loader.load_op_library(
     44     resource_loader.get_path_to_datafile("_bigtable.so"))
     45 
     46 
     47 class BigtableClient(object):
     48   """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
     49 
     50   BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
     51   `table` method to open a Bigtable table.
     52   """
     53 
     54   def __init__(self,
     55                project_id,
     56                instance_id,
     57                connection_pool_size=None,
     58                max_receive_message_size=None):
     59     """Creates a BigtableClient that can be used to open connections to tables.
     60 
     61     Args:
     62       project_id: A string representing the GCP project id to connect to.
     63       instance_id: A string representing the Bigtable instance to connect to.
     64       connection_pool_size: (Optional.) A number representing the number of
     65         concurrent connections to the Cloud Bigtable service to make.
     66       max_receive_message_size: (Optional.) The maximum bytes received in a
     67         single gRPC response.
     68 
     69     Raises:
     70       ValueError: if the arguments are invalid (e.g. wrong type, or out of
     71         expected ranges (e.g. negative).)
     72     """
     73     if not isinstance(project_id, str):
     74       raise ValueError("`project_id` must be a string")
     75     self._project_id = project_id
     76 
     77     if not isinstance(instance_id, str):
     78       raise ValueError("`instance_id` must be a string")
     79     self._instance_id = instance_id
     80 
     81     if connection_pool_size is None:
     82       connection_pool_size = -1
     83     elif connection_pool_size < 1:
     84       raise ValueError("`connection_pool_size` must be positive")
     85 
     86     if max_receive_message_size is None:
     87       max_receive_message_size = -1
     88     elif max_receive_message_size < 1:
     89       raise ValueError("`max_receive_message_size` must be positive")
     90 
     91     self._connection_pool_size = connection_pool_size
     92 
     93     self._resource = gen_bigtable_ops.bigtable_client(
     94         project_id, instance_id, connection_pool_size, max_receive_message_size)
     95 
     96   def table(self, name, snapshot=None):
     97     """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object.
     98 
     99     Args:
    100       name: A `tf.string` `tf.Tensor` name of the table to open.
    101       snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to
    102         request the creation of a snapshot. (Note: currently unimplemented.)
    103 
    104     Returns:
    105       A `tf.contrib.bigtable.BigtableTable` Python object representing the
    106       operations available on the table.
    107     """
    108     # TODO(saeta): Implement snapshot functionality.
    109     table = gen_bigtable_ops.bigtable_table(self._resource, name)
    110     return BigtableTable(name, snapshot, table)
    111 
    112 
    113 class BigtableTable(object):
    114   """Entry point for reading and writing data in Cloud Bigtable.
    115 
    116   This BigtableTable class is the Python representation of the Cloud Bigtable
    117   table within TensorFlow. Methods on this class allow data to be read from and
    118   written to the Cloud Bigtable service in flexible and high performance
    119   manners.
    120   """
    121 
    122   # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface.
    123   # TODO(saeta): Consider variant tensors instead of resources (while supporting
    124   #    connection pooling).
    125 
    126   def __init__(self, name, snapshot, resource):
    127     self._name = name
    128     self._snapshot = snapshot
    129     self._resource = resource
    130 
    131   def lookup_columns(self, *args, **kwargs):
    132     """Retrieves the values of columns for a dataset of keys.
    133 
    134     Example usage:
    135 
    136     ```python
    137     table = bigtable_client.table("my_table")
    138     key_dataset = table.get_keys_prefix("imagenet")
    139     images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
    140                                                     ("cf2", "label"),
    141                                                     ("cf2", "boundingbox")))
    142     training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
    143     ```
    144 
    145     Alternatively, you can use keyword arguments to specify the columns to
    146     capture. Example (same as above, rewritten):
    147 
    148     ```python
    149     table = bigtable_client.table("my_table")
    150     key_dataset = table.get_keys_prefix("imagenet")
    151     images = key_dataset.apply(table.lookup_columns(
    152         cf1="image", cf2=("label", "boundingbox")))
    153     training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
    154     ```
    155 
    156     Note: certain `kwargs` keys are reserved, and thus, some column families
    157     cannot be identified using the `kwargs` syntax. Instead, please use the
    158     `args` syntax. This list includes:
    159 
    160       - 'name'
    161 
    162     Note: this list can change at any time.
    163 
    164     Args:
    165       *args: A list of tuples containing (column family, column name) pairs.
    166       **kwargs: Column families (keys) and column qualifiers (values).
    167 
    168     Returns:
    169       A function that can be passed to `tf.data.Dataset.apply` to retrieve the
    170       values of columns for the rows.
    171     """
    172     table = self  # Capture self
    173     normalized = args
    174     if normalized is None:
    175       normalized = []
    176     if isinstance(normalized, tuple):
    177       normalized = list(normalized)
    178     for key, value in iteritems(kwargs):
    179       if key == "name":
    180         continue
    181       if isinstance(value, str):
    182         normalized.append((key, value))
    183         continue
    184       for col in value:
    185         normalized.append((key, col))
    186 
    187     def _apply_fn(dataset):
    188       # TODO(saeta): Verify dataset's types are correct!
    189       return _BigtableLookupDataset(dataset, table, normalized)
    190 
    191     return _apply_fn
    192 
    193   def keys_by_range_dataset(self, start, end):
    194     """Retrieves all row keys between start and end.
    195 
    196     Note: it does NOT retrieve the values of columns.
    197 
    198     Args:
    199       start: The start row key. The row keys for rows after start (inclusive)
    200         will be retrieved.
    201       end: (Optional.) The end row key. Rows up to (but not including) end will
    202         be retrieved. If end is None, all subsequent row keys will be retrieved.
    203 
    204     Returns:
    205       A `tf.data.Dataset` containing `tf.string` Tensors corresponding to all
    206       of the row keys between `start` and `end`.
    207     """
    208     # TODO(saeta): Make inclusive / exclusive configurable?
    209     if end is None:
    210       end = ""
    211     return _BigtableRangeKeyDataset(self, start, end)
    212 
    213   def keys_by_prefix_dataset(self, prefix):
    214     """Retrieves the row keys matching a given prefix.
    215 
    216     Args:
    217       prefix: All row keys that begin with `prefix` in the table will be
    218         retrieved.
    219 
    220     Returns:
    221       A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all
    222       of the row keys matching that prefix.
    223     """
    224     return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix))
    225 
    226   def sample_keys(self):
    227     """Retrieves a sampling of row keys from the Bigtable table.
    228 
    229     This dataset is most often used in conjunction with
    230     `tf.data.experimental.parallel_interleave` to construct a set of ranges for
    231     scanning in parallel.
    232 
    233     Returns:
    234       A `tf.data.Dataset` returning string row keys.
    235     """
    236     return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self))
    237 
    238   def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
    239     """Retrieves row (including values) from the Bigtable service.
    240 
    241     Rows with row-key prefixed by `prefix` will be retrieved.
    242 
    243     Specifying the columns to retrieve for each row is done by either using
    244     kwargs or in the columns parameter. To retrieve values of the columns "c1",
    245     and "c2" from the column family "cfa", and the value of the column "c3"
    246     from column family "cfb", the following datasets (`ds1`, and `ds2`) are
    247     equivalent:
    248 
    249     ```
    250     table = # ...
    251     ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"),
    252                                                    ("cfa", "c2"),
    253                                                    ("cfb", "c3")])
    254     ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
    255     ```
    256 
    257     Note: only the latest value of a cell will be retrieved.
    258 
    259     Args:
    260       prefix: The prefix all row keys must match to be retrieved for prefix-
    261         based scans.
    262       probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
    263         A non-1 value indicates to probabilistically sample rows with the
    264         provided probability.
    265       columns: The columns to read. Note: most commonly, they are expressed as
    266         kwargs. Use the columns value if you are using column families that are
    267         reserved. The value of columns and kwargs are merged. Columns is a list
    268         of tuples of strings ("column_family", "column_qualifier").
    269       **kwargs: The column families and columns to read. Keys are treated as
    270         column_families, and values can be either lists of strings, or strings
    271         that are treated as the column qualifier (column name).
    272 
    273     Returns:
    274       A `tf.data.Dataset` returning the row keys and the cell contents.
    275 
    276     Raises:
    277       ValueError: If the configured probability is unexpected.
    278     """
    279     probability = _normalize_probability(probability)
    280     normalized = _normalize_columns(columns, kwargs)
    281     return dataset_ops.DatasetV1Adapter(
    282         _BigtableScanDataset(self, prefix, "", "", normalized, probability))
    283 
    284   def scan_range(self, start, end, probability=None, columns=None, **kwargs):
    285     """Retrieves rows (including values) from the Bigtable service.
    286 
    287     Rows with row-keys between `start` and `end` will be retrieved.
    288 
    289     Specifying the columns to retrieve for each row is done by either using
    290     kwargs or in the columns parameter. To retrieve values of the columns "c1",
    291     and "c2" from the column family "cfa", and the value of the column "c3"
    292     from column family "cfb", the following datasets (`ds1`, and `ds2`) are
    293     equivalent:
    294 
    295     ```
    296     table = # ...
    297     ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
    298                                                             ("cfa", "c2"),
    299                                                             ("cfb", "c3")])
    300     ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
    301     ```
    302 
    303     Note: only the latest value of a cell will be retrieved.
    304 
    305     Args:
    306       start: The start of the range when scanning by range.
    307       end: (Optional.) The end of the range when scanning by range.
    308       probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
    309         A non-1 value indicates to probabilistically sample rows with the
    310         provided probability.
    311       columns: The columns to read. Note: most commonly, they are expressed as
    312         kwargs. Use the columns value if you are using column families that are
    313         reserved. The value of columns and kwargs are merged. Columns is a list
    314         of tuples of strings ("column_family", "column_qualifier").
    315       **kwargs: The column families and columns to read. Keys are treated as
    316         column_families, and values can be either lists of strings, or strings
    317         that are treated as the column qualifier (column name).
    318 
    319     Returns:
    320       A `tf.data.Dataset` returning the row keys and the cell contents.
    321 
    322     Raises:
    323       ValueError: If the configured probability is unexpected.
    324     """
    325     probability = _normalize_probability(probability)
    326     normalized = _normalize_columns(columns, kwargs)
    327     return dataset_ops.DatasetV1Adapter(
    328         _BigtableScanDataset(self, "", start, end, normalized, probability))
    329 
    330   def parallel_scan_prefix(self,
    331                            prefix,
    332                            num_parallel_scans=None,
    333                            probability=None,
    334                            columns=None,
    335                            **kwargs):
    336     """Retrieves row (including values) from the Bigtable service at high speed.
    337 
    338     Rows with row-key prefixed by `prefix` will be retrieved. This method is
    339     similar to `scan_prefix`, but by contrast performs multiple sub-scans in
    340     parallel in order to achieve higher performance.
    341 
    342     Note: The dataset produced by this method is not deterministic!
    343 
    344     Specifying the columns to retrieve for each row is done by either using
    345     kwargs or in the columns parameter. To retrieve values of the columns "c1",
    346     and "c2" from the column family "cfa", and the value of the column "c3"
    347     from column family "cfb", the following datasets (`ds1`, and `ds2`) are
    348     equivalent:
    349 
    350     ```
    351     table = # ...
    352     ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
    353                                                             ("cfa", "c2"),
    354                                                             ("cfb", "c3")])
    355     ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
    356     ```
    357 
    358     Note: only the latest value of a cell will be retrieved.
    359 
    360     Args:
    361       prefix: The prefix all row keys must match to be retrieved for prefix-
    362         based scans.
    363       num_parallel_scans: (Optional.) The number of concurrent scans against the
    364         Cloud Bigtable instance.
    365       probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
    366         A non-1 value indicates to probabilistically sample rows with the
    367         provided probability.
    368       columns: The columns to read. Note: most commonly, they are expressed as
    369         kwargs. Use the columns value if you are using column families that are
    370         reserved. The value of columns and kwargs are merged. Columns is a list
    371         of tuples of strings ("column_family", "column_qualifier").
    372       **kwargs: The column families and columns to read. Keys are treated as
    373         column_families, and values can be either lists of strings, or strings
    374         that are treated as the column qualifier (column name).
    375 
    376     Returns:
    377       A `tf.data.Dataset` returning the row keys and the cell contents.
    378 
    379     Raises:
    380       ValueError: If the configured probability is unexpected.
    381     """
    382     probability = _normalize_probability(probability)
    383     normalized = _normalize_columns(columns, kwargs)
    384     ds = dataset_ops.DatasetV1Adapter(
    385         _BigtableSampleKeyPairsDataset(self, prefix, "", ""))
    386     return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
    387                                             normalized)
    388 
    389   def parallel_scan_range(self,
    390                           start,
    391                           end,
    392                           num_parallel_scans=None,
    393                           probability=None,
    394                           columns=None,
    395                           **kwargs):
    396     """Retrieves rows (including values) from the Bigtable service.
    397 
    398     Rows with row-keys between `start` and `end` will be retrieved. This method
    399     is similar to `scan_range`, but by contrast performs multiple sub-scans in
    400     parallel in order to achieve higher performance.
    401 
    402     Note: The dataset produced by this method is not deterministic!
    403 
    404     Specifying the columns to retrieve for each row is done by either using
    405     kwargs or in the columns parameter. To retrieve values of the columns "c1",
    406     and "c2" from the column family "cfa", and the value of the column "c3"
    407     from column family "cfb", the following datasets (`ds1`, and `ds2`) are
    408     equivalent:
    409 
    410     ```
    411     table = # ...
    412     ds1 = table.parallel_scan_range("row_start",
    413                                     "row_end",
    414                                     columns=[("cfa", "c1"),
    415                                              ("cfa", "c2"),
    416                                              ("cfb", "c3")])
    417     ds2 = table.parallel_scan_range("row_start", "row_end",
    418                                     cfa=["c1", "c2"], cfb="c3")
    419     ```
    420 
    421     Note: only the latest value of a cell will be retrieved.
    422 
    423     Args:
    424       start: The start of the range when scanning by range.
    425       end: (Optional.) The end of the range when scanning by range.
    426       num_parallel_scans: (Optional.) The number of concurrent scans against the
    427         Cloud Bigtable instance.
    428       probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
    429         A non-1 value indicates to probabilistically sample rows with the
    430         provided probability.
    431       columns: The columns to read. Note: most commonly, they are expressed as
    432         kwargs. Use the columns value if you are using column families that are
    433         reserved. The value of columns and kwargs are merged. Columns is a list
    434         of tuples of strings ("column_family", "column_qualifier").
    435       **kwargs: The column families and columns to read. Keys are treated as
    436         column_families, and values can be either lists of strings, or strings
    437         that are treated as the column qualifier (column name).
    438 
    439     Returns:
    440       A `tf.data.Dataset` returning the row keys and the cell contents.
    441 
    442     Raises:
    443       ValueError: If the configured probability is unexpected.
    444     """
    445     probability = _normalize_probability(probability)
    446     normalized = _normalize_columns(columns, kwargs)
    447     ds = dataset_ops.DatasetV1Adapter(
    448         _BigtableSampleKeyPairsDataset(self, "", start, end))
    449     return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
    450                                             normalized)
    451 
    452   def write(self, dataset, column_families, columns, timestamp=None):
    453     """Writes a dataset to the table.
    454 
    455     Args:
    456       dataset: A `tf.data.Dataset` to be written to this table. It must produce
    457         a list of number-of-columns+1 elements, all of which must be strings.
    458         The first value will be used as the row key, and subsequent values will
    459         be used as cell values for the corresponding columns from the
    460         corresponding column_families and columns entries.
    461       column_families: A `tf.Tensor` of `tf.string`s corresponding to the
    462         column names to store the dataset's elements into.
    463       columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
    464         to store the dataset's elements into.
    465       timestamp: (Optional.) An int64 timestamp to write all the values at.
    466         Leave as None to use server-provided timestamps.
    467 
    468     Returns:
    469       A `tf.Operation` that can be run to perform the write.
    470 
    471     Raises:
    472       ValueError: If there are unexpected or incompatible types, or if the
    473         number of columns and column_families does not match the output of
    474         `dataset`.
    475     """
    476     if timestamp is None:
    477       timestamp = -1  # Bigtable server provided timestamp.
    478     for tensor_type in nest.flatten(
    479         dataset_ops.get_legacy_output_types(dataset)):
    480       if tensor_type != dtypes.string:
    481         raise ValueError("Not all elements of the dataset were `tf.string`")
    482     for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
    483       if not shape.is_compatible_with(tensor_shape.scalar()):
    484         raise ValueError("Not all elements of the dataset were scalars")
    485     if len(column_families) != len(columns):
    486       raise ValueError("len(column_families) != len(columns)")
    487     if len(nest.flatten(
    488         dataset_ops.get_legacy_output_types(dataset))) != len(columns) + 1:
    489       raise ValueError("A column name must be specified for every component of "
    490                        "the dataset elements. (e.g.: len(columns) != "
    491                        "len(dataset.output_types))")
    492     return gen_bigtable_ops.dataset_to_bigtable(
    493         self._resource,
    494         dataset._variant_tensor,  # pylint: disable=protected-access
    495         column_families,
    496         columns,
    497         timestamp)
    498 
    499   def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
    500                                   normalized_probability, normalized_columns):
    501     """Builds a parallel dataset from a given range.
    502 
    503     Args:
    504       ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
    505       num_parallel_scans: The number of concurrent parallel scans to use.
    506       normalized_probability: A number between 0 and 1 for the keep probability.
    507       normalized_columns: The column families and column qualifiers to retrieve.
    508 
    509     Returns:
    510       A `tf.data.Dataset` representing the result of the parallel scan.
    511     """
    512     if num_parallel_scans is None:
    513       num_parallel_scans = 50
    514 
    515     ds = ds.shuffle(buffer_size=10000)  # TODO(saeta): Make configurable.
    516 
    517     def _interleave_fn(start, end):
    518       return _BigtableScanDataset(
    519           self,
    520           prefix="",
    521           start=start,
    522           end=end,
    523           normalized=normalized_columns,
    524           probability=normalized_probability)
    525 
    526     # Note prefetch_input_elements must be set in order to avoid rpc timeouts.
    527     ds = ds.apply(
    528         interleave_ops.parallel_interleave(
    529             _interleave_fn,
    530             cycle_length=num_parallel_scans,
    531             sloppy=True,
    532             prefetch_input_elements=1))
    533     return ds
    534 
    535 
    536 def _normalize_probability(probability):
    537   if probability is None:
    538     probability = 1.0
    539   if isinstance(probability, float) and (probability <= 0.0 or
    540                                          probability > 1.0):
    541     raise ValueError("probability must be in the range (0, 1].")
    542   return probability
    543 
    544 
    545 def _normalize_columns(columns, provided_kwargs):
    546   """Converts arguments (columns, and kwargs dict) to C++ representation.
    547 
    548   Args:
    549     columns: a datastructure containing the column families and qualifier to
    550       retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
    551       strings.
    552     provided_kwargs: a dictionary containing the column families and qualifiers
    553       to retrieve
    554 
    555   Returns:
    556     A list of pairs of column family+qualifier to retrieve.
    557 
    558   Raises:
    559     ValueError: If there are no cells to retrieve or the columns are in an
    560       incorrect format.
    561   """
    562   normalized = columns
    563   if normalized is None:
    564     normalized = []
    565   if isinstance(normalized, tuple):
    566     if len(normalized) == 2:
    567       normalized = [normalized]
    568     else:
    569       raise ValueError("columns was a tuple of inappropriate length")
    570   for key, value in iteritems(provided_kwargs):
    571     if key == "name":
    572       continue
    573     if isinstance(value, string_types):
    574       normalized.append((key, value))
    575       continue
    576     for col in value:
    577       normalized.append((key, col))
    578   if not normalized:
    579     raise ValueError("At least one column + column family must be specified.")
    580   return normalized
    581 
    582 
    583 class _BigtableKeyDataset(dataset_ops.DatasetSource):
    584   """_BigtableKeyDataset is an abstract class representing the keys of a table.
    585   """
    586 
    587   def __init__(self, table, variant_tensor):
    588     """Constructs a _BigtableKeyDataset.
    589 
    590     Args:
    591       table: a Bigtable class.
    592       variant_tensor: DT_VARIANT representation of the dataset.
    593     """
    594     super(_BigtableKeyDataset, self).__init__(variant_tensor)
    595     self._table = table
    596 
    597   @property
    598   def _element_structure(self):
    599     return structure.TensorStructure(dtypes.string, [])
    600 
    601 
    602 class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
    603   """_BigtablePrefixKeyDataset represents looking up keys by prefix.
    604   """
    605 
    606   def __init__(self, table, prefix):
    607     self._prefix = prefix
    608     variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset(
    609         table=table._resource,  # pylint: disable=protected-access
    610         prefix=self._prefix)
    611     super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor)
    612 
    613 
    614 class _BigtableRangeKeyDataset(_BigtableKeyDataset):
    615   """_BigtableRangeKeyDataset represents looking up keys by range.
    616   """
    617 
    618   def __init__(self, table, start, end):
    619     self._start = start
    620     self._end = end
    621     variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset(
    622         table=table._resource,  # pylint: disable=protected-access
    623         start_key=self._start,
    624         end_key=self._end)
    625     super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor)
    626 
    627 
    628 class _BigtableSampleKeysDataset(_BigtableKeyDataset):
    629   """_BigtableSampleKeysDataset represents a sampling of row keys.
    630   """
    631 
    632   # TODO(saeta): Expose the data size offsets into the keys.
    633 
    634   def __init__(self, table):
    635     variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset(
    636         table=table._resource)  # pylint: disable=protected-access
    637     super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor)
    638 
    639 
    640 class _BigtableLookupDataset(dataset_ops.DatasetSource):
    641   """_BigtableLookupDataset represents a dataset that retrieves values for keys.
    642   """
    643 
    644   def __init__(self, dataset, table, normalized):
    645     self._num_outputs = len(normalized) + 1  # 1 for row key
    646     self._dataset = dataset
    647     self._table = table
    648     self._normalized = normalized
    649     self._column_families = [i[0] for i in normalized]
    650     self._columns = [i[1] for i in normalized]
    651     variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset(
    652         keys_dataset=self._dataset._variant_tensor,  # pylint: disable=protected-access
    653         table=self._table._resource,  # pylint: disable=protected-access
    654         column_families=self._column_families,
    655         columns=self._columns)
    656     super(_BigtableLookupDataset, self).__init__(variant_tensor)
    657 
    658   @property
    659   def _element_structure(self):
    660     return structure.NestedStructure(tuple(
    661         [structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
    662 
    663 
    664 class _BigtableScanDataset(dataset_ops.DatasetSource):
    665   """_BigtableScanDataset represents a dataset that retrieves keys and values.
    666   """
    667 
    668   def __init__(self, table, prefix, start, end, normalized, probability):
    669     self._table = table
    670     self._prefix = prefix
    671     self._start = start
    672     self._end = end
    673     self._column_families = [i[0] for i in normalized]
    674     self._columns = [i[1] for i in normalized]
    675     self._probability = probability
    676     self._num_outputs = len(normalized) + 1  # 1 for row key
    677     variant_tensor = gen_bigtable_ops.bigtable_scan_dataset(
    678         table=self._table._resource,  # pylint: disable=protected-access
    679         prefix=self._prefix,
    680         start_key=self._start,
    681         end_key=self._end,
    682         column_families=self._column_families,
    683         columns=self._columns,
    684         probability=self._probability)
    685     super(_BigtableScanDataset, self).__init__(variant_tensor)
    686 
    687   @property
    688   def _element_structure(self):
    689     return structure.NestedStructure(
    690         tuple(
    691             [structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
    692 
    693 
    694 class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
    695   """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
    696   """
    697 
    698   def __init__(self, table, prefix, start, end):
    699     self._table = table
    700     self._prefix = prefix
    701     self._start = start
    702     self._end = end
    703     variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
    704         table=self._table._resource,  # pylint: disable=protected-access
    705         prefix=self._prefix,
    706         start_key=self._start,
    707         end_key=self._end)
    708     super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
    709 
    710   @property
    711   def _element_structure(self):
    712     return structure.NestedStructure(
    713         (structure.TensorStructure(dtypes.string, []),
    714          structure.TensorStructure(dtypes.string, [])))
    715