Home | History | Annotate | Download | only in estimator
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 r"""System for specifying garbage collection (GC) of path based data.
     17 
     18 This framework allows for GC of data specified by path names, for example files
     19 on disk.  gc.Path objects each represent a single item stored at a path and may
     20 be a base directory,
     21   /tmp/exports/0/...
     22   /tmp/exports/1/...
     23   ...
     24 or a fully qualified file,
     25   /tmp/train-1.ckpt
     26   /tmp/train-2.ckpt
     27   ...
     28 
     29 A gc filter function takes and returns a list of gc.Path items.  Filter
     30 functions are responsible for selecting Path items for preservation or deletion.
     31 Note that functions should always return a sorted list.
     32 
     33 For example,
     34   base_dir = "/tmp"
     35   # Create the directories.
     36   for e in xrange(10):
     37     os.mkdir("%s/%d" % (base_dir, e), 0o755)
     38 
     39   # Create a simple parser that pulls the export_version from the directory.
     40   path_regex = "^" + re.escape(base_dir) + "/(\\d+)$"
     41   def parser(path):
     42     match = re.match(path_regex, path.path)
     43     if not match:
     44       return None
     45     return path._replace(export_version=int(match.group(1)))
     46 
     47   path_list = gc._get_paths("/tmp", parser)  # contains all ten Paths
     48 
     49   every_fifth = gc._mod_export_version(5)
     50   print(every_fifth(path_list))  # shows ["/tmp/0", "/tmp/5"]
     51 
     52   largest_three = gc.largest_export_versions(3)
     53   print(largest_three(all_paths))  # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
     54 
     55   both = gc._union(every_fifth, largest_three)
     56   print(both(all_paths))  # shows ["/tmp/0", "/tmp/5",
     57                           #        "/tmp/7", "/tmp/8", "/tmp/9"]
     58   # Delete everything not in 'both'.
     59   to_delete = gc._negation(both)
     60   for p in to_delete(all_paths):
     61     gfile.DeleteRecursively(p.path)  # deletes:  "/tmp/1", "/tmp/2",
     62                                      # "/tmp/3", "/tmp/4", "/tmp/6",
     63 """
     64 
     65 from __future__ import absolute_import
     66 from __future__ import division
     67 from __future__ import print_function
     68 
     69 import collections
     70 import heapq
     71 import math
     72 import os
     73 
     74 from tensorflow.python.platform import gfile
     75 from tensorflow.python.util import compat
     76 
     77 Path = collections.namedtuple('Path', 'path export_version')
     78 
     79 
     80 def _largest_export_versions(n):
     81   """Creates a filter that keeps the largest n export versions.
     82 
     83   Args:
     84     n: number of versions to keep.
     85 
     86   Returns:
     87     A filter function that keeps the n largest paths.
     88   """
     89   def keep(paths):
     90     heap = []
     91     for idx, path in enumerate(paths):
     92       if path.export_version is not None:
     93         heapq.heappush(heap, (path.export_version, idx))
     94     keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
     95     return sorted(keepers)
     96 
     97   return keep
     98 
     99 
    100 def _one_of_every_n_export_versions(n):
    101   """Creates a filter that keeps one of every n export versions.
    102 
    103   Args:
    104     n: interval size.
    105 
    106   Returns:
    107     A filter function that keeps exactly one path from each interval
    108     [0, n], (n, 2n], (2n, 3n], etc...  If more than one path exists in an
    109     interval the largest is kept.
    110   """
    111   def keep(paths):
    112     """A filter function that keeps exactly one out of every n paths."""
    113 
    114     keeper_map = {}  # map from interval to largest path seen in that interval
    115     for p in paths:
    116       if p.export_version is None:
    117         # Skip missing export_versions.
    118         continue
    119       # Find the interval (with a special case to map export_version = 0 to
    120       # interval 0.
    121       interval = math.floor(
    122           (p.export_version - 1) / n) if p.export_version else 0
    123       existing = keeper_map.get(interval, None)
    124       if (not existing) or (existing.export_version < p.export_version):
    125         keeper_map[interval] = p
    126     return sorted(keeper_map.values())
    127 
    128   return keep
    129 
    130 
    131 def _mod_export_version(n):
    132   """Creates a filter that keeps every export that is a multiple of n.
    133 
    134   Args:
    135     n: step size.
    136 
    137   Returns:
    138     A filter function that keeps paths where export_version % n == 0.
    139   """
    140   def keep(paths):
    141     keepers = []
    142     for p in paths:
    143       if p.export_version % n == 0:
    144         keepers.append(p)
    145     return sorted(keepers)
    146   return keep
    147 
    148 
    149 def _union(lf, rf):
    150   """Creates a filter that keeps the union of two filters.
    151 
    152   Args:
    153     lf: first filter
    154     rf: second filter
    155 
    156   Returns:
    157     A filter function that keeps the n largest paths.
    158   """
    159   def keep(paths):
    160     l = set(lf(paths))
    161     r = set(rf(paths))
    162     return sorted(list(l|r))
    163   return keep
    164 
    165 
    166 def _negation(f):
    167   """Negate a filter.
    168 
    169   Args:
    170     f: filter function to invert
    171 
    172   Returns:
    173     A filter function that returns the negation of f.
    174   """
    175   def keep(paths):
    176     l = set(paths)
    177     r = set(f(paths))
    178     return sorted(list(l-r))
    179   return keep
    180 
    181 
    182 def _get_paths(base_dir, parser):
    183   """Gets a list of Paths in a given directory.
    184 
    185   Args:
    186     base_dir: directory.
    187     parser: a function which gets the raw Path and can augment it with
    188       information such as the export_version, or ignore the path by returning
    189       None.  An example parser may extract the export version from a path
    190       such as "/tmp/exports/100" an another may extract from a full file
    191       name such as "/tmp/checkpoint-99.out".
    192 
    193   Returns:
    194     A list of Paths contained in the base directory with the parsing function
    195     applied.
    196     By default the following fields are populated,
    197       - Path.path
    198     The parsing function is responsible for populating,
    199       - Path.export_version
    200   """
    201   raw_paths = gfile.ListDirectory(base_dir)
    202   paths = []
    203   for r in raw_paths:
    204     p = parser(Path(os.path.join(compat.as_str_any(base_dir),
    205                                  compat.as_str_any(r)),
    206                     None))
    207     if p:
    208       paths.append(p)
    209   return sorted(paths)
    210