Home | History | Annotate | Download | only in canned
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for linear.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.estimator.canned import linear
     22 from tensorflow.python.estimator.canned import linear_testing_utils
     23 from tensorflow.python.platform import test
     24 
     25 
     26 def _linear_regressor_fn(*args, **kwargs):
     27   return linear.LinearRegressor(*args, **kwargs)
     28 
     29 
     30 def _linear_classifier_fn(*args, **kwargs):
     31   return linear.LinearClassifier(*args, **kwargs)
     32 
     33 
     34 # Tests for Linear Regressor.
     35 
     36 
     37 class LinearRegressorPartitionerTest(
     38     linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
     39 
     40   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     41     test.TestCase.__init__(self, methodName)
     42     linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
     43         self, _linear_regressor_fn)
     44 
     45 
     46 class LinearRegressorEvaluationTest(
     47     linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
     48 
     49   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     50     test.TestCase.__init__(self, methodName)
     51     linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
     52         self, _linear_regressor_fn)
     53 
     54 
     55 class LinearRegressorPredictTest(
     56     linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
     57 
     58   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     59     test.TestCase.__init__(self, methodName)
     60     linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
     61         self, _linear_regressor_fn)
     62 
     63 
     64 class LinearRegressorIntegrationTest(
     65     linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
     66 
     67   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     68     test.TestCase.__init__(self, methodName)
     69     linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
     70         self, _linear_regressor_fn)
     71 
     72 
     73 class LinearRegressorTrainingTest(
     74     linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
     75 
     76   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     77     test.TestCase.__init__(self, methodName)
     78     linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
     79         self, _linear_regressor_fn)
     80 
     81 
     82 # Tests for Linear Classifier.
     83 
     84 
     85 class LinearClassifierTrainingTest(
     86     linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
     87 
     88   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     89     test.TestCase.__init__(self, methodName)
     90     linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
     91         self, linear_classifier_fn=_linear_classifier_fn)
     92 
     93 
     94 class LinearClassifierEvaluationTest(
     95     linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
     96 
     97   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     98     test.TestCase.__init__(self, methodName)
     99     linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
    100         self, linear_classifier_fn=_linear_classifier_fn)
    101 
    102 
    103 class LinearClassifierPredictTest(
    104     linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
    105 
    106   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    107     test.TestCase.__init__(self, methodName)
    108     linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
    109         self, linear_classifier_fn=_linear_classifier_fn)
    110 
    111 
    112 class LinearClassifierIntegrationTest(
    113     linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
    114 
    115   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    116     test.TestCase.__init__(self, methodName)
    117     linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
    118         self, linear_classifier_fn=_linear_classifier_fn)
    119 
    120 
    121 # Tests for Linear logit_fn.
    122 class LinearLogitFnTest(linear_testing_utils.BaseLinearLogitFnTest,
    123                         test.TestCase):
    124 
    125   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    126     test.TestCase.__init__(self, methodName)
    127     linear_testing_utils.BaseLinearLogitFnTest.__init__(self)
    128 
    129 
    130 # Tests for warm-starting with Linear logit_fn.
    131 class LinearWarmStartingTest(linear_testing_utils.BaseLinearWarmStartingTest,
    132                              test.TestCase):
    133 
    134   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    135     test.TestCase.__init__(self, methodName)
    136     linear_testing_utils.BaseLinearWarmStartingTest.__init__(
    137         self, _linear_classifier_fn, _linear_regressor_fn)
    138 
    139 
    140 if __name__ == '__main__':
    141   test.main()
    142