1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utilities to warm-start TF.Learn Estimators.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import six 23 24 from tensorflow.python.framework import ops 25 from tensorflow.python.ops import resource_variable_ops 26 from tensorflow.python.ops import state_ops 27 from tensorflow.python.ops import variable_scope 28 from tensorflow.python.ops import variables as variables_lib 29 from tensorflow.python.platform import tf_logging as logging 30 from tensorflow.python.training import checkpoint_ops 31 from tensorflow.python.training import checkpoint_utils 32 from tensorflow.python.training import saver 33 from tensorflow.python.util.tf_export import tf_export 34 35 36 @tf_export("estimator.VocabInfo") 37 class VocabInfo( 38 collections.namedtuple("VocabInfo", [ 39 "new_vocab", 40 "new_vocab_size", 41 "num_oov_buckets", 42 "old_vocab", 43 "old_vocab_size", 44 "backup_initializer", 45 ])): 46 """Vocabulary information for WarmStartSettings. 47 48 See @{tf.estimator.WarmStartSettings$WarmStartSettings} for examples of using 49 VocabInfo to warm-start. 50 51 Attributes: 52 new_vocab: [Required] A path to the new vocabulary file (used with the 53 model to be trained). 54 new_vocab_size: [Required] An integer indicating how many entries of the new 55 vocabulary will used in training. 56 num_oov_buckets: [Required] An integer indicating how many OOV buckets are 57 associated with the vocabulary. 58 old_vocab: [Required] A path to the old vocabulary file (used with the 59 checkpoint to be warm-started from). 60 old_vocab_size: [Optional] An integer indicating how many entries of the old 61 vocabulary were used in the creation of the checkpoint. If not provided, 62 the entire old vocabulary will be used. 63 backup_initializer: [Optional] A variable initializer used for variables 64 corresponding to new vocabulary entries and OOV. If not provided, these 65 entries will be zero-initialized. 66 """ 67 68 def __new__(cls, 69 new_vocab, 70 new_vocab_size, 71 num_oov_buckets, 72 old_vocab, 73 old_vocab_size=-1, 74 backup_initializer=None): 75 return super(VocabInfo, cls).__new__( 76 cls, 77 new_vocab, 78 new_vocab_size, 79 num_oov_buckets, 80 old_vocab, 81 old_vocab_size, 82 backup_initializer, 83 ) 84 85 86 @tf_export("estimator.WarmStartSettings") 87 class WarmStartSettings( 88 collections.namedtuple("WarmStartSettings", [ 89 "ckpt_to_initialize_from", 90 "vars_to_warm_start", 91 "var_name_to_vocab_info", 92 "var_name_to_prev_var_name", 93 ])): 94 """Settings for warm-starting in Estimators. 95 96 Example Use with canned `DNNEstimator`: 97 98 ``` 99 emb_vocab_file = tf.feature_column.embedding_column( 100 tf.feature_column.categorical_column_with_vocabulary_file( 101 "sc_vocab_file", "new_vocab.txt", vocab_size=100), 102 dimension=8) 103 emb_vocab_list = tf.feature_column.embedding_column( 104 tf.feature_column.categorical_column_with_vocabulary_list( 105 "sc_vocab_list", vocabulary_list=["a", "b"]), 106 dimension=8) 107 estimator = tf.estimator.DNNClassifier( 108 hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list], 109 warm_start_from=ws) 110 ``` 111 112 where `ws` could be defined as: 113 114 Warm-start all weights in the model (input layer and hidden weights). 115 Either the directory or a specific checkpoint can be provided (in the case 116 of the former, the latest checkpoint will be used): 117 118 ``` 119 ws = WarmStartSettings(ckpt_to_initialize_from="/tmp") 120 ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") 121 ``` 122 123 Warm-start only the embeddings (input layer): 124 125 ``` 126 ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", 127 vars_to_warm_start=".*input_layer.*") 128 ``` 129 130 Warm-start all weights but the embedding parameters corresponding to 131 `sc_vocab_file` have a different vocab from the one used in the current 132 model: 133 134 ``` 135 vocab_info = ws_util.VocabInfo( 136 new_vocab=sc_vocab_file.vocabulary_file, 137 new_vocab_size=sc_vocab_file.vocabulary_size, 138 num_oov_buckets=sc_vocab_file.num_oov_buckets, 139 old_vocab="old_vocab.txt" 140 ) 141 ws = WarmStartSettings( 142 ckpt_to_initialize_from="/tmp", 143 var_name_to_vocab_info={ 144 "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info 145 }) 146 ``` 147 148 Warm-start only `sc_vocab_file` embeddings (and no other variables), which 149 have a different vocab from the one used in the current model: 150 151 ``` 152 vocab_info = ws_util.VocabInfo( 153 new_vocab=sc_vocab_file.vocabulary_file, 154 new_vocab_size=sc_vocab_file.vocabulary_size, 155 num_oov_buckets=sc_vocab_file.num_oov_buckets, 156 old_vocab="old_vocab.txt" 157 ) 158 ws = WarmStartSettings( 159 ckpt_to_initialize_from="/tmp", 160 vars_to_warm_start=None, 161 var_name_to_vocab_info={ 162 "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info 163 }) 164 ``` 165 166 Warm-start all weights but the parameters corresponding to `sc_vocab_file` 167 have a different vocab from the one used in current checkpoint, and only 168 100 of those entries were used: 169 170 ``` 171 vocab_info = ws_util.VocabInfo( 172 new_vocab=sc_vocab_file.vocabulary_file, 173 new_vocab_size=sc_vocab_file.vocabulary_size, 174 num_oov_buckets=sc_vocab_file.num_oov_buckets, 175 old_vocab="old_vocab.txt", 176 old_vocab_size=100 177 ) 178 ws = WarmStartSettings( 179 ckpt_to_initialize_from="/tmp", 180 var_name_to_vocab_info={ 181 "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info 182 }) 183 ``` 184 185 Warm-start all weights but the parameters corresponding to `sc_vocab_file` 186 have a different vocab from the one used in current checkpoint and the 187 parameters corresponding to `sc_vocab_list` have a different name from the 188 current checkpoint: 189 190 ``` 191 vocab_info = ws_util.VocabInfo( 192 new_vocab=sc_vocab_file.vocabulary_file, 193 new_vocab_size=sc_vocab_file.vocabulary_size, 194 num_oov_buckets=sc_vocab_file.num_oov_buckets, 195 old_vocab="old_vocab.txt", 196 old_vocab_size=100 197 ) 198 ws = WarmStartSettings( 199 ckpt_to_initialize_from="/tmp", 200 var_name_to_vocab_info={ 201 "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info 202 }, 203 var_name_to_prev_var_name={ 204 "input_layer/sc_vocab_list_embedding/embedding_weights": 205 "old_tensor_name" 206 }) 207 ``` 208 209 Attributes: 210 ckpt_to_initialize_from: [Required] A string specifying the directory with 211 checkpoint file(s) or path to checkpoint from which to warm-start the 212 model parameters. 213 vars_to_warm_start: [Optional] A regular expression that captures which 214 variables to warm-start (see tf.get_collection). Defaults to `'.*'`, 215 which warm-starts all variables. If `None` is explicitly given, only 216 variables specified in `var_name_to_vocab_info` will be warm-started. 217 var_name_to_vocab_info: [Optional] Dict of variable names (strings) to 218 VocabInfo. The variable names should be "full" variables, not the names 219 of the partitions. If not explicitly provided, the variable is assumed to 220 have no vocabulary. 221 var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to 222 name of the previously-trained variable in `ckpt_to_initialize_from`. If 223 not explicitly provided, the name of the variable is assumed to be same 224 between previous checkpoint and current model. 225 """ 226 227 def __new__(cls, 228 ckpt_to_initialize_from, 229 vars_to_warm_start=".*", 230 var_name_to_vocab_info=None, 231 var_name_to_prev_var_name=None): 232 if not ckpt_to_initialize_from: 233 raise ValueError( 234 "`ckpt_to_initialize_from` MUST be set in WarmStartSettings") 235 return super(WarmStartSettings, cls).__new__( 236 cls, 237 ckpt_to_initialize_from, 238 vars_to_warm_start, 239 var_name_to_vocab_info or {}, 240 var_name_to_prev_var_name or {}, 241 ) 242 243 244 def _is_variable(x): 245 return (isinstance(x, variables_lib.Variable) or 246 isinstance(x, resource_variable_ops.ResourceVariable)) 247 248 249 def _infer_var_name(var): 250 """Returns name of the `var`. 251 252 Args: 253 var: A list. The list can contain either of the following: 254 (i) A single `Variable` 255 (ii) A single `ResourceVariable` 256 (iii) Multiple `Variable` objects which must be slices of the same larger 257 variable. 258 (iv) A single `PartitionedVariable` 259 260 Returns: 261 Name of the `var` 262 """ 263 name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var) 264 if len(name_to_var_dict) > 1: 265 raise TypeError("`var` = %s passed as arg violates the constraints. " 266 "name_to_var_dict = %s" % (var, name_to_var_dict)) 267 return list(name_to_var_dict.keys())[0] 268 269 270 def _warm_start_var(var, prev_ckpt, prev_tensor_name=None): 271 """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. 272 273 Args: 274 var: Current graph's variable that needs to be warm-started (initialized). 275 Can be either of the following: 276 (i) `Variable` 277 (ii) `ResourceVariable` 278 (iii) list of `Variable`: The list must contain slices of the same larger 279 variable. 280 (iv) `PartitionedVariable` 281 prev_ckpt: A string specifying the directory with checkpoint file(s) or path 282 to checkpoint. The given checkpoint must have tensor with name 283 `prev_tensor_name` (if not None) or tensor with name same as given `var`. 284 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If 285 None, we lookup tensor with same name as given `var`. 286 """ 287 if _is_variable(var): 288 current_var_name = _infer_var_name([var]) 289 elif isinstance(var, list) and all(_is_variable(v) for v in var): 290 current_var_name = _infer_var_name(var) 291 elif isinstance(var, variables_lib.PartitionedVariable): 292 current_var_name = _infer_var_name([var]) 293 var = var._get_variable_list() # pylint: disable=protected-access 294 else: 295 raise TypeError( 296 "var MUST be one of the following: a Variable, list of Variable or " 297 "PartitionedVariable, but is {}".format(type(var))) 298 if not prev_tensor_name: 299 # Assume tensor name remains the same. 300 prev_tensor_name = current_var_name 301 checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var}) 302 303 304 # pylint: disable=protected-access 305 # Accesses protected members of tf.Variable to reset the variable's internal 306 # state. 307 def _warm_start_var_with_vocab(var, 308 current_vocab_path, 309 current_vocab_size, 310 prev_ckpt, 311 prev_vocab_path, 312 previous_vocab_size=-1, 313 current_oov_buckets=0, 314 prev_tensor_name=None, 315 initializer=None): 316 """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. 317 318 Use this method when the `var` is backed by vocabulary. This method stitches 319 the given `var` such that values corresponding to individual features in the 320 vocabulary remain consistent irrespective of changing order of the features 321 between old and new vocabularies. 322 323 Args: 324 var: Current graph's variable that needs to be warm-started (initialized). 325 Can be either of the following: 326 (i) `Variable` 327 (ii) `ResourceVariable` 328 (iii) list of `Variable`: The list must contain slices of the same larger 329 variable. 330 (iv) `PartitionedVariable` 331 current_vocab_path: Path to the vocab file used for the given `var`. 332 current_vocab_size: An `int` specifying the number of entries in the current 333 vocab. 334 prev_ckpt: A string specifying the directory with checkpoint file(s) or path 335 to checkpoint. The given checkpoint must have tensor with name 336 `prev_tensor_name` (if not None) or tensor with name same as given `var`. 337 prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`. 338 previous_vocab_size: If provided, will constrain previous vocab to the first 339 `previous_vocab_size` entries. -1 means use the entire previous vocab. 340 current_oov_buckets: An `int` specifying the number of out-of-vocabulary 341 buckets used for given `var`. 342 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If 343 None, we lookup tensor with same name as given `var`. 344 initializer: Variable initializer to be used for missing entries. If None, 345 missing entries will be zero-initialized. 346 347 Raises: 348 ValueError: If required args are not provided. 349 """ 350 if not (current_vocab_path and current_vocab_size and prev_ckpt and 351 prev_vocab_path): 352 raise ValueError("Invalid args: Must provide all of [current_vocab_path, " 353 "current_vocab_size, prev_ckpt, prev_vocab_path}.") 354 if _is_variable(var): 355 var = [var] 356 elif isinstance(var, list) and all(_is_variable(v) for v in var): 357 var = var 358 elif isinstance(var, variables_lib.PartitionedVariable): 359 var = var._get_variable_list() 360 else: 361 raise TypeError( 362 "var MUST be one of the following: a Variable, list of Variable or " 363 "PartitionedVariable, but is {}".format(type(var))) 364 365 if not prev_tensor_name: 366 # Assume tensor name remains the same. 367 prev_tensor_name = _infer_var_name(var) 368 369 for v in var: 370 v_shape = v.get_shape().as_list() 371 slice_info = v._get_save_slice_info() 372 partition_info = None 373 if slice_info: 374 partition_info = variable_scope._PartitionInfo( 375 full_shape=slice_info.full_shape, 376 var_offset=slice_info.var_offset) 377 378 # TODO(eddz): Support WarmStartSettings where class vocabularies need 379 # remapping too. 380 init = checkpoint_ops._load_and_remap_matrix_initializer( 381 ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), 382 old_tensor_name=prev_tensor_name, 383 new_row_vocab_size=current_vocab_size, 384 new_col_vocab_size=v_shape[1], 385 old_row_vocab_size=previous_vocab_size, 386 old_row_vocab_file=prev_vocab_path, 387 new_row_vocab_file=current_vocab_path, 388 old_col_vocab_file=None, 389 new_col_vocab_file=None, 390 num_row_oov_buckets=current_oov_buckets, 391 num_col_oov_buckets=0, 392 initializer=initializer) 393 new_init_val = ops.convert_to_tensor( 394 init(shape=v_shape, partition_info=partition_info)) 395 v._initializer_op = state_ops.assign(v, new_init_val) 396 # pylint: enable=protected-access 397 398 399 def _warm_start(warm_start_settings): 400 """Warm-starts a model using the given settings. 401 402 If you are using a tf.estimator.Estimator, this will automatically be called 403 during training. 404 405 Args: 406 warm_start_settings: An object of `WarmStartSettings`. 407 Raises: 408 ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo 409 configuration for variable names that are not used. This is to ensure 410 a stronger check for variable configuration than relying on users to 411 examine the logs. 412 """ 413 logging.info("Warm-starting from: %s", 414 (warm_start_settings.ckpt_to_initialize_from,)) 415 # We have to deal with partitioned variables, since get_collection flattens 416 # out the list. 417 grouped_variables = {} 418 # Both warm_start_settings.vars_to_warm_start = '.*' and 419 # warm_start_settings.vars_to_warm_start = None will match everything here. 420 for v in ops.get_collection( 421 # TODO(eddz): Allow for different collections here (to support 422 # warm-starting accumulators). 423 ops.GraphKeys.TRAINABLE_VARIABLES, 424 scope=warm_start_settings.vars_to_warm_start): 425 if not isinstance(v, list): 426 var_name = _infer_var_name([v]) 427 else: 428 var_name = _infer_var_name(v) 429 grouped_variables.setdefault(var_name, []).append(v) 430 431 # Keep track of which var_names in var_name_to_prev_var_name and 432 # var_name_to_vocab_info have been used. Err on the safer side by throwing an 433 # exception if any are unused by the end of the loop. It is easy to misname 434 # a variable during this configuration, in which case without this check, we 435 # would fail to warm-start silently. 436 prev_var_name_used = set() 437 vocab_info_used = set() 438 439 for var_name, variable in six.iteritems(grouped_variables): 440 prev_var_name = warm_start_settings.var_name_to_prev_var_name.get(var_name) 441 if prev_var_name: 442 prev_var_name_used.add(var_name) 443 vocab_info = warm_start_settings.var_name_to_vocab_info.get(var_name) 444 if vocab_info: 445 vocab_info_used.add(var_name) 446 logging.info( 447 "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" 448 " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" 449 " initializer: {}".format( 450 var_name, 451 vocab_info.new_vocab, 452 vocab_info.new_vocab_size, 453 vocab_info.old_vocab, 454 (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0 455 else "All"), 456 vocab_info.num_oov_buckets, 457 prev_var_name or "Unchanged", 458 vocab_info.backup_initializer or "zero-initialized")) 459 _warm_start_var_with_vocab( 460 variable, 461 current_vocab_path=vocab_info.new_vocab, 462 current_vocab_size=vocab_info.new_vocab_size, 463 prev_ckpt=warm_start_settings.ckpt_to_initialize_from, 464 prev_vocab_path=vocab_info.old_vocab, 465 previous_vocab_size=vocab_info.old_vocab_size, 466 current_oov_buckets=vocab_info.num_oov_buckets, 467 prev_tensor_name=prev_var_name, 468 initializer=vocab_info.backup_initializer) 469 else: 470 # For the special value of warm_start_settings.vars_to_warm_start = None, 471 # we only warm-start variables with explicitly specified vocabularies. 472 if warm_start_settings.vars_to_warm_start: 473 logging.info("Warm-starting variable: {}; prev_var_name: {}".format( 474 var_name, prev_var_name or "Unchanged")) 475 # Because we use a default empty list in grouped_variables, single 476 # unpartitioned variables will be lists here, which we rectify in order 477 # for init_from_checkpoint logic to work correctly. 478 if len(variable) == 1: 479 variable = variable[0] 480 _warm_start_var(variable, warm_start_settings.ckpt_to_initialize_from, 481 prev_var_name) 482 483 prev_var_name_not_used = set( 484 warm_start_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used 485 vocab_info_not_used = set( 486 warm_start_settings.var_name_to_vocab_info.keys()) - vocab_info_used 487 488 if prev_var_name_not_used: 489 raise ValueError( 490 "You provided the following variables in " 491 "warm_start_settings.var_name_to_prev_var_name that were not used: " 492 "{0}. Perhaps you misspelled them? Here is the list of viable " 493 "variable names: {1}".format(prev_var_name_not_used, 494 grouped_variables.keys())) 495 if vocab_info_not_used: 496 raise ValueError( 497 "You provided the following variables in " 498 "warm_start_settings.var_name_to_vocab_info that were not used: {0}. " 499 " Perhaps you misspelled them? Here is the list of viable variable " 500 "names: {1}".format(vocab_info_not_used, grouped_variables.keys())) 501 502 503 def _get_default_warm_start_settings(warm_start_from): 504 """Returns default WarmStartSettings. 505 506 Args: 507 warm_start_from: Either a string representing the filepath of a checkpoint 508 to initialize from, or an instance of WarmStartSettings. 509 510 Returns: 511 Either None or an instance of WarmStartSettings. 512 513 Raises: 514 ValueError: If warm_start_from is not None but is neither a string nor an 515 instance of WarmStartSettings. 516 """ 517 if warm_start_from is None: 518 return None 519 if isinstance(warm_start_from, six.string_types): 520 return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) 521 elif isinstance(warm_start_from, WarmStartSettings): 522 return warm_start_from 523 else: 524 raise ValueError("warm_start_from must be a string or a WarmStartSettings") 525