1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Contains code for the DataProvider. 16 17 A DataProvider is a class which provides some predefined data types from some 18 source (TFRecord, etc). The most basic function of a 19 data provider is the `Get` operation where one requests one or more types of 20 data, or 'items': 21 22 provider.get(items=['image', 'sentence', 'class']) 23 24 More concretely, a data provider (a subclass of BaseDataProvider) returns a 25 single tensor for each requested item (data type): 26 27 provider = MyDataProvider(...) 28 image, sentence, clazz = provider.get(['image', 'sentence', 'class']) 29 30 In this example, the provider `MyDataProvider` must know how to load each item. 31 A data provider may be written in a way that the logic necessary to map from 32 each item to tensor is completely encapsulated within the data_provider itself. 33 """ 34 35 from __future__ import absolute_import 36 from __future__ import division 37 from __future__ import print_function 38 39 import abc 40 41 42 class DataProvider(object): 43 """Maps a list of requested data items to tensors from a data source. 44 45 All data providers must inherit from DataProvider and implement the Get 46 method which returns arbitrary types of data. No assumption is made about the 47 source of the data nor the mechanism for providing it. 48 """ 49 __metaclass__ = abc.ABCMeta 50 51 def __init__(self, items_to_tensors, num_samples): 52 """Constructs the Data Provider. 53 54 Args: 55 items_to_tensors: a dictionary of names to tensors. 56 num_samples: the number of samples in the dataset being provided. 57 """ 58 self._items_to_tensors = items_to_tensors 59 self._num_samples = num_samples 60 61 def get(self, items): 62 """Returns a list of tensors specified by the given list of items. 63 64 The list of items is arbitrary different data providers satisfy different 65 lists of items. For example the Pascal VOC might accept items 'image' and 66 'semantics', whereas the NYUDepthV2 data provider might accept items 67 'image', 'depths' and 'normals'. 68 69 Args: 70 items: a list of strings, each of which indicate a particular data type. 71 72 Returns: 73 a list of tensors, whose length matches the length of `items`, where each 74 tensor corresponds to each item. 75 76 Raises: 77 ValueError: if any of the items cannot be satisfied. 78 """ 79 self._validate_items(items) 80 return [self._items_to_tensors[item] for item in items] 81 82 def list_items(self): 83 """Returns the list of item names that can be provided by the data provider. 84 85 Returns: 86 a list of item names that can be passed to Get([items]). 87 """ 88 return self._items_to_tensors.keys() 89 90 def num_samples(self): 91 """Returns the number of data samples in the dataset. 92 93 Returns: 94 a positive whole number. 95 """ 96 return self._num_samples 97 98 def _validate_items(self, items): 99 """Verifies that each given item is a member of the list from ListItems(). 100 101 Args: 102 items: a list or tuple of strings. 103 104 Raises: 105 ValueError: if `items` is not a tuple or list or if any of the elements of 106 `items` is not found in the list provided by self.ListItems(). 107 """ 108 if not isinstance(items, (list, tuple)): 109 raise ValueError('items must be a list or tuple') 110 111 valid_items = self.list_items() 112 for item in items: 113 if item not in valid_items: 114 raise ValueError('Item [%s] is invalid. Valid entries include: %s' % 115 (item, valid_items)) 116