Home | History | Annotate | Download | only in data
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contains code for the DataProvider.
     16 
     17 A DataProvider is a class which provides some predefined data types from some
     18 source (TFRecord, etc). The most basic function of a
     19 data provider is the `Get` operation where one requests one or more types of
     20 data, or 'items':
     21 
     22   provider.get(items=['image', 'sentence', 'class'])
     23 
     24 More concretely, a data provider (a subclass of BaseDataProvider) returns a
     25 single tensor for each requested item (data type):
     26 
     27   provider = MyDataProvider(...)
     28   image, sentence, clazz = provider.get(['image', 'sentence', 'class'])
     29 
     30 In this example, the provider `MyDataProvider` must know how to load each item.
     31 A data provider may be written in a way that the logic necessary to map from
     32 each item to tensor is completely encapsulated within the data_provider itself.
     33 """
     34 
     35 from __future__ import absolute_import
     36 from __future__ import division
     37 from __future__ import print_function
     38 
     39 import abc
     40 
     41 
     42 class DataProvider(object):
     43   """Maps a list of requested data items to tensors from a data source.
     44 
     45   All data providers must inherit from DataProvider and implement the Get
     46   method which returns arbitrary types of data. No assumption is made about the
     47   source of the data nor the mechanism for providing it.
     48   """
     49   __metaclass__ = abc.ABCMeta
     50 
     51   def __init__(self, items_to_tensors, num_samples):
     52     """Constructs the Data Provider.
     53 
     54     Args:
     55       items_to_tensors: a dictionary of names to tensors.
     56       num_samples: the number of samples in the dataset being provided.
     57     """
     58     self._items_to_tensors = items_to_tensors
     59     self._num_samples = num_samples
     60 
     61   def get(self, items):
     62     """Returns a list of tensors specified by the given list of items.
     63 
     64     The list of items is arbitrary different data providers satisfy different
     65     lists of items. For example the Pascal VOC might accept items 'image' and
     66     'semantics', whereas the NYUDepthV2 data provider might accept items
     67     'image', 'depths' and 'normals'.
     68 
     69     Args:
     70       items: a list of strings, each of which indicate a particular data type.
     71 
     72     Returns:
     73       a list of tensors, whose length matches the length of `items`, where each
     74       tensor corresponds to each item.
     75 
     76     Raises:
     77       ValueError: if any of the items cannot be satisfied.
     78     """
     79     self._validate_items(items)
     80     return [self._items_to_tensors[item] for item in items]
     81 
     82   def list_items(self):
     83     """Returns the list of item names that can be provided by the data provider.
     84 
     85     Returns:
     86       a list of item names that can be passed to Get([items]).
     87     """
     88     return self._items_to_tensors.keys()
     89 
     90   def num_samples(self):
     91     """Returns the number of data samples in the dataset.
     92 
     93     Returns:
     94       a positive whole number.
     95     """
     96     return self._num_samples
     97 
     98   def _validate_items(self, items):
     99     """Verifies that each given item is a member of the list from ListItems().
    100 
    101     Args:
    102       items: a list or tuple of strings.
    103 
    104     Raises:
    105       ValueError: if `items` is not a tuple or list or if any of the elements of
    106         `items` is not found in the list provided by self.ListItems().
    107     """
    108     if not isinstance(items, (list, tuple)):
    109       raise ValueError('items must be a list or tuple')
    110 
    111     valid_items = self.list_items()
    112     for item in items:
    113       if item not in valid_items:
    114         raise ValueError('Item [%s] is invalid. Valid entries include: %s' %
    115                          (item, valid_items))
    116