Home | History | Annotate | Download | only in utils
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for learn.utils.gc."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import re
     23 
     24 from six.moves import xrange  # pylint: disable=redefined-builtin
     25 
     26 from tensorflow.contrib.learn.python.learn.utils import gc
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.platform import gfile
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.util import compat
     31 
     32 
     33 def _create_parser(base_dir):
     34   # create a simple parser that pulls the export_version from the directory.
     35   def parser(path):
     36     # Modify the path object for RegEx match for Windows Paths
     37     if os.name == "nt":
     38       match = re.match(
     39           "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$",
     40           compat.as_str_any(path.path).replace("\\", "/"))
     41     else:
     42       match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
     43                        compat.as_str_any(path.path))
     44     if not match:
     45       return None
     46     return path._replace(export_version=int(match.group(1)))
     47 
     48   return parser
     49 
     50 
     51 class GcTest(test_util.TensorFlowTestCase):
     52 
     53   def testLargestExportVersions(self):
     54     paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
     55     newest = gc.largest_export_versions(2)
     56     n = newest(paths)
     57     self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
     58 
     59   def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
     60     paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
     61     newest = gc.largest_export_versions(2)
     62     n = newest(paths)
     63     self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
     64 
     65   def testModExportVersion(self):
     66     paths = [
     67         gc.Path("/foo", 4),
     68         gc.Path("/foo", 5),
     69         gc.Path("/foo", 6),
     70         gc.Path("/foo", 9)
     71     ]
     72     mod = gc.mod_export_version(2)
     73     self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
     74     mod = gc.mod_export_version(3)
     75     self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
     76 
     77   def testOneOfEveryNExportVersions(self):
     78     paths = [
     79         gc.Path("/foo", 0),
     80         gc.Path("/foo", 1),
     81         gc.Path("/foo", 3),
     82         gc.Path("/foo", 5),
     83         gc.Path("/foo", 6),
     84         gc.Path("/foo", 7),
     85         gc.Path("/foo", 8),
     86         gc.Path("/foo", 33)
     87     ]
     88     one_of = gc.one_of_every_n_export_versions(3)
     89     self.assertEqual(
     90         one_of(paths), [
     91             gc.Path("/foo", 3),
     92             gc.Path("/foo", 6),
     93             gc.Path("/foo", 8),
     94             gc.Path("/foo", 33)
     95         ])
     96 
     97   def testOneOfEveryNExportVersionsZero(self):
     98     # Zero is a special case since it gets rolled into the first interval.
     99     # Test that here.
    100     paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
    101     one_of = gc.one_of_every_n_export_versions(3)
    102     self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
    103 
    104   def testUnion(self):
    105     paths = []
    106     for i in xrange(10):
    107       paths.append(gc.Path("/foo", i))
    108     f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
    109     self.assertEqual(
    110         f(paths), [
    111             gc.Path("/foo", 0),
    112             gc.Path("/foo", 3),
    113             gc.Path("/foo", 6),
    114             gc.Path("/foo", 7),
    115             gc.Path("/foo", 8),
    116             gc.Path("/foo", 9)
    117         ])
    118 
    119   def testNegation(self):
    120     paths = [
    121         gc.Path("/foo", 4),
    122         gc.Path("/foo", 5),
    123         gc.Path("/foo", 6),
    124         gc.Path("/foo", 9)
    125     ]
    126     mod = gc.negation(gc.mod_export_version(2))
    127     self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
    128     mod = gc.negation(gc.mod_export_version(3))
    129     self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
    130 
    131   def testPathsWithParse(self):
    132     base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
    133     self.assertFalse(gfile.Exists(base_dir))
    134     for p in xrange(3):
    135       gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
    136     # add a base_directory to ignore
    137     gfile.MakeDirs(os.path.join(base_dir, "ignore"))
    138 
    139     self.assertEqual(
    140         gc.get_paths(base_dir, _create_parser(base_dir)), [
    141             gc.Path(os.path.join(base_dir, "0"), 0),
    142             gc.Path(os.path.join(base_dir, "1"), 1),
    143             gc.Path(os.path.join(base_dir, "2"), 2)
    144         ])
    145 
    146   def testMixedStrTypes(self):
    147     temp_dir = compat.as_bytes(test.get_temp_dir())
    148 
    149     for sub_dir in ["str", b"bytes", u"unicode"]:
    150       base_dir = os.path.join(
    151           (temp_dir
    152            if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir)
    153       self.assertFalse(gfile.Exists(base_dir))
    154       gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
    155       gc.get_paths(base_dir, _create_parser(base_dir))
    156 
    157 
    158 if __name__ == "__main__":
    159   test.main()
    160