Home | History | Annotate | Download | only in learn_io
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """tf.learn IO operation tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import random
     22 
     23 # pylint: disable=wildcard-import
     24 from tensorflow.contrib.learn.python import learn
     25 from tensorflow.contrib.learn.python.learn import datasets
     26 from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
     27 from tensorflow.contrib.learn.python.learn.learn_io import *
     28 from tensorflow.python.platform import test
     29 
     30 # pylint: enable=wildcard-import
     31 
     32 
     33 class IOTest(test.TestCase):
     34   # pylint: disable=undefined-variable
     35   """tf.learn IO operation tests."""
     36 
     37   def test_pandas_dataframe(self):
     38     if HAS_PANDAS:
     39       import pandas as pd  # pylint: disable=g-import-not-at-top
     40       random.seed(42)
     41       iris = datasets.load_iris()
     42       data = pd.DataFrame(iris.data)
     43       labels = pd.DataFrame(iris.target)
     44       classifier = learn.LinearClassifier(
     45           feature_columns=learn.infer_real_valued_columns_from_input(data),
     46           n_classes=3)
     47       classifier.fit(data, labels, steps=100)
     48       score = accuracy_score(labels[0], list(classifier.predict_classes(data)))
     49       self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
     50     else:
     51       print("No pandas installed. pandas-related tests are skipped.")
     52 
     53   def test_pandas_series(self):
     54     if HAS_PANDAS:
     55       import pandas as pd  # pylint: disable=g-import-not-at-top
     56       random.seed(42)
     57       iris = datasets.load_iris()
     58       data = pd.DataFrame(iris.data)
     59       labels = pd.Series(iris.target)
     60       classifier = learn.LinearClassifier(
     61           feature_columns=learn.infer_real_valued_columns_from_input(data),
     62           n_classes=3)
     63       classifier.fit(data, labels, steps=100)
     64       score = accuracy_score(labels, list(classifier.predict_classes(data)))
     65       self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
     66 
     67   def test_string_data_formats(self):
     68     if HAS_PANDAS:
     69       import pandas as pd  # pylint: disable=g-import-not-at-top
     70       with self.assertRaises(ValueError):
     71         learn.io.extract_pandas_data(pd.DataFrame({"Test": ["A", "B"]}))
     72       with self.assertRaises(ValueError):
     73         learn.io.extract_pandas_labels(pd.DataFrame({"Test": ["A", "B"]}))
     74 
     75   def test_dask_io(self):
     76     if HAS_DASK and HAS_PANDAS:
     77       import pandas as pd  # pylint: disable=g-import-not-at-top
     78       import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
     79       # test dask.dataframe
     80       df = pd.DataFrame(
     81           dict(
     82               a=list("aabbcc"), b=list(range(6))),
     83           index=pd.date_range(
     84               start="20100101", periods=6))
     85       ddf = dd.from_pandas(df, npartitions=3)
     86       extracted_ddf = extract_dask_data(ddf)
     87       self.assertEqual(
     88           extracted_ddf.divisions, (0, 2, 4, 6),
     89           "Failed with divisions = {0}".format(extracted_ddf.divisions))
     90       self.assertEqual(
     91           extracted_ddf.columns.tolist(), ["a", "b"],
     92           "Failed with columns = {0}".format(extracted_ddf.columns))
     93       # test dask.series
     94       labels = ddf["a"]
     95       extracted_labels = extract_dask_labels(labels)
     96       self.assertEqual(
     97           extracted_labels.divisions, (0, 2, 4, 6),
     98           "Failed with divisions = {0}".format(extracted_labels.divisions))
     99       # labels should only have one column
    100       with self.assertRaises(ValueError):
    101         extract_dask_labels(ddf)
    102     else:
    103       print("No dask installed. dask-related tests are skipped.")
    104 
    105   def test_dask_iris_classification(self):
    106     if HAS_DASK and HAS_PANDAS:
    107       import pandas as pd  # pylint: disable=g-import-not-at-top
    108       import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
    109       random.seed(42)
    110       iris = datasets.load_iris()
    111       data = pd.DataFrame(iris.data)
    112       data = dd.from_pandas(data, npartitions=2)
    113       labels = pd.DataFrame(iris.target)
    114       labels = dd.from_pandas(labels, npartitions=2)
    115       classifier = learn.LinearClassifier(
    116           feature_columns=learn.infer_real_valued_columns_from_input(data),
    117           n_classes=3)
    118       classifier.fit(data, labels, steps=100)
    119       predictions = data.map_partitions(classifier.predict).compute()
    120       score = accuracy_score(labels.compute(), predictions)
    121       self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
    122 
    123 
    124 if __name__ == "__main__":
    125   test.main()
    126