1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for linear.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.estimator.canned import linear 22 from tensorflow.python.estimator.canned import linear_testing_utils 23 from tensorflow.python.platform import test 24 25 26 def _linear_regressor_fn(*args, **kwargs): 27 return linear.LinearRegressor(*args, **kwargs) 28 29 30 def _linear_classifier_fn(*args, **kwargs): 31 return linear.LinearClassifier(*args, **kwargs) 32 33 34 # Tests for Linear Regressor. 35 36 37 class LinearRegressorPartitionerTest( 38 linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase): 39 40 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 41 test.TestCase.__init__(self, methodName) 42 linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__( 43 self, _linear_regressor_fn) 44 45 46 class LinearRegressorEvaluationTest( 47 linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): 48 49 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 50 test.TestCase.__init__(self, methodName) 51 linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( 52 self, _linear_regressor_fn) 53 54 55 class LinearRegressorPredictTest( 56 linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): 57 58 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 59 test.TestCase.__init__(self, methodName) 60 linear_testing_utils.BaseLinearRegressorPredictTest.__init__( 61 self, _linear_regressor_fn) 62 63 64 class LinearRegressorIntegrationTest( 65 linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase): 66 67 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 68 test.TestCase.__init__(self, methodName) 69 linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__( 70 self, _linear_regressor_fn) 71 72 73 class LinearRegressorTrainingTest( 74 linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): 75 76 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 77 test.TestCase.__init__(self, methodName) 78 linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( 79 self, _linear_regressor_fn) 80 81 82 # Tests for Linear Classifier. 83 84 85 class LinearClassifierTrainingTest( 86 linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase): 87 88 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 89 test.TestCase.__init__(self, methodName) 90 linear_testing_utils.BaseLinearClassifierTrainingTest.__init__( 91 self, linear_classifier_fn=_linear_classifier_fn) 92 93 94 class LinearClassifierEvaluationTest( 95 linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase): 96 97 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 98 test.TestCase.__init__(self, methodName) 99 linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__( 100 self, linear_classifier_fn=_linear_classifier_fn) 101 102 103 class LinearClassifierPredictTest( 104 linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase): 105 106 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 107 test.TestCase.__init__(self, methodName) 108 linear_testing_utils.BaseLinearClassifierPredictTest.__init__( 109 self, linear_classifier_fn=_linear_classifier_fn) 110 111 112 class LinearClassifierIntegrationTest( 113 linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase): 114 115 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 116 test.TestCase.__init__(self, methodName) 117 linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__( 118 self, linear_classifier_fn=_linear_classifier_fn) 119 120 121 # Tests for Linear logit_fn. 122 class LinearLogitFnTest(linear_testing_utils.BaseLinearLogitFnTest, 123 test.TestCase): 124 125 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 126 test.TestCase.__init__(self, methodName) 127 linear_testing_utils.BaseLinearLogitFnTest.__init__(self) 128 129 130 # Tests for warm-starting with Linear logit_fn. 131 class LinearWarmStartingTest(linear_testing_utils.BaseLinearWarmStartingTest, 132 test.TestCase): 133 134 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 135 test.TestCase.__init__(self, methodName) 136 linear_testing_utils.BaseLinearWarmStartingTest.__init__( 137 self, _linear_classifier_fn, _linear_regressor_fn) 138 139 140 if __name__ == '__main__': 141 test.main() 142