Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities to warm-start TF.Learn Estimators."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import six
     23 
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import resource_variable_ops
     26 from tensorflow.python.ops import state_ops
     27 from tensorflow.python.ops import variable_scope
     28 from tensorflow.python.ops import variables as variables_lib
     29 from tensorflow.python.platform import tf_logging as logging
     30 from tensorflow.python.training import checkpoint_ops
     31 from tensorflow.python.training import checkpoint_utils
     32 from tensorflow.python.training import saver
     33 from tensorflow.python.util.tf_export import tf_export
     34 
     35 
     36 @tf_export("estimator.VocabInfo")
     37 class VocabInfo(
     38     collections.namedtuple("VocabInfo", [
     39         "new_vocab",
     40         "new_vocab_size",
     41         "num_oov_buckets",
     42         "old_vocab",
     43         "old_vocab_size",
     44         "backup_initializer",
     45     ])):
     46   """Vocabulary information for WarmStartSettings.
     47 
     48   See @{tf.estimator.WarmStartSettings$WarmStartSettings} for examples of using
     49   VocabInfo to warm-start.
     50 
     51   Attributes:
     52     new_vocab: [Required] A path to the new vocabulary file (used with the
     53       model to be trained).
     54     new_vocab_size: [Required] An integer indicating how many entries of the new
     55       vocabulary will used in training.
     56     num_oov_buckets: [Required] An integer indicating how many OOV buckets are
     57       associated with the vocabulary.
     58     old_vocab: [Required] A path to the old vocabulary file (used with the
     59       checkpoint to be warm-started from).
     60     old_vocab_size: [Optional] An integer indicating how many entries of the old
     61       vocabulary were used in the creation of the checkpoint. If not provided,
     62       the entire old vocabulary will be used.
     63     backup_initializer: [Optional] A variable initializer used for variables
     64       corresponding to new vocabulary entries and OOV. If not provided, these
     65       entries will be zero-initialized.
     66   """
     67 
     68   def __new__(cls,
     69               new_vocab,
     70               new_vocab_size,
     71               num_oov_buckets,
     72               old_vocab,
     73               old_vocab_size=-1,
     74               backup_initializer=None):
     75     return super(VocabInfo, cls).__new__(
     76         cls,
     77         new_vocab,
     78         new_vocab_size,
     79         num_oov_buckets,
     80         old_vocab,
     81         old_vocab_size,
     82         backup_initializer,
     83     )
     84 
     85 
     86 @tf_export("estimator.WarmStartSettings")
     87 class WarmStartSettings(
     88     collections.namedtuple("WarmStartSettings", [
     89         "ckpt_to_initialize_from",
     90         "vars_to_warm_start",
     91         "var_name_to_vocab_info",
     92         "var_name_to_prev_var_name",
     93     ])):
     94   """Settings for warm-starting in Estimators.
     95 
     96   Example Use with canned `DNNEstimator`:
     97 
     98   ```
     99   emb_vocab_file = tf.feature_column.embedding_column(
    100       tf.feature_column.categorical_column_with_vocabulary_file(
    101           "sc_vocab_file", "new_vocab.txt", vocab_size=100),
    102       dimension=8)
    103   emb_vocab_list = tf.feature_column.embedding_column(
    104       tf.feature_column.categorical_column_with_vocabulary_list(
    105           "sc_vocab_list", vocabulary_list=["a", "b"]),
    106       dimension=8)
    107   estimator = tf.estimator.DNNClassifier(
    108     hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
    109     warm_start_from=ws)
    110   ```
    111 
    112   where `ws` could be defined as:
    113 
    114   Warm-start all weights in the model (input layer and hidden weights).
    115   Either the directory or a specific checkpoint can be provided (in the case
    116   of the former, the latest checkpoint will be used):
    117 
    118   ```
    119   ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
    120   ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
    121   ```
    122 
    123   Warm-start only the embeddings (input layer):
    124 
    125   ```
    126   ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
    127                          vars_to_warm_start=".*input_layer.*")
    128   ```
    129 
    130   Warm-start all weights but the embedding parameters corresponding to
    131   `sc_vocab_file` have a different vocab from the one used in the current
    132   model:
    133 
    134   ```
    135   vocab_info = ws_util.VocabInfo(
    136       new_vocab=sc_vocab_file.vocabulary_file,
    137       new_vocab_size=sc_vocab_file.vocabulary_size,
    138       num_oov_buckets=sc_vocab_file.num_oov_buckets,
    139       old_vocab="old_vocab.txt"
    140   )
    141   ws = WarmStartSettings(
    142       ckpt_to_initialize_from="/tmp",
    143       var_name_to_vocab_info={
    144           "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    145       })
    146   ```
    147 
    148   Warm-start only `sc_vocab_file` embeddings (and no other variables), which
    149   have a different vocab from the one used in the current model:
    150 
    151   ```
    152   vocab_info = ws_util.VocabInfo(
    153       new_vocab=sc_vocab_file.vocabulary_file,
    154       new_vocab_size=sc_vocab_file.vocabulary_size,
    155       num_oov_buckets=sc_vocab_file.num_oov_buckets,
    156       old_vocab="old_vocab.txt"
    157   )
    158   ws = WarmStartSettings(
    159       ckpt_to_initialize_from="/tmp",
    160       vars_to_warm_start=None,
    161       var_name_to_vocab_info={
    162           "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    163       })
    164   ```
    165 
    166   Warm-start all weights but the parameters corresponding to `sc_vocab_file`
    167   have a different vocab from the one used in current checkpoint, and only
    168   100 of those entries were used:
    169 
    170   ```
    171   vocab_info = ws_util.VocabInfo(
    172       new_vocab=sc_vocab_file.vocabulary_file,
    173       new_vocab_size=sc_vocab_file.vocabulary_size,
    174       num_oov_buckets=sc_vocab_file.num_oov_buckets,
    175       old_vocab="old_vocab.txt",
    176       old_vocab_size=100
    177   )
    178   ws = WarmStartSettings(
    179       ckpt_to_initialize_from="/tmp",
    180       var_name_to_vocab_info={
    181           "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    182       })
    183   ```
    184 
    185   Warm-start all weights but the parameters corresponding to `sc_vocab_file`
    186   have a different vocab from the one used in current checkpoint and the
    187   parameters corresponding to `sc_vocab_list` have a different name from the
    188   current checkpoint:
    189 
    190   ```
    191   vocab_info = ws_util.VocabInfo(
    192       new_vocab=sc_vocab_file.vocabulary_file,
    193       new_vocab_size=sc_vocab_file.vocabulary_size,
    194       num_oov_buckets=sc_vocab_file.num_oov_buckets,
    195       old_vocab="old_vocab.txt",
    196       old_vocab_size=100
    197   )
    198   ws = WarmStartSettings(
    199       ckpt_to_initialize_from="/tmp",
    200       var_name_to_vocab_info={
    201           "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
    202       },
    203       var_name_to_prev_var_name={
    204           "input_layer/sc_vocab_list_embedding/embedding_weights":
    205               "old_tensor_name"
    206       })
    207   ```
    208 
    209   Attributes:
    210     ckpt_to_initialize_from: [Required] A string specifying the directory with
    211       checkpoint file(s) or path to checkpoint from which to warm-start the
    212       model parameters.
    213     vars_to_warm_start: [Optional] A regular expression that captures which
    214       variables to warm-start (see tf.get_collection).  Defaults to `'.*'`,
    215       which warm-starts all variables.  If `None` is explicitly given, only
    216       variables specified in `var_name_to_vocab_info` will be warm-started.
    217     var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
    218       VocabInfo. The variable names should be "full" variables, not the names
    219       of the partitions.  If not explicitly provided, the variable is assumed to
    220       have no vocabulary.
    221     var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
    222       name of the previously-trained variable in `ckpt_to_initialize_from`. If
    223       not explicitly provided, the name of the variable is assumed to be same
    224       between previous checkpoint and current model.
    225   """
    226 
    227   def __new__(cls,
    228               ckpt_to_initialize_from,
    229               vars_to_warm_start=".*",
    230               var_name_to_vocab_info=None,
    231               var_name_to_prev_var_name=None):
    232     if not ckpt_to_initialize_from:
    233       raise ValueError(
    234           "`ckpt_to_initialize_from` MUST be set in WarmStartSettings")
    235     return super(WarmStartSettings, cls).__new__(
    236         cls,
    237         ckpt_to_initialize_from,
    238         vars_to_warm_start,
    239         var_name_to_vocab_info or {},
    240         var_name_to_prev_var_name or {},
    241     )
    242 
    243 
    244 def _is_variable(x):
    245   return (isinstance(x, variables_lib.Variable) or
    246           isinstance(x, resource_variable_ops.ResourceVariable))
    247 
    248 
    249 def _infer_var_name(var):
    250   """Returns name of the `var`.
    251 
    252   Args:
    253     var: A list. The list can contain either of the following:
    254       (i) A single `Variable`
    255       (ii) A single `ResourceVariable`
    256       (iii) Multiple `Variable` objects which must be slices of the same larger
    257         variable.
    258       (iv) A single `PartitionedVariable`
    259 
    260   Returns:
    261     Name of the `var`
    262   """
    263   name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
    264   if len(name_to_var_dict) > 1:
    265     raise TypeError("`var` = %s passed as arg violates the constraints.  "
    266                     "name_to_var_dict = %s" % (var, name_to_var_dict))
    267   return list(name_to_var_dict.keys())[0]
    268 
    269 
    270 def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
    271   """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
    272 
    273   Args:
    274     var: Current graph's variable that needs to be warm-started (initialized).
    275       Can be either of the following:
    276       (i) `Variable`
    277       (ii) `ResourceVariable`
    278       (iii) list of `Variable`: The list must contain slices of the same larger
    279         variable.
    280       (iv) `PartitionedVariable`
    281     prev_ckpt: A string specifying the directory with checkpoint file(s) or path
    282       to checkpoint. The given checkpoint must have tensor with name
    283       `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    284     prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
    285       None, we lookup tensor with same name as given `var`.
    286   """
    287   if _is_variable(var):
    288     current_var_name = _infer_var_name([var])
    289   elif isinstance(var, list) and all(_is_variable(v) for v in var):
    290     current_var_name = _infer_var_name(var)
    291   elif isinstance(var, variables_lib.PartitionedVariable):
    292     current_var_name = _infer_var_name([var])
    293     var = var._get_variable_list()  # pylint: disable=protected-access
    294   else:
    295     raise TypeError(
    296         "var MUST be one of the following: a Variable, list of Variable or "
    297         "PartitionedVariable, but is {}".format(type(var)))
    298   if not prev_tensor_name:
    299     # Assume tensor name remains the same.
    300     prev_tensor_name = current_var_name
    301   checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
    302 
    303 
    304 # pylint: disable=protected-access
    305 # Accesses protected members of tf.Variable to reset the variable's internal
    306 # state.
    307 def _warm_start_var_with_vocab(var,
    308                                current_vocab_path,
    309                                current_vocab_size,
    310                                prev_ckpt,
    311                                prev_vocab_path,
    312                                previous_vocab_size=-1,
    313                                current_oov_buckets=0,
    314                                prev_tensor_name=None,
    315                                initializer=None):
    316   """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
    317 
    318   Use this method when the `var` is backed by vocabulary. This method stitches
    319   the given `var` such that values corresponding to individual features in the
    320   vocabulary remain consistent irrespective of changing order of the features
    321   between old and new vocabularies.
    322 
    323   Args:
    324     var: Current graph's variable that needs to be warm-started (initialized).
    325       Can be either of the following:
    326       (i) `Variable`
    327       (ii) `ResourceVariable`
    328       (iii) list of `Variable`: The list must contain slices of the same larger
    329         variable.
    330       (iv) `PartitionedVariable`
    331     current_vocab_path: Path to the vocab file used for the given `var`.
    332     current_vocab_size: An `int` specifying the number of entries in the current
    333       vocab.
    334     prev_ckpt: A string specifying the directory with checkpoint file(s) or path
    335       to checkpoint. The given checkpoint must have tensor with name
    336       `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    337     prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
    338     previous_vocab_size: If provided, will constrain previous vocab to the first
    339       `previous_vocab_size` entries.  -1 means use the entire previous vocab.
    340     current_oov_buckets: An `int` specifying the number of out-of-vocabulary
    341       buckets used for given `var`.
    342     prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
    343       None, we lookup tensor with same name as given `var`.
    344     initializer: Variable initializer to be used for missing entries.  If None,
    345       missing entries will be zero-initialized.
    346 
    347   Raises:
    348     ValueError: If required args are not provided.
    349   """
    350   if not (current_vocab_path and current_vocab_size and prev_ckpt and
    351           prev_vocab_path):
    352     raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
    353                      "current_vocab_size, prev_ckpt, prev_vocab_path}.")
    354   if _is_variable(var):
    355     var = [var]
    356   elif isinstance(var, list) and all(_is_variable(v) for v in var):
    357     var = var
    358   elif isinstance(var, variables_lib.PartitionedVariable):
    359     var = var._get_variable_list()
    360   else:
    361     raise TypeError(
    362         "var MUST be one of the following: a Variable, list of Variable or "
    363         "PartitionedVariable, but is {}".format(type(var)))
    364 
    365   if not prev_tensor_name:
    366     # Assume tensor name remains the same.
    367     prev_tensor_name = _infer_var_name(var)
    368 
    369   for v in var:
    370     v_shape = v.get_shape().as_list()
    371     slice_info = v._get_save_slice_info()
    372     partition_info = None
    373     if slice_info:
    374       partition_info = variable_scope._PartitionInfo(
    375           full_shape=slice_info.full_shape,
    376           var_offset=slice_info.var_offset)
    377 
    378     # TODO(eddz): Support WarmStartSettings where class vocabularies need
    379     # remapping too.
    380     init = checkpoint_ops._load_and_remap_matrix_initializer(
    381         ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
    382         old_tensor_name=prev_tensor_name,
    383         new_row_vocab_size=current_vocab_size,
    384         new_col_vocab_size=v_shape[1],
    385         old_row_vocab_size=previous_vocab_size,
    386         old_row_vocab_file=prev_vocab_path,
    387         new_row_vocab_file=current_vocab_path,
    388         old_col_vocab_file=None,
    389         new_col_vocab_file=None,
    390         num_row_oov_buckets=current_oov_buckets,
    391         num_col_oov_buckets=0,
    392         initializer=initializer)
    393     new_init_val = ops.convert_to_tensor(
    394         init(shape=v_shape, partition_info=partition_info))
    395     v._initializer_op = state_ops.assign(v, new_init_val)
    396 # pylint: enable=protected-access
    397 
    398 
    399 def _warm_start(warm_start_settings):
    400   """Warm-starts a model using the given settings.
    401 
    402   If you are using a tf.estimator.Estimator, this will automatically be called
    403   during training.
    404 
    405   Args:
    406     warm_start_settings: An object of `WarmStartSettings`.
    407   Raises:
    408     ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
    409       configuration for variable names that are not used.  This is to ensure
    410       a stronger check for variable configuration than relying on users to
    411       examine the logs.
    412   """
    413   logging.info("Warm-starting from: %s",
    414                (warm_start_settings.ckpt_to_initialize_from,))
    415   # We have to deal with partitioned variables, since get_collection flattens
    416   # out the list.
    417   grouped_variables = {}
    418   # Both warm_start_settings.vars_to_warm_start = '.*' and
    419   # warm_start_settings.vars_to_warm_start = None will match everything here.
    420   for v in ops.get_collection(
    421       # TODO(eddz): Allow for different collections here (to support
    422       # warm-starting accumulators).
    423       ops.GraphKeys.TRAINABLE_VARIABLES,
    424       scope=warm_start_settings.vars_to_warm_start):
    425     if not isinstance(v, list):
    426       var_name = _infer_var_name([v])
    427     else:
    428       var_name = _infer_var_name(v)
    429     grouped_variables.setdefault(var_name, []).append(v)
    430 
    431   # Keep track of which var_names in var_name_to_prev_var_name and
    432   # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
    433   # exception if any are unused by the end of the loop.  It is easy to misname
    434   # a variable during this configuration, in which case without this check, we
    435   # would fail to warm-start silently.
    436   prev_var_name_used = set()
    437   vocab_info_used = set()
    438 
    439   for var_name, variable in six.iteritems(grouped_variables):
    440     prev_var_name = warm_start_settings.var_name_to_prev_var_name.get(var_name)
    441     if prev_var_name:
    442       prev_var_name_used.add(var_name)
    443     vocab_info = warm_start_settings.var_name_to_vocab_info.get(var_name)
    444     if vocab_info:
    445       vocab_info_used.add(var_name)
    446       logging.info(
    447           "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
    448           " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
    449           " initializer: {}".format(
    450               var_name,
    451               vocab_info.new_vocab,
    452               vocab_info.new_vocab_size,
    453               vocab_info.old_vocab,
    454               (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0
    455                else "All"),
    456               vocab_info.num_oov_buckets,
    457               prev_var_name or "Unchanged",
    458               vocab_info.backup_initializer or "zero-initialized"))
    459       _warm_start_var_with_vocab(
    460           variable,
    461           current_vocab_path=vocab_info.new_vocab,
    462           current_vocab_size=vocab_info.new_vocab_size,
    463           prev_ckpt=warm_start_settings.ckpt_to_initialize_from,
    464           prev_vocab_path=vocab_info.old_vocab,
    465           previous_vocab_size=vocab_info.old_vocab_size,
    466           current_oov_buckets=vocab_info.num_oov_buckets,
    467           prev_tensor_name=prev_var_name,
    468           initializer=vocab_info.backup_initializer)
    469     else:
    470       # For the special value of warm_start_settings.vars_to_warm_start = None,
    471       # we only warm-start variables with explicitly specified vocabularies.
    472       if warm_start_settings.vars_to_warm_start:
    473         logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
    474             var_name, prev_var_name or "Unchanged"))
    475         # Because we use a default empty list in grouped_variables, single
    476         # unpartitioned variables will be lists here, which we rectify in order
    477         # for init_from_checkpoint logic to work correctly.
    478         if len(variable) == 1:
    479           variable = variable[0]
    480         _warm_start_var(variable, warm_start_settings.ckpt_to_initialize_from,
    481                         prev_var_name)
    482 
    483   prev_var_name_not_used = set(
    484       warm_start_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used
    485   vocab_info_not_used = set(
    486       warm_start_settings.var_name_to_vocab_info.keys()) - vocab_info_used
    487 
    488   if prev_var_name_not_used:
    489     raise ValueError(
    490         "You provided the following variables in "
    491         "warm_start_settings.var_name_to_prev_var_name that were not used: "
    492         "{0}.  Perhaps you misspelled them?  Here is the list of viable "
    493         "variable names: {1}".format(prev_var_name_not_used,
    494                                      grouped_variables.keys()))
    495   if vocab_info_not_used:
    496     raise ValueError(
    497         "You provided the following variables in "
    498         "warm_start_settings.var_name_to_vocab_info that were not used: {0}. "
    499         " Perhaps you misspelled them?  Here is the list of viable variable "
    500         "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
    501 
    502 
    503 def _get_default_warm_start_settings(warm_start_from):
    504   """Returns default WarmStartSettings.
    505 
    506   Args:
    507     warm_start_from: Either a string representing the filepath of a checkpoint
    508       to initialize from, or an instance of WarmStartSettings.
    509 
    510   Returns:
    511     Either None or an instance of WarmStartSettings.
    512 
    513   Raises:
    514     ValueError: If warm_start_from is not None but is neither a string nor an
    515       instance of WarmStartSettings.
    516   """
    517   if warm_start_from is None:
    518     return None
    519   if isinstance(warm_start_from, six.string_types):
    520     return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
    521   elif isinstance(warm_start_from, WarmStartSettings):
    522     return warm_start_from
    523   else:
    524     raise ValueError("warm_start_from must be a string or a WarmStartSettings")
    525