Home | History | Annotate | Download | only in python
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for warm_starting_util with Distribution Strategy.
     16 
     17 These tests are located here instead of as part of `WarmStartingUtilTest`
     18 because they need access to distribution strategies which are only present in
     19 contrib right now.
     20 TODO(priyag): Move the tests to core `WarmStartingUtilTest` when distribution
     21 strategy moves out of contrib.
     22 """
     23 
     24 from __future__ import absolute_import
     25 from __future__ import division
     26 from __future__ import print_function
     27 
     28 import os
     29 from absl.testing import parameterized
     30 
     31 from tensorflow.contrib.distribute.python import combinations
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.ops import variable_scope
     34 from tensorflow.python.ops import variables
     35 from tensorflow.python.platform import test
     36 from tensorflow.python.training import saver as saver_lib
     37 from tensorflow.python.training import warm_starting_util as ws_util
     38 
     39 
     40 class WarmStartingUtilWithDistributionStrategyTest(
     41     test.TestCase, parameterized.TestCase):
     42 
     43   @combinations.generate(combinations.combine(
     44       distribution=[combinations.default_strategy,
     45                     combinations.one_device_strategy,
     46                     combinations.mirrored_strategy_with_gpu_and_cpu,
     47                     combinations.mirrored_strategy_with_two_gpus,
     48                     combinations.core_mirrored_strategy_with_gpu_and_cpu,
     49                     combinations.core_mirrored_strategy_with_two_gpus],
     50       save_with_distribution=[True, False],
     51       restore_with_distribution=[True, False],
     52       mode=["graph"]))
     53   def testWarmStart(self, distribution, save_with_distribution,
     54                     restore_with_distribution):
     55 
     56     var_name = "v"
     57     original_value = [[1., 2.], [3., 4.]]
     58 
     59     # Create variable and save checkpoint from which to warm-start.
     60     def create_var(g):
     61       with self.session(graph=g) as sess:
     62         var = variable_scope.get_variable(var_name, initializer=original_value)
     63         sess.run(variables.global_variables_initializer())
     64         saver = saver_lib.Saver()
     65         ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
     66         saver.save(sess, ckpt_prefix, global_step=0)
     67         return var, sess.run(var)
     68 
     69     if save_with_distribution:
     70       with ops.Graph().as_default() as g, distribution.scope():
     71         _, prev_init_val = create_var(g)
     72     else:
     73       with ops.Graph().as_default() as g:
     74         _, prev_init_val = create_var(g)
     75 
     76     # Verify we initialized the values correctly.
     77     self.assertAllEqual(original_value, prev_init_val)
     78 
     79     def warm_start(g):
     80       with self.session(graph=g) as sess:
     81         # Initialize with zeros.
     82         var = variable_scope.get_variable(
     83             var_name, initializer=[[0., 0.], [0., 0.]])
     84         ws_util.warm_start(self.get_temp_dir())
     85         sess.run(variables.global_variables_initializer())
     86         # Verify weights were correctly warm-started to previous values.
     87         self.assertAllEqual(original_value, self.evaluate(var))
     88 
     89     # Warm start in a new graph.
     90     if restore_with_distribution:
     91       with ops.Graph().as_default() as g, distribution.scope():
     92         warm_start(g)
     93     else:
     94       with ops.Graph().as_default() as g:
     95         warm_start(g)
     96 
     97 
     98 if __name__ == "__main__":
     99   test.main()
    100