Home | History | Annotate | Download | only in datasets
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import six
     21 
     22 import numpy as np
     23 from tensorflow.python.platform import test
     24 from tensorflow.contrib.learn.python.learn import datasets
     25 from tensorflow.contrib.learn.python.learn.datasets import synthetic
     26 
     27 
     28 class SyntheticTest(test.TestCase):
     29   """Test synthetic dataset generation"""
     30 
     31   def test_make_dataset(self):
     32     """Test if the synthetic routine wrapper complains about the name"""
     33     self.assertRaises(
     34         ValueError, datasets.make_dataset, name='_non_existing_name')
     35 
     36   def test_all_datasets_callable(self):
     37     """Test if all methods inside the `SYNTHETIC` are callable"""
     38     self.assertIsInstance(datasets.SYNTHETIC, dict)
     39     if len(datasets.SYNTHETIC) > 0:
     40       for name, method in six.iteritems(datasets.SYNTHETIC):
     41         self.assertTrue(callable(method))
     42 
     43   def test_circles(self):
     44     """Test if the circles are generated correctly
     45 
     46     Tests:
     47       - return type is `Dataset`
     48       - returned `data` shape is (n_samples, n_features)
     49       - returned `target` shape is (n_samples,)
     50       - set of unique classes range is [0, n_classes)
     51 
     52     TODO:
     53       - all points have the same radius, if no `noise` specified
     54     """
     55     n_samples = 100
     56     n_classes = 2
     57     circ = synthetic.circles(
     58         n_samples=n_samples, noise=None, n_classes=n_classes)
     59     self.assertIsInstance(circ, datasets.base.Dataset)
     60     self.assertTupleEqual(circ.data.shape, (n_samples, 2))
     61     self.assertTupleEqual(circ.target.shape, (n_samples,))
     62     self.assertSetEqual(set(circ.target), set(range(n_classes)))
     63 
     64   def test_circles_replicable(self):
     65     """Test if the data generation is replicable with a specified `seed`
     66 
     67     Tests:
     68       - return the same value if raised with the same seed
     69       - return different values if noise or seed is different
     70     """
     71     seed = 42
     72     noise = 0.1
     73     circ0 = synthetic.circles(
     74         n_samples=100, noise=noise, n_classes=2, seed=seed)
     75     circ1 = synthetic.circles(
     76         n_samples=100, noise=noise, n_classes=2, seed=seed)
     77     np.testing.assert_array_equal(circ0.data, circ1.data)
     78     np.testing.assert_array_equal(circ0.target, circ1.target)
     79 
     80     circ1 = synthetic.circles(
     81         n_samples=100, noise=noise, n_classes=2, seed=seed + 1)
     82     self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data,
     83                       circ1.data)
     84     self.assertRaises(AssertionError, np.testing.assert_array_equal,
     85                       circ0.target, circ1.target)
     86 
     87     circ1 = synthetic.circles(
     88         n_samples=100, noise=noise / 2., n_classes=2, seed=seed)
     89     self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data,
     90                       circ1.data)
     91 
     92   def test_spirals(self):
     93     """Test if the circles are generated correctly
     94 
     95     Tests:
     96       - if mode is unknown, ValueError is raised
     97       - return type is `Dataset`
     98       - returned `data` shape is (n_samples, n_features)
     99       - returned `target` shape is (n_samples,)
    100       - set of unique classes range is [0, n_classes)
    101     """
    102     self.assertRaises(
    103         ValueError, synthetic.spirals, mode='_unknown_mode_spiral_')
    104     n_samples = 100
    105     modes = ('archimedes', 'bernoulli', 'fermat')
    106     for mode in modes:
    107       spir = synthetic.spirals(n_samples=n_samples, noise=None, mode=mode)
    108       self.assertIsInstance(spir, datasets.base.Dataset)
    109       self.assertTupleEqual(spir.data.shape, (n_samples, 2))
    110       self.assertTupleEqual(spir.target.shape, (n_samples,))
    111       self.assertSetEqual(set(spir.target), set(range(2)))
    112 
    113   def test_spirals_replicable(self):
    114     """Test if the data generation is replicable with a specified `seed`
    115 
    116     Tests:
    117       - return the same value if raised with the same seed
    118       - return different values if noise or seed is different
    119     """
    120     seed = 42
    121     noise = 0.1
    122     modes = ('archimedes', 'bernoulli', 'fermat')
    123     for mode in modes:
    124       spir0 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed)
    125       spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed)
    126       np.testing.assert_array_equal(spir0.data, spir1.data)
    127       np.testing.assert_array_equal(spir0.target, spir1.target)
    128 
    129       spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed + 1)
    130       self.assertRaises(AssertionError, np.testing.assert_array_equal,
    131                         spir0.data, spir1.data)
    132       self.assertRaises(AssertionError, np.testing.assert_array_equal,
    133                         spir0.target, spir1.target)
    134 
    135       spir1 = synthetic.spirals(n_samples=1000, noise=noise / 2., seed=seed)
    136       self.assertRaises(AssertionError, np.testing.assert_array_equal,
    137                         spir0.data, spir1.data)
    138 
    139   def test_spirals_synthetic(self):
    140     synthetic.spirals(3)
    141 
    142 
    143 if __name__ == '__main__':
    144   test.main()
    145