1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utils for Estimator.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.util import tf_inspect 22 23 24 def assert_estimator_contract(tester, estimator_class): 25 """Asserts whether given estimator satisfies the expected contract. 26 27 This doesn't check every details of contract. This test is used for that a 28 function is not forgotten to implement in a precanned Estimator. 29 30 Args: 31 tester: A tf.test.TestCase. 32 estimator_class: 'type' object of pre-canned estimator. 33 """ 34 attributes = tf_inspect.getmembers(estimator_class) 35 attribute_names = [a[0] for a in attributes] 36 37 tester.assertTrue('config' in attribute_names) 38 tester.assertTrue('evaluate' in attribute_names) 39 tester.assertTrue('export' in attribute_names) 40 tester.assertTrue('fit' in attribute_names) 41 tester.assertTrue('get_variable_names' in attribute_names) 42 tester.assertTrue('get_variable_value' in attribute_names) 43 tester.assertTrue('model_dir' in attribute_names) 44 tester.assertTrue('predict' in attribute_names) 45 46 47 def assert_in_range(min_value, max_value, key, metrics): 48 actual_value = metrics[key] 49 if actual_value < min_value: 50 raise ValueError('%s: %s < %s.' % (key, actual_value, min_value)) 51 if actual_value > max_value: 52 raise ValueError('%s: %s > %s.' % (key, actual_value, max_value)) 53