Home | History | Annotate | Download | only in datasets
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Dataset utilities and synthetic/reference datasets."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import csv
     22 from os import path
     23 
     24 import numpy as np
     25 
     26 from tensorflow.contrib.learn.python.learn.datasets import base
     27 from tensorflow.contrib.learn.python.learn.datasets import mnist
     28 from tensorflow.contrib.learn.python.learn.datasets import synthetic
     29 from tensorflow.contrib.learn.python.learn.datasets import text_datasets
     30 
     31 # Export load_iris and load_boston.
     32 load_iris = base.load_iris
     33 load_boston = base.load_boston
     34 
     35 # List of all available datasets.
     36 # Note, currently they may return different types.
     37 DATASETS = {
     38     # Returns base.Dataset.
     39     'iris': base.load_iris,
     40     'boston': base.load_boston,
     41     # Returns base.Datasets (train/validation/test sets).
     42     'mnist': mnist.load_mnist,
     43     'dbpedia': text_datasets.load_dbpedia,
     44 }
     45 
     46 # List of all synthetic datasets
     47 SYNTHETIC = {
     48     # All of these will return ['data', 'target'] -> base.Dataset
     49     'circles': synthetic.circles,
     50     'spirals': synthetic.spirals
     51 }
     52 
     53 
     54 def load_dataset(name, size='small', test_with_fake_data=False):
     55   """Loads dataset by name.
     56 
     57   Args:
     58     name: Name of the dataset to load.
     59     size: Size of the dataset to load.
     60     test_with_fake_data: If true, load with fake dataset.
     61 
     62   Returns:
     63     Features and labels for given dataset. Can be numpy or iterator.
     64 
     65   Raises:
     66     ValueError: if `name` is not found.
     67   """
     68   if name not in DATASETS:
     69     raise ValueError('Name of dataset is not found: %s' % name)
     70   if name == 'dbpedia':
     71     return DATASETS[name](size, test_with_fake_data)
     72   else:
     73     return DATASETS[name]()
     74 
     75 
     76 def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs):
     77   """Creates binary synthetic datasets
     78 
     79   Args:
     80     name: str, name of the dataset to generate
     81     n_samples: int, number of datapoints to generate
     82     noise: float or None, standard deviation of the Gaussian noise added
     83     seed: int or None, seed for noise
     84 
     85   Returns:
     86     Shuffled features and labels for given synthetic dataset of type
     87     `base.Dataset`
     88 
     89   Raises:
     90     ValueError: Raised if `name` not found
     91 
     92   Note:
     93     - This is a generic synthetic data generator - individual generators might
     94     have more parameters!
     95       See documentation for individual parameters
     96     - Note that the `noise` parameter uses `numpy.random.normal` and depends on
     97     `numpy`'s seed
     98 
     99   TODO:
    100     - Support multiclass datasets
    101     - Need shuffling routine. Currently synthetic datasets are reshuffled to
    102     avoid train/test correlation,
    103       but that hurts reprodusability
    104   """
    105   # seed = kwargs.pop('seed', None)
    106   if name not in SYNTHETIC:
    107     raise ValueError('Synthetic dataset not found or not implemeted: %s' % name)
    108   else:
    109     return SYNTHETIC[name](
    110         n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs)
    111