1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Dataset utilities and synthetic/reference datasets.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import csv 22 from os import path 23 24 import numpy as np 25 26 from tensorflow.contrib.learn.python.learn.datasets import base 27 from tensorflow.contrib.learn.python.learn.datasets import mnist 28 from tensorflow.contrib.learn.python.learn.datasets import synthetic 29 from tensorflow.contrib.learn.python.learn.datasets import text_datasets 30 31 # Export load_iris and load_boston. 32 load_iris = base.load_iris 33 load_boston = base.load_boston 34 35 # List of all available datasets. 36 # Note, currently they may return different types. 37 DATASETS = { 38 # Returns base.Dataset. 39 'iris': base.load_iris, 40 'boston': base.load_boston, 41 # Returns base.Datasets (train/validation/test sets). 42 'mnist': mnist.load_mnist, 43 'dbpedia': text_datasets.load_dbpedia, 44 } 45 46 # List of all synthetic datasets 47 SYNTHETIC = { 48 # All of these will return ['data', 'target'] -> base.Dataset 49 'circles': synthetic.circles, 50 'spirals': synthetic.spirals 51 } 52 53 54 def load_dataset(name, size='small', test_with_fake_data=False): 55 """Loads dataset by name. 56 57 Args: 58 name: Name of the dataset to load. 59 size: Size of the dataset to load. 60 test_with_fake_data: If true, load with fake dataset. 61 62 Returns: 63 Features and labels for given dataset. Can be numpy or iterator. 64 65 Raises: 66 ValueError: if `name` is not found. 67 """ 68 if name not in DATASETS: 69 raise ValueError('Name of dataset is not found: %s' % name) 70 if name == 'dbpedia': 71 return DATASETS[name](size, test_with_fake_data) 72 else: 73 return DATASETS[name]() 74 75 76 def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): 77 """Creates binary synthetic datasets 78 79 Args: 80 name: str, name of the dataset to generate 81 n_samples: int, number of datapoints to generate 82 noise: float or None, standard deviation of the Gaussian noise added 83 seed: int or None, seed for noise 84 85 Returns: 86 Shuffled features and labels for given synthetic dataset of type 87 `base.Dataset` 88 89 Raises: 90 ValueError: Raised if `name` not found 91 92 Note: 93 - This is a generic synthetic data generator - individual generators might 94 have more parameters! 95 See documentation for individual parameters 96 - Note that the `noise` parameter uses `numpy.random.normal` and depends on 97 `numpy`'s seed 98 99 TODO: 100 - Support multiclass datasets 101 - Need shuffling routine. Currently synthetic datasets are reshuffled to 102 avoid train/test correlation, 103 but that hurts reprodusability 104 """ 105 # seed = kwargs.pop('seed', None) 106 if name not in SYNTHETIC: 107 raise ValueError('Synthetic dataset not found or not implemeted: %s' % name) 108 else: 109 return SYNTHETIC[name]( 110 n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs) 111