1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """Tests for tensorflow.python.training.saver.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 23 from tensorflow.python.client import session 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import partitioned_variables 28 from tensorflow.python.ops import variables 29 from tensorflow.python.platform import test 30 from tensorflow.python.training import saver 31 32 33 class SaverLargePartitionedVariableTest(test.TestCase): 34 35 # Need to do this in a separate test because of the amount of memory needed 36 # to run this test. 37 def testLargePartitionedVariables(self): 38 save_path = os.path.join(self.get_temp_dir(), "large_variable") 39 var_name = "my_var" 40 # Saving large partition variable. 41 with session.Session("", graph=ops.Graph()) as sess: 42 with ops.device("/cpu:0"): 43 # Create a partitioned variable which is larger than int32 size but 44 # split into smaller sized variables. 45 init = lambda shape, dtype, partition_info: constant_op.constant( 46 True, dtype, shape) 47 partitioned_var = partitioned_variables.create_partitioned_variables( 48 [1 << 31], [4], init, dtype=dtypes.bool, name=var_name) 49 variables.global_variables_initializer().run() 50 save = saver.Saver(partitioned_var) 51 val = save.save(sess, save_path) 52 self.assertEqual(save_path, val) 53 54 55 if __name__ == "__main__": 56 test.main() 57