1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for learn.utils.gc.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import re 23 24 from six.moves import xrange # pylint: disable=redefined-builtin 25 26 from tensorflow.contrib.learn.python.learn.utils import gc 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.platform import gfile 29 from tensorflow.python.platform import test 30 from tensorflow.python.util import compat 31 32 33 def _create_parser(base_dir): 34 # create a simple parser that pulls the export_version from the directory. 35 def parser(path): 36 # Modify the path object for RegEx match for Windows Paths 37 if os.name == "nt": 38 match = re.match( 39 "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", 40 compat.as_str_any(path.path).replace("\\", "/")) 41 else: 42 match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", 43 compat.as_str_any(path.path)) 44 if not match: 45 return None 46 return path._replace(export_version=int(match.group(1))) 47 48 return parser 49 50 51 class GcTest(test_util.TensorFlowTestCase): 52 53 def testLargestExportVersions(self): 54 paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] 55 newest = gc.largest_export_versions(2) 56 n = newest(paths) 57 self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) 58 59 def testLargestExportVersionsDoesNotDeleteZeroFolder(self): 60 paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] 61 newest = gc.largest_export_versions(2) 62 n = newest(paths) 63 self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) 64 65 def testModExportVersion(self): 66 paths = [ 67 gc.Path("/foo", 4), 68 gc.Path("/foo", 5), 69 gc.Path("/foo", 6), 70 gc.Path("/foo", 9) 71 ] 72 mod = gc.mod_export_version(2) 73 self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) 74 mod = gc.mod_export_version(3) 75 self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) 76 77 def testOneOfEveryNExportVersions(self): 78 paths = [ 79 gc.Path("/foo", 0), 80 gc.Path("/foo", 1), 81 gc.Path("/foo", 3), 82 gc.Path("/foo", 5), 83 gc.Path("/foo", 6), 84 gc.Path("/foo", 7), 85 gc.Path("/foo", 8), 86 gc.Path("/foo", 33) 87 ] 88 one_of = gc.one_of_every_n_export_versions(3) 89 self.assertEqual( 90 one_of(paths), [ 91 gc.Path("/foo", 3), 92 gc.Path("/foo", 6), 93 gc.Path("/foo", 8), 94 gc.Path("/foo", 33) 95 ]) 96 97 def testOneOfEveryNExportVersionsZero(self): 98 # Zero is a special case since it gets rolled into the first interval. 99 # Test that here. 100 paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] 101 one_of = gc.one_of_every_n_export_versions(3) 102 self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)]) 103 104 def testUnion(self): 105 paths = [] 106 for i in xrange(10): 107 paths.append(gc.Path("/foo", i)) 108 f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) 109 self.assertEqual( 110 f(paths), [ 111 gc.Path("/foo", 0), 112 gc.Path("/foo", 3), 113 gc.Path("/foo", 6), 114 gc.Path("/foo", 7), 115 gc.Path("/foo", 8), 116 gc.Path("/foo", 9) 117 ]) 118 119 def testNegation(self): 120 paths = [ 121 gc.Path("/foo", 4), 122 gc.Path("/foo", 5), 123 gc.Path("/foo", 6), 124 gc.Path("/foo", 9) 125 ] 126 mod = gc.negation(gc.mod_export_version(2)) 127 self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) 128 mod = gc.negation(gc.mod_export_version(3)) 129 self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) 130 131 def testPathsWithParse(self): 132 base_dir = os.path.join(test.get_temp_dir(), "paths_parse") 133 self.assertFalse(gfile.Exists(base_dir)) 134 for p in xrange(3): 135 gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) 136 # add a base_directory to ignore 137 gfile.MakeDirs(os.path.join(base_dir, "ignore")) 138 139 self.assertEqual( 140 gc.get_paths(base_dir, _create_parser(base_dir)), [ 141 gc.Path(os.path.join(base_dir, "0"), 0), 142 gc.Path(os.path.join(base_dir, "1"), 1), 143 gc.Path(os.path.join(base_dir, "2"), 2) 144 ]) 145 146 def testMixedStrTypes(self): 147 temp_dir = compat.as_bytes(test.get_temp_dir()) 148 149 for sub_dir in ["str", b"bytes", u"unicode"]: 150 base_dir = os.path.join( 151 (temp_dir 152 if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir) 153 self.assertFalse(gfile.Exists(base_dir)) 154 gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42")) 155 gc.get_paths(base_dir, _create_parser(base_dir)) 156 157 158 if __name__ == "__main__": 159 test.main() 160