1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import six 21 22 import numpy as np 23 from tensorflow.python.platform import test 24 from tensorflow.contrib.learn.python.learn import datasets 25 from tensorflow.contrib.learn.python.learn.datasets import synthetic 26 27 28 class SyntheticTest(test.TestCase): 29 """Test synthetic dataset generation""" 30 31 def test_make_dataset(self): 32 """Test if the synthetic routine wrapper complains about the name""" 33 self.assertRaises( 34 ValueError, datasets.make_dataset, name='_non_existing_name') 35 36 def test_all_datasets_callable(self): 37 """Test if all methods inside the `SYNTHETIC` are callable""" 38 self.assertIsInstance(datasets.SYNTHETIC, dict) 39 if len(datasets.SYNTHETIC) > 0: 40 for name, method in six.iteritems(datasets.SYNTHETIC): 41 self.assertTrue(callable(method)) 42 43 def test_circles(self): 44 """Test if the circles are generated correctly 45 46 Tests: 47 - return type is `Dataset` 48 - returned `data` shape is (n_samples, n_features) 49 - returned `target` shape is (n_samples,) 50 - set of unique classes range is [0, n_classes) 51 52 TODO: 53 - all points have the same radius, if no `noise` specified 54 """ 55 n_samples = 100 56 n_classes = 2 57 circ = synthetic.circles( 58 n_samples=n_samples, noise=None, n_classes=n_classes) 59 self.assertIsInstance(circ, datasets.base.Dataset) 60 self.assertTupleEqual(circ.data.shape, (n_samples, 2)) 61 self.assertTupleEqual(circ.target.shape, (n_samples,)) 62 self.assertSetEqual(set(circ.target), set(range(n_classes))) 63 64 def test_circles_replicable(self): 65 """Test if the data generation is replicable with a specified `seed` 66 67 Tests: 68 - return the same value if raised with the same seed 69 - return different values if noise or seed is different 70 """ 71 seed = 42 72 noise = 0.1 73 circ0 = synthetic.circles( 74 n_samples=100, noise=noise, n_classes=2, seed=seed) 75 circ1 = synthetic.circles( 76 n_samples=100, noise=noise, n_classes=2, seed=seed) 77 np.testing.assert_array_equal(circ0.data, circ1.data) 78 np.testing.assert_array_equal(circ0.target, circ1.target) 79 80 circ1 = synthetic.circles( 81 n_samples=100, noise=noise, n_classes=2, seed=seed + 1) 82 self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, 83 circ1.data) 84 self.assertRaises(AssertionError, np.testing.assert_array_equal, 85 circ0.target, circ1.target) 86 87 circ1 = synthetic.circles( 88 n_samples=100, noise=noise / 2., n_classes=2, seed=seed) 89 self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, 90 circ1.data) 91 92 def test_spirals(self): 93 """Test if the circles are generated correctly 94 95 Tests: 96 - if mode is unknown, ValueError is raised 97 - return type is `Dataset` 98 - returned `data` shape is (n_samples, n_features) 99 - returned `target` shape is (n_samples,) 100 - set of unique classes range is [0, n_classes) 101 """ 102 self.assertRaises( 103 ValueError, synthetic.spirals, mode='_unknown_mode_spiral_') 104 n_samples = 100 105 modes = ('archimedes', 'bernoulli', 'fermat') 106 for mode in modes: 107 spir = synthetic.spirals(n_samples=n_samples, noise=None, mode=mode) 108 self.assertIsInstance(spir, datasets.base.Dataset) 109 self.assertTupleEqual(spir.data.shape, (n_samples, 2)) 110 self.assertTupleEqual(spir.target.shape, (n_samples,)) 111 self.assertSetEqual(set(spir.target), set(range(2))) 112 113 def test_spirals_replicable(self): 114 """Test if the data generation is replicable with a specified `seed` 115 116 Tests: 117 - return the same value if raised with the same seed 118 - return different values if noise or seed is different 119 """ 120 seed = 42 121 noise = 0.1 122 modes = ('archimedes', 'bernoulli', 'fermat') 123 for mode in modes: 124 spir0 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed) 125 spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed) 126 np.testing.assert_array_equal(spir0.data, spir1.data) 127 np.testing.assert_array_equal(spir0.target, spir1.target) 128 129 spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed + 1) 130 self.assertRaises(AssertionError, np.testing.assert_array_equal, 131 spir0.data, spir1.data) 132 self.assertRaises(AssertionError, np.testing.assert_array_equal, 133 spir0.target, spir1.target) 134 135 spir1 = synthetic.spirals(n_samples=1000, noise=noise / 2., seed=seed) 136 self.assertRaises(AssertionError, np.testing.assert_array_equal, 137 spir0.data, spir1.data) 138 139 def test_spirals_synthetic(self): 140 synthetic.spirals(3) 141 142 143 if __name__ == '__main__': 144 test.main() 145