Home | History | Annotate | Download | only in estimators
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utils for Estimator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.util import tf_inspect
     22 
     23 
     24 def assert_estimator_contract(tester, estimator_class):
     25   """Asserts whether given estimator satisfies the expected contract.
     26 
     27   This doesn't check every details of contract. This test is used for that a
     28   function is not forgotten to implement in a precanned Estimator.
     29 
     30   Args:
     31     tester: A tf.test.TestCase.
     32     estimator_class: 'type' object of pre-canned estimator.
     33   """
     34   attributes = tf_inspect.getmembers(estimator_class)
     35   attribute_names = [a[0] for a in attributes]
     36 
     37   tester.assertTrue('config' in attribute_names)
     38   tester.assertTrue('evaluate' in attribute_names)
     39   tester.assertTrue('export' in attribute_names)
     40   tester.assertTrue('fit' in attribute_names)
     41   tester.assertTrue('get_variable_names' in attribute_names)
     42   tester.assertTrue('get_variable_value' in attribute_names)
     43   tester.assertTrue('model_dir' in attribute_names)
     44   tester.assertTrue('predict' in attribute_names)
     45 
     46 
     47 def assert_in_range(min_value, max_value, key, metrics):
     48   actual_value = metrics[key]
     49   if actual_value < min_value:
     50     raise ValueError('%s: %s < %s.' % (key, actual_value, min_value))
     51   if actual_value > max_value:
     52     raise ValueError('%s: %s > %s.' % (key, actual_value, max_value))
     53