1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for ExportStrategy.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.learn.python.learn import export_strategy 22 from tensorflow.python.platform import test 23 24 25 class ExportStrategyTest(test.TestCase): 26 27 def test_no_optional_args_export(self): 28 model_path = '/path/to/model' 29 def _export_fn(estimator, export_path): 30 self.assertTupleEqual((estimator, export_path), (None, None)) 31 return model_path 32 33 strategy = export_strategy.ExportStrategy('foo', _export_fn) 34 self.assertTupleEqual(strategy, ('foo', _export_fn, None)) 35 self.assertIs(strategy.export(None, None), model_path) 36 37 def test_checkpoint_export(self): 38 ckpt_model_path = '/path/to/checkpoint_model' 39 def _ckpt_export_fn(estimator, export_path, checkpoint_path): 40 self.assertTupleEqual((estimator, export_path), (None, None)) 41 self.assertEqual(checkpoint_path, 'checkpoint') 42 return ckpt_model_path 43 44 strategy = export_strategy.ExportStrategy('foo', _ckpt_export_fn) 45 self.assertTupleEqual(strategy, ('foo', _ckpt_export_fn, None)) 46 self.assertIs(strategy.export(None, None, 'checkpoint'), ckpt_model_path) 47 48 def test_checkpoint_eval_export(self): 49 ckpt_eval_model_path = '/path/to/checkpoint_eval_model' 50 def _ckpt_eval_export_fn(estimator, export_path, checkpoint_path, 51 eval_result): 52 self.assertTupleEqual((estimator, export_path), (None, None)) 53 self.assertEqual(checkpoint_path, 'checkpoint') 54 self.assertEqual(eval_result, 'eval') 55 return ckpt_eval_model_path 56 57 strategy = export_strategy.ExportStrategy('foo', _ckpt_eval_export_fn) 58 self.assertTupleEqual(strategy, ('foo', _ckpt_eval_export_fn, None)) 59 self.assertIs(strategy.export(None, None, 'checkpoint', 'eval'), 60 ckpt_eval_model_path) 61 62 def test_eval_only_export(self): 63 def _eval_export_fn(estimator, export_path, eval_result): 64 del estimator, export_path, eval_result 65 66 strategy = export_strategy.ExportStrategy('foo', _eval_export_fn) 67 self.assertTupleEqual(strategy, ('foo', _eval_export_fn, None)) 68 with self.assertRaisesRegexp(ValueError, 'An export_fn accepting ' 69 'eval_result must also accept ' 70 'checkpoint_path'): 71 strategy.export(None, None, eval_result='eval') 72 73 def test_strip_default_attr_export(self): 74 strip_default_attrs_model_path = '/path/to/strip_default_attrs_model' 75 def _strip_default_attrs_export_fn(estimator, export_path, 76 strip_default_attrs): 77 self.assertTupleEqual((estimator, export_path), (None, None)) 78 self.assertTrue(strip_default_attrs) 79 return strip_default_attrs_model_path 80 81 strategy = export_strategy.ExportStrategy('foo', 82 _strip_default_attrs_export_fn, 83 True) 84 self.assertTupleEqual(strategy, 85 ('foo', _strip_default_attrs_export_fn, True)) 86 self.assertIs(strategy.export(None, None), strip_default_attrs_model_path) 87 88 if __name__ == '__main__': 89 test.main() 90