Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """logit_fn tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.estimator.python.estimator import logit_fns
     22 from tensorflow.python.client import session
     23 from tensorflow.python.estimator import model_fn
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.platform import test
     26 
     27 
     28 class LogitFnTest(test.TestCase):
     29 
     30   def test_simple_call_logit_fn(self):
     31     def dummy_logit_fn(features, mode):
     32       if mode == model_fn.ModeKeys.TRAIN:
     33         return features['f1']
     34       else:
     35         return features['f2']
     36     features = {
     37         'f1': constant_op.constant([[2., 3.]]),
     38         'f2': constant_op.constant([[4., 5.]])
     39     }
     40     logit_fn_result = logit_fns.call_logit_fn(
     41         dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
     42         'fake_config')
     43     with session.Session():
     44       self.assertAllClose([[4., 5.]], logit_fn_result.eval())
     45 
     46   def test_simple_call_multi_logit_fn(self):
     47 
     48     def dummy_logit_fn(features):
     49       return {u'head1': features['f1'], 'head2': features['f2']}
     50 
     51     features = {
     52         'f1': constant_op.constant([[2., 3.]]),
     53         'f2': constant_op.constant([[4., 5.]])
     54     }
     55     logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features,
     56                                               model_fn.ModeKeys.TRAIN,
     57                                               'fake_params', 'fake_config')
     58     with session.Session():
     59       self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval())
     60       self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval())
     61 
     62   def test_invalid_logit_fn_results(self):
     63 
     64     def invalid_logit_fn(features, params):
     65       return [
     66           features['f1'] * params['input_multiplier'],
     67           features['f2'] * params['input_multiplier']
     68       ]
     69 
     70     features = {
     71         'f1': constant_op.constant([[2., 3.]]),
     72         'f2': constant_op.constant([[4., 5.]])
     73     }
     74     params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
     75     with self.assertRaisesRegexp(
     76         ValueError, 'logit_fn should return a Tensor or a dictionary mapping '
     77                     'strings to Tensors'):
     78       logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
     79                               'fake_config')
     80 
     81   def test_invalid_logit_fn_results_dict(self):
     82 
     83     def invalid_logit_fn(features):
     84       return {'head1': features['f1'], 'head2': features['f2']}
     85 
     86     features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'}
     87     with self.assertRaisesRegexp(
     88         ValueError, 'logit_fn should return a Tensor or a dictionary mapping '
     89                     'strings to Tensors'):
     90       logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode',
     91                               'fake_params', 'fake_config')
     92 
     93 
     94 if __name__ == '__main__':
     95   test.main()
     96