Home | History | Annotate | Download | only in learn
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for ExportStrategy."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.learn.python.learn import export_strategy
     22 from tensorflow.python.platform import test
     23 
     24 
     25 class ExportStrategyTest(test.TestCase):
     26 
     27   def test_no_optional_args_export(self):
     28     model_path = '/path/to/model'
     29     def _export_fn(estimator, export_path):
     30       self.assertTupleEqual((estimator, export_path), (None, None))
     31       return model_path
     32 
     33     strategy = export_strategy.ExportStrategy('foo', _export_fn)
     34     self.assertTupleEqual(strategy, ('foo', _export_fn, None))
     35     self.assertIs(strategy.export(None, None), model_path)
     36 
     37   def test_checkpoint_export(self):
     38     ckpt_model_path = '/path/to/checkpoint_model'
     39     def _ckpt_export_fn(estimator, export_path, checkpoint_path):
     40       self.assertTupleEqual((estimator, export_path), (None, None))
     41       self.assertEqual(checkpoint_path, 'checkpoint')
     42       return ckpt_model_path
     43 
     44     strategy = export_strategy.ExportStrategy('foo', _ckpt_export_fn)
     45     self.assertTupleEqual(strategy, ('foo', _ckpt_export_fn, None))
     46     self.assertIs(strategy.export(None, None, 'checkpoint'), ckpt_model_path)
     47 
     48   def test_checkpoint_eval_export(self):
     49     ckpt_eval_model_path = '/path/to/checkpoint_eval_model'
     50     def _ckpt_eval_export_fn(estimator, export_path, checkpoint_path,
     51                              eval_result):
     52       self.assertTupleEqual((estimator, export_path), (None, None))
     53       self.assertEqual(checkpoint_path, 'checkpoint')
     54       self.assertEqual(eval_result, 'eval')
     55       return ckpt_eval_model_path
     56 
     57     strategy = export_strategy.ExportStrategy('foo', _ckpt_eval_export_fn)
     58     self.assertTupleEqual(strategy, ('foo', _ckpt_eval_export_fn, None))
     59     self.assertIs(strategy.export(None, None, 'checkpoint', 'eval'),
     60                   ckpt_eval_model_path)
     61 
     62   def test_eval_only_export(self):
     63     def _eval_export_fn(estimator, export_path, eval_result):
     64       del estimator, export_path, eval_result
     65 
     66     strategy = export_strategy.ExportStrategy('foo', _eval_export_fn)
     67     self.assertTupleEqual(strategy, ('foo', _eval_export_fn, None))
     68     with self.assertRaisesRegexp(ValueError, 'An export_fn accepting '
     69                                  'eval_result must also accept '
     70                                  'checkpoint_path'):
     71       strategy.export(None, None, eval_result='eval')
     72 
     73   def test_strip_default_attr_export(self):
     74     strip_default_attrs_model_path = '/path/to/strip_default_attrs_model'
     75     def _strip_default_attrs_export_fn(estimator, export_path,
     76                                        strip_default_attrs):
     77       self.assertTupleEqual((estimator, export_path), (None, None))
     78       self.assertTrue(strip_default_attrs)
     79       return strip_default_attrs_model_path
     80 
     81     strategy = export_strategy.ExportStrategy('foo',
     82                                               _strip_default_attrs_export_fn,
     83                                               True)
     84     self.assertTupleEqual(strategy,
     85                           ('foo', _strip_default_attrs_export_fn, True))
     86     self.assertIs(strategy.export(None, None), strip_default_attrs_model_path)
     87 
     88 if __name__ == '__main__':
     89   test.main()
     90