1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for warm_starting_util with Distribution Strategy. 16 17 These tests are located here instead of as part of `WarmStartingUtilTest` 18 because they need access to distribution strategies which are only present in 19 contrib right now. 20 TODO(priyag): Move the tests to core `WarmStartingUtilTest` when distribution 21 strategy moves out of contrib. 22 """ 23 24 from __future__ import absolute_import 25 from __future__ import division 26 from __future__ import print_function 27 28 import os 29 from absl.testing import parameterized 30 31 from tensorflow.contrib.distribute.python import combinations 32 from tensorflow.python.framework import ops 33 from tensorflow.python.ops import variable_scope 34 from tensorflow.python.ops import variables 35 from tensorflow.python.platform import test 36 from tensorflow.python.training import saver as saver_lib 37 from tensorflow.python.training import warm_starting_util as ws_util 38 39 40 class WarmStartingUtilWithDistributionStrategyTest( 41 test.TestCase, parameterized.TestCase): 42 43 @combinations.generate(combinations.combine( 44 distribution=[combinations.default_strategy, 45 combinations.one_device_strategy, 46 combinations.mirrored_strategy_with_gpu_and_cpu, 47 combinations.mirrored_strategy_with_two_gpus, 48 combinations.core_mirrored_strategy_with_gpu_and_cpu, 49 combinations.core_mirrored_strategy_with_two_gpus], 50 save_with_distribution=[True, False], 51 restore_with_distribution=[True, False], 52 mode=["graph"])) 53 def testWarmStart(self, distribution, save_with_distribution, 54 restore_with_distribution): 55 56 var_name = "v" 57 original_value = [[1., 2.], [3., 4.]] 58 59 # Create variable and save checkpoint from which to warm-start. 60 def create_var(g): 61 with self.session(graph=g) as sess: 62 var = variable_scope.get_variable(var_name, initializer=original_value) 63 sess.run(variables.global_variables_initializer()) 64 saver = saver_lib.Saver() 65 ckpt_prefix = os.path.join(self.get_temp_dir(), "model") 66 saver.save(sess, ckpt_prefix, global_step=0) 67 return var, sess.run(var) 68 69 if save_with_distribution: 70 with ops.Graph().as_default() as g, distribution.scope(): 71 _, prev_init_val = create_var(g) 72 else: 73 with ops.Graph().as_default() as g: 74 _, prev_init_val = create_var(g) 75 76 # Verify we initialized the values correctly. 77 self.assertAllEqual(original_value, prev_init_val) 78 79 def warm_start(g): 80 with self.session(graph=g) as sess: 81 # Initialize with zeros. 82 var = variable_scope.get_variable( 83 var_name, initializer=[[0., 0.], [0., 0.]]) 84 ws_util.warm_start(self.get_temp_dir()) 85 sess.run(variables.global_variables_initializer()) 86 # Verify weights were correctly warm-started to previous values. 87 self.assertAllEqual(original_value, self.evaluate(var)) 88 89 # Warm start in a new graph. 90 if restore_with_distribution: 91 with ops.Graph().as_default() as g, distribution.scope(): 92 warm_start(g) 93 else: 94 with ops.Graph().as_default() as g: 95 warm_start(g) 96 97 98 if __name__ == "__main__": 99 test.main() 100