1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """logit_fn tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.estimator.python.estimator import logit_fns 22 from tensorflow.python.client import session 23 from tensorflow.python.estimator import model_fn 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.platform import test 26 27 28 class LogitFnTest(test.TestCase): 29 30 def test_simple_call_logit_fn(self): 31 def dummy_logit_fn(features, mode): 32 if mode == model_fn.ModeKeys.TRAIN: 33 return features['f1'] 34 else: 35 return features['f2'] 36 features = { 37 'f1': constant_op.constant([[2., 3.]]), 38 'f2': constant_op.constant([[4., 5.]]) 39 } 40 logit_fn_result = logit_fns.call_logit_fn( 41 dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params', 42 'fake_config') 43 with session.Session(): 44 self.assertAllClose([[4., 5.]], logit_fn_result.eval()) 45 46 def test_simple_call_multi_logit_fn(self): 47 48 def dummy_logit_fn(features): 49 return {u'head1': features['f1'], 'head2': features['f2']} 50 51 features = { 52 'f1': constant_op.constant([[2., 3.]]), 53 'f2': constant_op.constant([[4., 5.]]) 54 } 55 logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features, 56 model_fn.ModeKeys.TRAIN, 57 'fake_params', 'fake_config') 58 with session.Session(): 59 self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval()) 60 self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval()) 61 62 def test_invalid_logit_fn_results(self): 63 64 def invalid_logit_fn(features, params): 65 return [ 66 features['f1'] * params['input_multiplier'], 67 features['f2'] * params['input_multiplier'] 68 ] 69 70 features = { 71 'f1': constant_op.constant([[2., 3.]]), 72 'f2': constant_op.constant([[4., 5.]]) 73 } 74 params = {'learning_rate': 0.001, 'input_multiplier': 2.0} 75 with self.assertRaisesRegexp( 76 ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 77 'strings to Tensors'): 78 logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, 79 'fake_config') 80 81 def test_invalid_logit_fn_results_dict(self): 82 83 def invalid_logit_fn(features): 84 return {'head1': features['f1'], 'head2': features['f2']} 85 86 features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'} 87 with self.assertRaisesRegexp( 88 ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 89 'strings to Tensors'): 90 logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', 91 'fake_params', 'fake_config') 92 93 94 if __name__ == '__main__': 95 test.main() 96