1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 # pylint: disable=invalid-name 17 """Save and restore variables.""" 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections 23 import os.path 24 import re 25 import time 26 import uuid 27 28 import numpy as np 29 import six 30 31 from google.protobuf import text_format 32 33 from tensorflow.core.protobuf import meta_graph_pb2 34 from tensorflow.core.protobuf import saver_pb2 35 from tensorflow.python.client import session 36 from tensorflow.python.eager import context 37 from tensorflow.python.framework import constant_op 38 from tensorflow.python.framework import device as pydev 39 from tensorflow.python.framework import errors 40 from tensorflow.python.framework import meta_graph 41 from tensorflow.python.framework import ops 42 from tensorflow.python.lib.io import file_io 43 from tensorflow.python.ops import array_ops 44 from tensorflow.python.ops import control_flow_ops 45 from tensorflow.python.ops import gen_io_ops 46 from tensorflow.python.ops import io_ops 47 from tensorflow.python.ops import resource_variable_ops 48 from tensorflow.python.ops import state_ops 49 from tensorflow.python.ops import string_ops 50 from tensorflow.python.ops import variables 51 from tensorflow.python.platform import gfile 52 from tensorflow.python.platform import tf_logging as logging 53 from tensorflow.python.training import training_util 54 from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 55 from tensorflow.python.util import compat 56 from tensorflow.python.util.tf_export import tf_export 57 58 59 # Op names which identify variable reads which should be saved. 60 _VARIABLE_OPS = set(["Variable", 61 "VariableV2", 62 "AutoReloadVariable", 63 "VarHandleOp", 64 "ReadVariableOp"]) 65 66 67 def _set_cpu0(device_string): 68 """Creates a new device string based on `device_string` but using /CPU:0. 69 70 If the device is already on /CPU:0, this is a no-op. 71 72 Args: 73 device_string: A device string. 74 75 Returns: 76 A device string. 77 """ 78 parsed_device = pydev.DeviceSpec.from_string(device_string) 79 parsed_device.device_type = "CPU" 80 parsed_device.device_index = 0 81 return parsed_device.to_string() 82 83 84 class BaseSaverBuilder(object): 85 """Base class for Savers. 86 87 Can be extended to create different Ops. 88 """ 89 90 class SaveSpec(object): 91 """Class used to describe tensor slices that need to be saved.""" 92 93 def __init__(self, tensor, slice_spec, name): 94 """Creates a `SaveSpec` object. 95 96 Args: 97 tensor: the tensor to save or callable that produces a tensor to save. 98 slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`. 99 name: the name to save the tensor under. 100 """ 101 self._tensor = tensor 102 self.slice_spec = slice_spec 103 self.name = name 104 105 @property 106 def tensor(self): 107 return self._tensor() if callable(self._tensor) else self._tensor 108 109 class SaveableObject(object): 110 """Base class for saving and restoring saveable objects.""" 111 112 def __init__(self, op, specs, name): 113 """Creates a `SaveableObject` object. 114 115 Args: 116 op: the "producer" object that this class wraps; it produces a list of 117 tensors to save. E.g., a "Variable" object saving its backing tensor. 118 specs: a list of SaveSpec, each element of which describes one tensor to 119 save under this object. 120 name: the name to save the object under. 121 """ 122 self.op = op 123 self.specs = specs 124 self.name = name 125 # The device of this saveable. All tensors must be on the same device. 126 self.device = specs[0].tensor.device 127 128 def restore(self, restored_tensors, restored_shapes): 129 """Restores this object from 'restored_tensors'. 130 131 Args: 132 restored_tensors: the tensors that were loaded from a checkpoint 133 restored_shapes: the shapes this object should conform to after 134 restore, or None. 135 136 Returns: 137 An operation that restores the state of the object. 138 139 Raises: 140 ValueError: If the object cannot be restored using the provided 141 parameters. 142 """ 143 # pylint: disable=unused-argument 144 raise ValueError("Calling an abstract method.") 145 146 class VariableSaveable(SaveableObject): 147 """SaveableObject implementation that handles Variables.""" 148 149 def __init__(self, var, slice_spec, name): 150 spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name) 151 super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name) 152 153 def restore(self, restored_tensors, restored_shapes): 154 restored_tensor = restored_tensors[0] 155 if restored_shapes is not None: 156 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 157 return state_ops.assign( 158 self.op, 159 restored_tensor, 160 validate_shape=restored_shapes is None and 161 self.op.get_shape().is_fully_defined()) 162 163 class ResourceVariableSaveable(SaveableObject): 164 """SaveableObject implementation that handles ResourceVariables.""" 165 166 def __init__(self, var, slice_spec, name): 167 self._var_device = var.device 168 self._var_shape = var.shape 169 if isinstance(var, ops.Tensor): 170 self.handle_op = var.op.inputs[0] 171 tensor = var 172 elif isinstance(var, resource_variable_ops.ResourceVariable): 173 174 def _read_variable_closure(v): 175 def f(): 176 with ops.device(v.device): 177 x = v.read_value() 178 with ops.device("/device:CPU:0"): 179 return array_ops.identity(x) 180 return f 181 182 self.handle_op = var.handle 183 tensor = _read_variable_closure(var) 184 else: 185 raise ValueError( 186 "Saveable is neither a resource variable nor a read operation." 187 " Got: %s" % repr(var)) 188 spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name) 189 super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__( 190 var, [spec], name) 191 192 def restore(self, restored_tensors, restored_shapes): 193 restored_tensor = restored_tensors[0] 194 if restored_shapes is not None: 195 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 196 # Copy the restored tensor to the variable's device. 197 with ops.device(self._var_device): 198 restored_tensor = array_ops.identity(restored_tensor) 199 return resource_variable_ops.shape_safe_assign_variable_handle( 200 self.handle_op, self._var_shape, restored_tensor) 201 202 def __init__(self, write_version=saver_pb2.SaverDef.V2): 203 self._write_version = write_version 204 205 def save_op(self, filename_tensor, saveables): 206 """Create an Op to save 'saveables'. 207 208 This is intended to be overridden by subclasses that want to generate 209 different Ops. 210 211 Args: 212 filename_tensor: String Tensor. 213 saveables: A list of BaseSaverBuilder.SaveableObject objects. 214 215 Returns: 216 An Operation that save the variables. 217 218 Raises: 219 RuntimeError: (implementation detail) if "self._write_version" is an 220 unexpected value. 221 """ 222 # pylint: disable=protected-access 223 tensor_names = [] 224 tensors = [] 225 tensor_slices = [] 226 for saveable in saveables: 227 for spec in saveable.specs: 228 tensor_names.append(spec.name) 229 tensors.append(spec.tensor) 230 tensor_slices.append(spec.slice_spec) 231 if self._write_version == saver_pb2.SaverDef.V1: 232 return io_ops._save( 233 filename=filename_tensor, 234 tensor_names=tensor_names, 235 tensors=tensors, 236 tensor_slices=tensor_slices) 237 elif self._write_version == saver_pb2.SaverDef.V2: 238 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 239 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 240 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 241 tensors) 242 else: 243 raise RuntimeError("Unexpected write_version: " + self._write_version) 244 245 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 246 restore_sequentially): 247 """Restore all tensors contained in saveables. 248 249 By default, this issues separate calls to `restore_op` for each saveable. 250 Subclasses may override to load multiple saveables in a single call. 251 252 Args: 253 filename_tensor: String Tensor. 254 saveables: List of BaseSaverBuilder.SaveableObject objects. 255 preferred_shard: Int. Shard to open first when loading a sharded file. 256 restore_sequentially: Bool. If true, each restore is sequential. 257 258 Returns: 259 A list of Tensors resulting from reading 'saveable' from 260 'filename'. 261 262 """ 263 all_tensors = [] 264 assign_ops = [] 265 for saveable in saveables: 266 restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] 267 with ops.device(_set_cpu0(saveable.device) if saveable.device else None): 268 with ops.control_dependencies(restore_control_inputs): 269 all_tensors.extend( 270 self.restore_op(filename_tensor, saveable, preferred_shard)) 271 return all_tensors 272 273 # pylint: disable=unused-argument 274 def restore_op(self, filename_tensor, saveable, preferred_shard): 275 """Create ops to restore 'saveable'. 276 277 This is intended to be overridden by subclasses that want to generate 278 different Ops. 279 280 Args: 281 filename_tensor: String Tensor. 282 saveable: A BaseSaverBuilder.SaveableObject object. 283 preferred_shard: Int. Shard to open first when loading a sharded file. 284 285 Returns: 286 A list of Tensors resulting from reading 'saveable' from 287 'filename'. 288 """ 289 # pylint: disable=protected-access 290 tensors = [] 291 for spec in saveable.specs: 292 tensors.append( 293 io_ops.restore_v2( 294 filename_tensor, 295 [spec.name], 296 [spec.slice_spec], 297 [spec.tensor.dtype])[0]) 298 299 return tensors 300 # pylint: enable=unused-argument 301 302 def sharded_filename(self, filename_tensor, shard, num_shards): 303 """Append sharding information to a filename. 304 305 Args: 306 filename_tensor: A string tensor. 307 shard: Integer. The shard for the filename. 308 num_shards: An int Tensor for the number of shards. 309 310 Returns: 311 A string tensor. 312 """ 313 # pylint: disable=protected-access 314 return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards) 315 316 def _AddSaveOps(self, filename_tensor, saveables): 317 """Add ops to save variables that are on the same shard. 318 319 Args: 320 filename_tensor: String Tensor. 321 saveables: A list of SaveableObject objects. 322 323 Returns: 324 A tensor with the filename used to save. 325 """ 326 save = self.save_op(filename_tensor, saveables) 327 return control_flow_ops.with_dependencies([save], filename_tensor) 328 329 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 330 """Add ops to save the params per shard, for the V2 format. 331 332 Note that the sharded save procedure for the V2 format is different from 333 V1: there is a special "merge" step that merges the small metadata produced 334 from each device. 335 336 Args: 337 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A 338 FILENAME*, but as a prefix of a V2 checkpoint; 339 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 340 returned by _GroupByDevices(). 341 342 Returns: 343 An op to save the variables, which, when evaluated, returns the prefix 344 "<user-fed prefix>" only and does not include the sharded spec suffix. 345 """ 346 # IMPLEMENTATION DETAILS: most clients should skip. 347 # 348 # Suffix for any well-formed "checkpoint_prefix", when sharded. 349 # Transformations: 350 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 351 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 352 # 353 # Example: 354 # During runtime, a temporary directory is first created, which contains 355 # files 356 # 357 # <train dir>/myckpt_temp/ 358 # part-?????-of-?????{.index, .data-00000-of-00001} 359 # 360 # Before .save() finishes, they will be (hopefully, atomically) renamed to 361 # 362 # <train dir>/ 363 # myckpt{.index, .data-?????-of-?????} 364 # 365 # Users only need to interact with the user-specified prefix, which is 366 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 367 # prefix directly, instead of any physical pathname. (On failure and 368 # subsequent restore, an outdated and orphaned temporary directory can be 369 # safely removed.) 370 _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex 371 tmp_checkpoint_prefix = string_ops.string_join( 372 [checkpoint_prefix, _SHARDED_SUFFIX]) 373 374 num_shards = len(per_device) 375 sharded_saves = [] 376 sharded_prefixes = [] 377 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 378 last_device = None 379 for shard, (device, saveables) in enumerate(per_device): 380 last_device = device 381 with ops.device(_set_cpu0(device)): 382 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 383 num_shards_tensor) 384 sharded_prefixes.append(sharded_filename) 385 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 386 387 with ops.control_dependencies([x.op for x in sharded_saves]): 388 # Co-locates the merge step with the last device. 389 with ops.device(_set_cpu0(last_device)): 390 # V2 format write path consists of a metadata merge step. Once merged, 391 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 392 merge_step = gen_io_ops.merge_v2_checkpoints( 393 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 394 with ops.control_dependencies([merge_step]): 395 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 396 # sharded spec suffix. 397 return array_ops.identity(checkpoint_prefix) 398 399 def _AddShardedSaveOps(self, filename_tensor, per_device): 400 """Add ops to save the params per shard. 401 402 Args: 403 filename_tensor: a scalar String Tensor. 404 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 405 returned by _GroupByDevices(). 406 407 Returns: 408 An op to save the variables. 409 """ 410 if self._write_version == saver_pb2.SaverDef.V2: 411 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 412 413 num_shards = len(per_device) 414 sharded_saves = [] 415 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 416 for shard, (device, saveables) in enumerate(per_device): 417 with ops.device(device): 418 sharded_filename = self.sharded_filename(filename_tensor, shard, 419 num_shards_tensor) 420 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 421 # Return the sharded name for the save path. 422 with ops.control_dependencies([x.op for x in sharded_saves]): 423 # pylint: disable=protected-access 424 return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor) 425 426 def _AddRestoreOps(self, 427 filename_tensor, 428 saveables, 429 restore_sequentially, 430 reshape, 431 preferred_shard=-1, 432 name="restore_all"): 433 """Add operations to restore saveables. 434 435 Args: 436 filename_tensor: Tensor for the path of the file to load. 437 saveables: A list of SaveableObject objects. 438 restore_sequentially: True if we want to restore variables sequentially 439 within a shard. 440 reshape: True if we want to reshape loaded tensors to the shape of 441 the corresponding variable. 442 preferred_shard: Shard to open first when loading a sharded file. 443 name: Name for the returned op. 444 445 Returns: 446 An Operation that restores the variables. 447 """ 448 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 449 restore_sequentially) 450 451 assign_ops = [] 452 idx = 0 453 # Load and optionally reshape on the CPU, as string tensors are not 454 # available on the GPU. 455 # TODO(touts): Re-enable restore on GPU when we can support annotating 456 # string tensors as "HostMemory" inputs. 457 for saveable in saveables: 458 shapes = None 459 if reshape: 460 # Compute the shapes, let the restore op decide if and how to do 461 # the reshape. 462 shapes = [] 463 for spec in saveable.specs: 464 v = spec.tensor 465 shape = v.get_shape() 466 if not shape.is_fully_defined(): 467 shape = array_ops.shape(v) 468 shapes.append(shape) 469 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 470 idx += len(saveable.specs) 471 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 472 473 # Create a Noop that has control dependencies from all the updates. 474 return control_flow_ops.group(*assign_ops, name=name) 475 476 def _AddShardedRestoreOps(self, filename_tensor, per_device, 477 restore_sequentially, reshape): 478 """Add Ops to restore variables from multiple devices. 479 480 Args: 481 filename_tensor: Tensor for the path of the file to load. 482 per_device: A list of (device, SaveableObject) pairs, as 483 returned by _GroupByDevices(). 484 restore_sequentially: True if we want to restore variables sequentially 485 within a shard. 486 reshape: True if we want to reshape loaded tensors to the shape of 487 the corresponding variable. 488 489 Returns: 490 An Operation that restores the variables. 491 """ 492 sharded_restores = [] 493 for shard, (device, saveables) in enumerate(per_device): 494 with ops.device(device): 495 sharded_restores.append( 496 self._AddRestoreOps( 497 filename_tensor, 498 saveables, 499 restore_sequentially, 500 reshape, 501 preferred_shard=shard, 502 name="restore_shard")) 503 return control_flow_ops.group(*sharded_restores, name="restore_all") 504 505 @staticmethod 506 def _IsVariable(v): 507 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS 508 509 def _GroupByDevices(self, saveables): 510 """Group Variable tensor slices per device. 511 512 TODO(touts): Make sure that all the devices found are on different 513 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 514 It can happen if the devices are unspecified. 515 516 Args: 517 saveables: A list of BaseSaverBuilder.SaveableObject objects. 518 519 Returns: 520 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 521 The list is sorted by ascending device_name. 522 523 Raises: 524 ValueError: If the tensors of a saveable are on different devices. 525 """ 526 per_device = collections.defaultdict(lambda: []) 527 for saveable in saveables: 528 canonical_device = set( 529 pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) 530 if len(canonical_device) != 1: 531 raise ValueError("All tensors of a saveable object must be " 532 "on the same device: %s" % saveable.name) 533 per_device[canonical_device.pop()].append(saveable) 534 return sorted(per_device.items(), key=lambda t: t[0]) 535 536 @staticmethod 537 def OpListToDict(op_list, convert_variable_to_tensor=True): 538 """Create a dictionary of names to operation lists. 539 540 Args: 541 op_list: A list, tuple, or set of Variables or SaveableObjects. 542 convert_variable_to_tensor: Whether or not to convert single Variables 543 with no slice info into Tensors. 544 545 Returns: 546 A dictionary of names to the operations that must be saved under 547 that name. Variables with save_slice_info are grouped together under the 548 same key in no particular order. 549 550 Raises: 551 TypeError: If the type of op_list or its elements is not supported. 552 ValueError: If at least two saveables share the same name. 553 """ 554 if not isinstance(op_list, (list, tuple, set)): 555 raise TypeError("Variables to save should be passed in a dict or a " 556 "list: %s" % op_list) 557 # When ResourceVariables are converted to Tensors, read ops are added to the 558 # graph. Sorting the op_list ensures that the resulting graph is always 559 # constructed in a deterministic way: 560 op_list = sorted(op_list, key=lambda x: x.name) 561 names_to_saveables = {} 562 # pylint: disable=protected-access 563 for var in op_list: 564 if isinstance(var, BaseSaverBuilder.SaveableObject): 565 names_to_saveables[var.name] = var 566 elif isinstance(var, variables.PartitionedVariable): 567 if var.name in names_to_saveables: 568 raise ValueError("At least two variables have the same name: %s" % 569 var.name) 570 names_to_saveables[var.name] = var 571 elif isinstance(var, variables.Variable) and var._save_slice_info: 572 name = var._save_slice_info.full_name 573 if name in names_to_saveables: 574 if not isinstance(names_to_saveables[name], list): 575 raise ValueError("Mixing slices and non-slices with the same name: " 576 "%s" % name) 577 names_to_saveables[name].append(var) 578 else: 579 names_to_saveables[name] = [var] 580 else: 581 if context.in_graph_mode(): 582 if convert_variable_to_tensor: 583 var = ops.internal_convert_to_tensor(var, as_ref=True) 584 if not BaseSaverBuilder._IsVariable(var): 585 raise TypeError("Variable to save is not a Variable: %s" % var) 586 if var.op.type == "ReadVariableOp": 587 name = var.op.inputs[0].op.name 588 else: 589 name = var.op.name 590 if name in names_to_saveables: 591 raise ValueError("At least two variables have the same name: %s" % 592 name) 593 names_to_saveables[name] = var 594 else: 595 if not isinstance(var, resource_variable_ops.ResourceVariable): 596 raise ValueError("Can only save/restore ResourceVariable eager " 597 "mode is enabled, type: %s." % type(var)) 598 set_var = names_to_saveables.setdefault(var._shared_name, var) 599 if set_var is not var: 600 raise ValueError( 601 ("Two different ResourceVariable objects with the same " 602 "shared_name '%s' were passed to the Saver. This likely means " 603 "that they were created in different Graphs or isolation " 604 "contexts, and may not be checkpointed together.") % ( 605 var._shared_name,)) 606 607 # pylint: enable=protected-access 608 return names_to_saveables 609 610 def _ValidateAndSliceInputs(self, names_to_saveables): 611 """Returns the variables and names that will be used for a Saver. 612 613 Args: 614 names_to_saveables: A dict (k, v) where k is the name of an operation and 615 v is an operation to save or a BaseSaverBuilder.Saver. 616 617 Returns: 618 A list of BaseSaverBuilder.SaveableObject objects. 619 620 Raises: 621 TypeError: If any of the keys are not strings or any of the 622 values are not one of Tensor or Variable or a checkpointable operation. 623 ValueError: If the same operation is given in more than one value 624 (this also applies to slices of SlicedVariables). 625 """ 626 if not isinstance(names_to_saveables, dict): 627 names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables) 628 629 saveables = [] 630 seen_ops = set() 631 for name in sorted(names_to_saveables.keys()): 632 if not isinstance(name, six.string_types): 633 raise TypeError( 634 "names_to_saveables must be a dict mapping string names to " 635 "checkpointable operations. Name is not a string: %s" % name) 636 op = names_to_saveables[name] 637 if isinstance(op, BaseSaverBuilder.SaveableObject): 638 self._AddSaveable(saveables, seen_ops, op) 639 elif isinstance(op, (list, tuple, variables.PartitionedVariable)): 640 if isinstance(op, variables.PartitionedVariable): 641 op = list(op) 642 # A set of slices. 643 slice_name = None 644 # pylint: disable=protected-access 645 for variable in op: 646 if not isinstance(variable, variables.Variable): 647 raise ValueError("Slices must all be Variables: %s" % variable) 648 if not variable._save_slice_info: 649 raise ValueError("Slices must all be slices: %s" % variable) 650 if slice_name is None: 651 slice_name = variable._save_slice_info.full_name 652 elif slice_name != variable._save_slice_info.full_name: 653 raise ValueError( 654 "Slices must all be from the same tensor: %s != %s" % 655 (slice_name, variable._save_slice_info.full_name)) 656 if variable.op.type in ["Variable", "VariableV2", 657 "AutoReloadVariable"]: 658 saveable = BaseSaverBuilder.VariableSaveable( 659 variable, variable._save_slice_info.spec, name) 660 else: 661 saveable = BaseSaverBuilder.ResourceVariableSaveable( 662 variable, variable._save_slice_info.spec, name) 663 self._AddSaveable(saveables, seen_ops, saveable) 664 # pylint: enable=protected-access 665 else: 666 # A variable or tensor. 667 if context.in_eager_mode(): 668 if not isinstance(op, resource_variable_ops.ResourceVariable): 669 raise ValueError("Can only save/restore ResourceVariable eager " 670 "mode is enabled, type: %s." % type(op)) 671 saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name) 672 else: 673 variable = ops.internal_convert_to_tensor(op, as_ref=True) 674 if not BaseSaverBuilder._IsVariable(variable): 675 raise TypeError("names_to_saveables must be a dict mapping string " 676 "names to Tensors/Variables. Not a variable: %s" % 677 variable) 678 if variable.op.type in ["Variable", "VariableV2", 679 "AutoReloadVariable"]: 680 saveable = BaseSaverBuilder.VariableSaveable(variable, "", name) 681 else: 682 saveable = BaseSaverBuilder.ResourceVariableSaveable( 683 variable, "", name) 684 self._AddSaveable(saveables, seen_ops, saveable) 685 return saveables 686 687 def _AddSaveable(self, saveables, seen_ops, saveable): 688 """Adds the saveable to the saveables list. 689 690 Args: 691 saveables: List to append the SaveableObject to. 692 seen_ops: Set of the ops of the saveables already processed. Used to 693 check that each saveable is only saved once. 694 saveable: The saveable. 695 696 Raises: 697 ValueError: If the saveable has already been processed. 698 """ 699 if saveable.op in seen_ops: 700 raise ValueError("The same saveable will be restored with two names: %s" % 701 saveable.name) 702 saveables.append(saveable) 703 seen_ops.add(saveable.op) 704 705 def build(self, 706 names_to_saveables, 707 reshape=False, 708 sharded=False, 709 max_to_keep=5, 710 keep_checkpoint_every_n_hours=10000.0, 711 name=None, 712 restore_sequentially=False, 713 filename="model"): 714 """Builds save/restore graph nodes or runs save/restore in eager mode. 715 716 Args: 717 names_to_saveables: A dictionary mapping name to a Variable or 718 SaveableObject. Each name will be associated with the 719 corresponding variable in the checkpoint. 720 reshape: If True, allow restoring parameters from a checkpoint 721 that where the parameters have a different shape. This is 722 only needed when you try to restore from a Dist-Belief checkpoint, 723 and only some times. 724 sharded: If True, shard the checkpoints, one per device that has 725 Variable nodes. 726 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 727 are created, old ones are deleted. If None or 0, no checkpoints are 728 deleted from the filesystem but only the last one is kept in the 729 `checkpoint` file. Presently the number is only roughly enforced. For 730 example in case of restarts more than max_to_keep checkpoints may be 731 kept. 732 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 733 Defaults to 10,000 hours. 734 name: String. Optional name to use as a prefix when adding operations. 735 restore_sequentially: A Bool, which if true, causes restore of different 736 variables to happen sequentially within each device. 737 filename: If known at graph construction time, filename used for variable 738 loading/saving. If None, then the default name "model" will be used. 739 740 Returns: 741 A SaverDef proto. 742 743 Raises: 744 TypeError: If 'names_to_saveables' is not a dictionary mapping string 745 keys to variable Tensors. 746 ValueError: If any of the keys or values in 'names_to_saveables' is not 747 unique. 748 """ 749 return self._build_internal( 750 names_to_saveables=names_to_saveables, 751 reshape=reshape, 752 sharded=sharded, 753 max_to_keep=max_to_keep, 754 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 755 name=name, 756 restore_sequentially=restore_sequentially, 757 filename=filename) 758 759 def _build_internal(self, 760 names_to_saveables, 761 reshape=False, 762 sharded=False, 763 max_to_keep=5, 764 keep_checkpoint_every_n_hours=10000.0, 765 name=None, 766 restore_sequentially=False, 767 filename="model", 768 build_save=True, 769 build_restore=True): 770 """build() with option to only perform save and restore.""" 771 if context.in_graph_mode() and (not build_save or not build_restore): 772 raise ValueError("Graph mode needs to build save and restore together.") 773 774 saveables = self._ValidateAndSliceInputs(names_to_saveables) 775 if max_to_keep is None: 776 max_to_keep = 0 777 778 with ops.name_scope(name, "save", 779 [saveable.op for saveable in saveables]) as name: 780 # Add the Constant string tensor for the filename. 781 filename_tensor = constant_op.constant(filename or "model") 782 783 # Add the save ops. 784 if sharded: 785 per_device = self._GroupByDevices(saveables) 786 if build_save: 787 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 788 if build_restore: 789 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 790 restore_sequentially, reshape) 791 else: 792 if build_save: 793 save_tensor = self._AddSaveOps(filename_tensor, saveables) 794 if build_restore: 795 restore_op = self._AddRestoreOps(filename_tensor, saveables, 796 restore_sequentially, reshape) 797 798 # In the following use case, it's possible to have restore_ops be called 799 # something else: 800 # - Build inference graph and export a meta_graph. 801 # - Import the inference meta_graph 802 # - Extend the inference graph to a train graph. 803 # - Export a new meta_graph. 804 # Now the second restore_op will be called "restore_all_1". 805 # As such, comment out the assert for now until we know whether supporting 806 # such usage model makes sense. 807 # 808 # assert restore_op.name.endswith("restore_all"), restore_op.name 809 if context.in_graph_mode(): 810 return saver_pb2.SaverDef( 811 filename_tensor_name=filename_tensor.name, 812 save_tensor_name=save_tensor.name, 813 restore_op_name=restore_op.name, 814 max_to_keep=max_to_keep, 815 sharded=sharded, 816 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 817 version=self._write_version) 818 else: 819 # Store the tensor values to the tensor_names. 820 save_tensor_name = save_tensor.numpy() if build_save else "" 821 return saver_pb2.SaverDef( 822 filename_tensor_name=filename_tensor.numpy(), 823 save_tensor_name=save_tensor_name, 824 restore_op_name="", 825 max_to_keep=max_to_keep, 826 sharded=sharded, 827 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 828 version=self._write_version) 829 830 831 class BulkSaverBuilder(BaseSaverBuilder): 832 """SaverBuilder with support for bulk restoring multiple saveables.""" 833 834 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 835 restore_sequentially): 836 837 # Ignored: bulk restore is internally sequential. 838 del restore_sequentially 839 restore_specs = [] 840 for saveable in saveables: 841 for spec in saveable.specs: 842 restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype)) 843 844 names, slices, dtypes = zip(*restore_specs) 845 # Load all tensors onto CPU 0 for compatibility with existing code. 846 with ops.device("cpu:0"): 847 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 848 849 850 def _get_saver_or_default(): 851 """Returns the saver from SAVERS collection, or creates a default one. 852 853 This method is used by other members of the training module, such as 854 `Scaffold`, or `CheckpointSaverHook`. 855 856 Returns: 857 `Saver`. 858 859 Raises: 860 RuntimeError: If the SAVERS collection already has more than one items. 861 """ 862 collection_key = ops.GraphKeys.SAVERS 863 savers = ops.get_collection(collection_key) 864 if savers: 865 if len(savers) > 1: 866 raise RuntimeError( 867 "More than one item in collection {}. " 868 "Please indicate which one to use by passing it to the constructor.". 869 format(collection_key)) 870 return savers[0] 871 saver = Saver(sharded=True, allow_empty=True) 872 if saver is not None: 873 ops.add_to_collection(collection_key, saver) 874 return saver 875 876 877 def _GetCheckpointFilename(save_dir, latest_filename): 878 """Returns a filename for storing the CheckpointState. 879 880 Args: 881 save_dir: The directory for saving and restoring checkpoints. 882 latest_filename: Name of the file in 'save_dir' that is used 883 to store the CheckpointState. 884 885 Returns: 886 The path of the file that contains the CheckpointState proto. 887 """ 888 if latest_filename is None: 889 latest_filename = "checkpoint" 890 return os.path.join(save_dir, latest_filename) 891 892 893 @tf_export("train.generate_checkpoint_state_proto") 894 def generate_checkpoint_state_proto(save_dir, 895 model_checkpoint_path, 896 all_model_checkpoint_paths=None): 897 """Generates a checkpoint state proto. 898 899 Args: 900 save_dir: Directory where the model was saved. 901 model_checkpoint_path: The checkpoint file. 902 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 903 checkpoints, sorted from oldest to newest. If this is a non-empty list, 904 the last element must be equal to model_checkpoint_path. These paths 905 are also saved in the CheckpointState proto. 906 907 Returns: 908 CheckpointState proto with model_checkpoint_path and 909 all_model_checkpoint_paths updated to either absolute paths or 910 relative paths to the current save_dir. 911 """ 912 if all_model_checkpoint_paths is None: 913 all_model_checkpoint_paths = [] 914 915 if (not all_model_checkpoint_paths or 916 all_model_checkpoint_paths[-1] != model_checkpoint_path): 917 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 918 model_checkpoint_path) 919 all_model_checkpoint_paths.append(model_checkpoint_path) 920 921 # Relative paths need to be rewritten to be relative to the "save_dir" 922 # if model_checkpoint_path already contains "save_dir". 923 if not os.path.isabs(save_dir): 924 if not os.path.isabs(model_checkpoint_path): 925 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 926 for i in range(len(all_model_checkpoint_paths)): 927 p = all_model_checkpoint_paths[i] 928 if not os.path.isabs(p): 929 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 930 931 coord_checkpoint_proto = CheckpointState( 932 model_checkpoint_path=model_checkpoint_path, 933 all_model_checkpoint_paths=all_model_checkpoint_paths) 934 935 return coord_checkpoint_proto 936 937 938 @tf_export("train.update_checkpoint_state") 939 def update_checkpoint_state(save_dir, 940 model_checkpoint_path, 941 all_model_checkpoint_paths=None, 942 latest_filename=None): 943 """Updates the content of the 'checkpoint' file. 944 945 This updates the checkpoint file containing a CheckpointState 946 proto. 947 948 Args: 949 save_dir: Directory where the model was saved. 950 model_checkpoint_path: The checkpoint file. 951 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 952 checkpoints, sorted from oldest to newest. If this is a non-empty list, 953 the last element must be equal to model_checkpoint_path. These paths 954 are also saved in the CheckpointState proto. 955 latest_filename: Optional name of the checkpoint file. Default to 956 'checkpoint'. 957 958 Raises: 959 RuntimeError: If any of the model checkpoint paths conflict with the file 960 containing CheckpointSate. 961 """ 962 _update_checkpoint_state( 963 save_dir=save_dir, 964 model_checkpoint_path=model_checkpoint_path, 965 all_model_checkpoint_paths=all_model_checkpoint_paths, 966 latest_filename=latest_filename, 967 save_relative_paths=False) 968 969 970 def _update_checkpoint_state(save_dir, 971 model_checkpoint_path, 972 all_model_checkpoint_paths=None, 973 latest_filename=None, 974 save_relative_paths=False): 975 """Updates the content of the 'checkpoint' file. 976 977 This updates the checkpoint file containing a CheckpointState 978 proto. 979 980 Args: 981 save_dir: Directory where the model was saved. 982 model_checkpoint_path: The checkpoint file. 983 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 984 checkpoints, sorted from oldest to newest. If this is a non-empty list, 985 the last element must be equal to model_checkpoint_path. These paths 986 are also saved in the CheckpointState proto. 987 latest_filename: Optional name of the checkpoint file. Default to 988 'checkpoint'. 989 save_relative_paths: If `True`, will write relative paths to the checkpoint 990 state file. 991 992 Raises: 993 RuntimeError: If any of the model checkpoint paths conflict with the file 994 containing CheckpointSate. 995 """ 996 # Writes the "checkpoint" file for the coordinator for later restoration. 997 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 998 if save_relative_paths: 999 if os.path.isabs(model_checkpoint_path): 1000 rel_model_checkpoint_path = os.path.relpath( 1001 model_checkpoint_path, save_dir) 1002 else: 1003 rel_model_checkpoint_path = model_checkpoint_path 1004 rel_all_model_checkpoint_paths = [] 1005 for p in all_model_checkpoint_paths: 1006 if os.path.isabs(p): 1007 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 1008 else: 1009 rel_all_model_checkpoint_paths.append(p) 1010 ckpt = generate_checkpoint_state_proto( 1011 save_dir, 1012 rel_model_checkpoint_path, 1013 all_model_checkpoint_paths=rel_all_model_checkpoint_paths) 1014 else: 1015 ckpt = generate_checkpoint_state_proto( 1016 save_dir, 1017 model_checkpoint_path, 1018 all_model_checkpoint_paths=all_model_checkpoint_paths) 1019 1020 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 1021 raise RuntimeError("Save path '%s' conflicts with path used for " 1022 "checkpoint state. Please use a different save path." % 1023 model_checkpoint_path) 1024 1025 # Preventing potential read/write race condition by *atomically* writing to a 1026 # file. 1027 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 1028 text_format.MessageToString(ckpt)) 1029 1030 1031 @tf_export("train.get_checkpoint_state") 1032 def get_checkpoint_state(checkpoint_dir, latest_filename=None): 1033 """Returns CheckpointState proto from the "checkpoint" file. 1034 1035 If the "checkpoint" file contains a valid CheckpointState 1036 proto, returns it. 1037 1038 Args: 1039 checkpoint_dir: The directory of checkpoints. 1040 latest_filename: Optional name of the checkpoint file. Default to 1041 'checkpoint'. 1042 1043 Returns: 1044 A CheckpointState if the state was available, None 1045 otherwise. 1046 1047 Raises: 1048 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 1049 """ 1050 ckpt = None 1051 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 1052 latest_filename) 1053 f = None 1054 try: 1055 # Check that the file exists before opening it to avoid 1056 # many lines of errors from colossus in the logs. 1057 if file_io.file_exists(coord_checkpoint_filename): 1058 file_content = file_io.read_file_to_string( 1059 coord_checkpoint_filename) 1060 ckpt = CheckpointState() 1061 text_format.Merge(file_content, ckpt) 1062 if not ckpt.model_checkpoint_path: 1063 raise ValueError("Invalid checkpoint state loaded from %s", 1064 checkpoint_dir) 1065 # For relative model_checkpoint_path and all_model_checkpoint_paths, 1066 # prepend checkpoint_dir. 1067 if not os.path.isabs(ckpt.model_checkpoint_path): 1068 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 1069 ckpt.model_checkpoint_path) 1070 for i in range(len(ckpt.all_model_checkpoint_paths)): 1071 p = ckpt.all_model_checkpoint_paths[i] 1072 if not os.path.isabs(p): 1073 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 1074 except errors.OpError as e: 1075 # It's ok if the file cannot be read 1076 logging.warning("%s: %s", type(e).__name__, e) 1077 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 1078 return None 1079 except text_format.ParseError as e: 1080 logging.warning("%s: %s", type(e).__name__, e) 1081 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 1082 return None 1083 finally: 1084 if f: 1085 f.close() 1086 return ckpt 1087 1088 1089 @tf_export("train.Saver") 1090 class Saver(object): 1091 """Saves and restores variables. 1092 1093 See @{$variables$Variables} 1094 for an overview of variables, saving and restoring. 1095 1096 The `Saver` class adds ops to save and restore variables to and from 1097 *checkpoints*. It also provides convenience methods to run these ops. 1098 1099 Checkpoints are binary files in a proprietary format which map variable names 1100 to tensor values. The best way to examine the contents of a checkpoint is to 1101 load it using a `Saver`. 1102 1103 Savers can automatically number checkpoint filenames with a provided counter. 1104 This lets you keep multiple checkpoints at different steps while training a 1105 model. For example you can number the checkpoint filenames with the training 1106 step number. To avoid filling up disks, savers manage checkpoint files 1107 automatically. For example, they can keep only the N most recent files, or 1108 one checkpoint for every N hours of training. 1109 1110 You number checkpoint filenames by passing a value to the optional 1111 `global_step` argument to `save()`: 1112 1113 ```python 1114 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 1115 ... 1116 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 1117 ``` 1118 1119 Additionally, optional arguments to the `Saver()` constructor let you control 1120 the proliferation of checkpoint files on disk: 1121 1122 * `max_to_keep` indicates the maximum number of recent checkpoint files to 1123 keep. As new files are created, older files are deleted. If None or 0, 1124 all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent 1125 checkpoint files are kept.) 1126 1127 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 1128 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 1129 for every N hours of training. This can be useful if you want to later 1130 analyze how a model progressed during a long training session. For 1131 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 1132 one checkpoint file for every 2 hours of training. The default value of 1133 10,000 hours effectively disables the feature. 1134 1135 Note that you still have to call the `save()` method to save the model. 1136 Passing these arguments to the constructor will not save variables 1137 automatically for you. 1138 1139 A training program that saves regularly looks like: 1140 1141 ```python 1142 ... 1143 # Create a saver. 1144 saver = tf.train.Saver(...variables...) 1145 # Launch the graph and train, saving the model every 1,000 steps. 1146 sess = tf.Session() 1147 for step in xrange(1000000): 1148 sess.run(..training_op..) 1149 if step % 1000 == 0: 1150 # Append the step number to the checkpoint name: 1151 saver.save(sess, 'my-model', global_step=step) 1152 ``` 1153 1154 In addition to checkpoint files, savers keep a protocol buffer on disk with 1155 the list of recent checkpoints. This is used to manage numbered checkpoint 1156 files and by `latest_checkpoint()`, which makes it easy to discover the path 1157 to the most recent checkpoint. That protocol buffer is stored in a file named 1158 'checkpoint' next to the checkpoint files. 1159 1160 If you create several savers, you can specify a different filename for the 1161 protocol buffer file in the call to `save()`. 1162 """ 1163 1164 def __init__(self, 1165 var_list=None, 1166 reshape=False, 1167 sharded=False, 1168 max_to_keep=5, 1169 keep_checkpoint_every_n_hours=10000.0, 1170 name=None, 1171 restore_sequentially=False, 1172 saver_def=None, 1173 builder=None, 1174 defer_build=False, 1175 allow_empty=False, 1176 write_version=saver_pb2.SaverDef.V2, 1177 pad_step_number=False, 1178 save_relative_paths=False, 1179 filename=None): 1180 """Creates a `Saver`. 1181 1182 The constructor adds ops to save and restore variables. 1183 1184 `var_list` specifies the variables that will be saved and restored. It can 1185 be passed as a `dict` or a list: 1186 1187 * A `dict` of names to variables: The keys are the names that will be 1188 used to save or restore the variables in the checkpoint files. 1189 * A list of variables: The variables will be keyed with their op name in 1190 the checkpoint files. 1191 1192 For example: 1193 1194 ```python 1195 v1 = tf.Variable(..., name='v1') 1196 v2 = tf.Variable(..., name='v2') 1197 1198 # Pass the variables as a dict: 1199 saver = tf.train.Saver({'v1': v1, 'v2': v2}) 1200 1201 # Or pass them as a list. 1202 saver = tf.train.Saver([v1, v2]) 1203 # Passing a list is equivalent to passing a dict with the variable op names 1204 # as keys: 1205 saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) 1206 ``` 1207 1208 The optional `reshape` argument, if `True`, allows restoring a variable from 1209 a save file where the variable had a different shape, but the same number 1210 of elements and type. This is useful if you have reshaped a variable and 1211 want to reload it from an older checkpoint. 1212 1213 The optional `sharded` argument, if `True`, instructs the saver to shard 1214 checkpoints per device. 1215 1216 Args: 1217 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 1218 names to `SaveableObject`s. If `None`, defaults to the list of all 1219 saveable objects. 1220 reshape: If `True`, allows restoring parameters from a checkpoint 1221 where the variables have a different shape. 1222 sharded: If `True`, shard the checkpoints, one per device. 1223 max_to_keep: Maximum number of recent checkpoints to keep. 1224 Defaults to 5. 1225 keep_checkpoint_every_n_hours: How often to keep checkpoints. 1226 Defaults to 10,000 hours. 1227 name: String. Optional name to use as a prefix when adding operations. 1228 restore_sequentially: A `Bool`, which if true, causes restore of different 1229 variables to happen sequentially within each device. This can lower 1230 memory usage when restoring very large models. 1231 saver_def: Optional `SaverDef` proto to use instead of running the 1232 builder. This is only useful for specialty code that wants to recreate 1233 a `Saver` object for a previously built `Graph` that had a `Saver`. 1234 The `saver_def` proto should be the one returned by the 1235 `as_saver_def()` call of the `Saver` that was created for that `Graph`. 1236 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 1237 Defaults to `BulkSaverBuilder()`. 1238 defer_build: If `True`, defer adding the save and restore ops to the 1239 `build()` call. In that case `build()` should be called before 1240 finalizing the graph or using the saver. 1241 allow_empty: If `False` (default) raise an error if there are no 1242 variables in the graph. Otherwise, construct the saver anyway and make 1243 it a no-op. 1244 write_version: controls what format to use when saving checkpoints. It 1245 also affects certain filepath matching logic. The V2 format is the 1246 recommended choice: it is much more optimized than V1 in terms of 1247 memory required and latency incurred during restore. Regardless of 1248 this flag, the Saver is able to restore from both V2 and V1 checkpoints. 1249 pad_step_number: if True, pads the global step number in the checkpoint 1250 filepaths to some fixed width (8 by default). This is turned off by 1251 default. 1252 save_relative_paths: If `True`, will write relative paths to the 1253 checkpoint state file. This is needed if the user wants to copy the 1254 checkpoint directory and reload from the copied directory. 1255 filename: If known at graph construction time, filename used for variable 1256 loading/saving. 1257 1258 Raises: 1259 TypeError: If `var_list` is invalid. 1260 ValueError: If any of the keys or values in `var_list` are not unique. 1261 RuntimeError: If eager execution is enabled and`var_list` does not specify 1262 a list of varialbes to save. 1263 1264 @compatibility(eager) 1265 When eager execution is enabled, `var_list` must specify a `list` or `dict` 1266 of variables to save. Otherwise, a `RuntimeError` will be raised. 1267 @end_compatibility 1268 """ 1269 if defer_build and var_list: 1270 raise ValueError( 1271 "If `var_list` is provided then build cannot be deferred. " 1272 "Either set defer_build=False or var_list=None.") 1273 if context.in_eager_mode() and var_list is None: 1274 raise RuntimeError( 1275 "When eager execution is enabled, `var_list` must specify a list or " 1276 "dict of variables to save") 1277 self._var_list = var_list 1278 self._reshape = reshape 1279 self._sharded = sharded 1280 self._max_to_keep = max_to_keep 1281 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 1282 self._name = name 1283 self._restore_sequentially = restore_sequentially 1284 self.saver_def = saver_def 1285 self._builder = builder 1286 self._is_built = False 1287 self._allow_empty = allow_empty 1288 self._is_empty = None 1289 self._write_version = write_version 1290 self._pad_step_number = pad_step_number 1291 self._filename = filename 1292 if not defer_build and context.in_graph_mode(): 1293 self.build() 1294 if self.saver_def: 1295 self._check_saver_def() 1296 self._write_version = self.saver_def.version 1297 self._save_relative_paths = save_relative_paths 1298 1299 def build(self): 1300 if context.in_eager_mode(): 1301 raise RuntimeError("Use save/restore instead of build in eager mode.") 1302 self._build(self._filename, build_save=True, build_restore=True) 1303 1304 def _build_eager(self, checkpoint_path, build_save, build_restore): 1305 self._build( 1306 checkpoint_path, build_save=build_save, build_restore=build_restore) 1307 1308 def _build(self, checkpoint_path, build_save, build_restore): 1309 """Builds saver_def.""" 1310 if context.in_graph_mode(): 1311 if self._is_built: 1312 return 1313 self._is_built = True 1314 1315 if not self.saver_def or context.in_eager_mode(): 1316 if self._builder is None: 1317 self._builder = BulkSaverBuilder(self._write_version) 1318 1319 if self._var_list is None: 1320 # pylint: disable=protected-access 1321 self._var_list = variables._all_saveable_objects() 1322 if not self._var_list: 1323 if self._allow_empty: 1324 self._is_empty = True 1325 return 1326 else: 1327 raise ValueError("No variables to save") 1328 self._is_empty = False 1329 1330 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 1331 self._var_list, 1332 reshape=self._reshape, 1333 sharded=self._sharded, 1334 max_to_keep=self._max_to_keep, 1335 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 1336 name=self._name, 1337 restore_sequentially=self._restore_sequentially, 1338 filename=checkpoint_path, 1339 build_save=build_save, build_restore=build_restore) 1340 elif self.saver_def and self._name: 1341 # Since self._name is used as a name_scope by builder(), we are 1342 # overloading the use of this field to represent the "import_scope" as 1343 # well. 1344 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 1345 self.saver_def.filename_tensor_name, self._name) 1346 self.saver_def.save_tensor_name = ops.prepend_name_scope( 1347 self.saver_def.save_tensor_name, self._name) 1348 self.saver_def.restore_op_name = ops.prepend_name_scope( 1349 self.saver_def.restore_op_name, self._name) 1350 1351 self._check_saver_def() 1352 # Updates next checkpoint time. 1353 self._next_checkpoint_time = ( 1354 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 1355 self._last_checkpoints = [] 1356 self._checkpoints_to_be_deleted = [] 1357 1358 def _check_saver_def(self): 1359 if not isinstance(self.saver_def, saver_pb2.SaverDef): 1360 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 1361 self.saver_def) 1362 if context.in_graph_mode(): 1363 if not self.saver_def.save_tensor_name: 1364 raise ValueError("saver_def must specify the save_tensor_name: %s" % 1365 str(self.saver_def)) 1366 if not self.saver_def.restore_op_name: 1367 raise ValueError("saver_def must specify the restore_op_name: %s" % 1368 str(self.saver_def)) 1369 1370 def _CheckpointFilename(self, p): 1371 """Returns the checkpoint filename given a `(filename, time)` pair. 1372 1373 Args: 1374 p: (filename, time) pair. 1375 1376 Returns: 1377 Checkpoint file name. 1378 """ 1379 name, _ = p 1380 return name 1381 1382 def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"): 1383 """Returns the meta graph filename. 1384 1385 Args: 1386 checkpoint_filename: Name of the checkpoint file. 1387 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1388 1389 Returns: 1390 MetaGraph file name. 1391 """ 1392 # If the checkpoint_filename is sharded, the checkpoint_filename could 1393 # be of format model.ckpt-step#-?????-of-shard#. For example, 1394 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 1395 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 1396 meta_graph_filename = ".".join([basename, meta_graph_suffix]) 1397 return meta_graph_filename 1398 1399 def _RecordLastCheckpoint(self, latest_save_path): 1400 """Manages the list of the latest checkpoints.""" 1401 if not self.saver_def.max_to_keep: 1402 return 1403 # Remove first from list if the same name was used before. 1404 for p in self._last_checkpoints: 1405 if latest_save_path == self._CheckpointFilename(p): 1406 self._last_checkpoints.remove(p) 1407 # Append new path to list 1408 self._last_checkpoints.append((latest_save_path, time.time())) 1409 1410 # If more than max_to_keep, remove oldest. 1411 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 1412 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 1413 1414 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 1415 """Deletes old checkpoints if necessary. 1416 1417 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 1418 over `max_to_keep`. They are going to be deleted. If 1419 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 1420 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 1421 kept for every 0.5 hours of training; if `N` is 10, an additional 1422 checkpoint is kept for every 10 hours of training. 1423 1424 Args: 1425 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1426 """ 1427 if self._checkpoints_to_be_deleted: 1428 p = self._checkpoints_to_be_deleted.pop(0) 1429 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 1430 # have reached N hours of training. 1431 should_keep = p[1] > self._next_checkpoint_time 1432 if should_keep: 1433 self._next_checkpoint_time += ( 1434 self.saver_def.keep_checkpoint_every_n_hours * 3600) 1435 return 1436 1437 # Otherwise delete the files. 1438 try: 1439 checkpoint_prefix = self._CheckpointFilename(p) 1440 self._delete_file_if_exists( 1441 self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix)) 1442 if self.saver_def.version == saver_pb2.SaverDef.V2: 1443 # V2 has a metadata file and some data files. 1444 self._delete_file_if_exists(checkpoint_prefix + ".index") 1445 self._delete_file_if_exists(checkpoint_prefix + 1446 ".data-?????-of-?????") 1447 else: 1448 # V1, Legacy. Exact match on the data file. 1449 self._delete_file_if_exists(checkpoint_prefix) 1450 except Exception as e: # pylint: disable=broad-except 1451 logging.warning("Ignoring: %s", str(e)) 1452 1453 def _delete_file_if_exists(self, filespec): 1454 for pathname in file_io.get_matching_files(filespec): 1455 file_io.delete_file(pathname) 1456 1457 def as_saver_def(self): 1458 """Generates a `SaverDef` representation of this saver. 1459 1460 Returns: 1461 A `SaverDef` proto. 1462 """ 1463 return self.saver_def 1464 1465 def to_proto(self, export_scope=None): 1466 """Converts this `Saver` to a `SaverDef` protocol buffer. 1467 1468 Args: 1469 export_scope: Optional `string`. Name scope to remove. 1470 1471 Returns: 1472 A `SaverDef` protocol buffer. 1473 """ 1474 if export_scope is None: 1475 return self.saver_def 1476 1477 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 1478 self.saver_def.save_tensor_name.startswith(export_scope) and 1479 self.saver_def.restore_op_name.startswith(export_scope)): 1480 return None 1481 1482 saver_def = saver_pb2.SaverDef() 1483 saver_def.CopyFrom(self.saver_def) 1484 saver_def.filename_tensor_name = ops.strip_name_scope( 1485 saver_def.filename_tensor_name, export_scope) 1486 saver_def.save_tensor_name = ops.strip_name_scope( 1487 saver_def.save_tensor_name, export_scope) 1488 saver_def.restore_op_name = ops.strip_name_scope( 1489 saver_def.restore_op_name, export_scope) 1490 return saver_def 1491 1492 @staticmethod 1493 def from_proto(saver_def, import_scope=None): 1494 """Returns a `Saver` object created from `saver_def`. 1495 1496 Args: 1497 saver_def: a `SaverDef` protocol buffer. 1498 import_scope: Optional `string`. Name scope to use. 1499 1500 Returns: 1501 A `Saver` built from saver_def. 1502 """ 1503 return Saver(saver_def=saver_def, name=import_scope) 1504 1505 @property 1506 def last_checkpoints(self): 1507 """List of not-yet-deleted checkpoint filenames. 1508 1509 You can pass any of the returned values to `restore()`. 1510 1511 Returns: 1512 A list of checkpoint filenames, sorted from oldest to newest. 1513 """ 1514 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 1515 1516 def set_last_checkpoints(self, last_checkpoints): 1517 """DEPRECATED: Use set_last_checkpoints_with_time. 1518 1519 Sets the list of old checkpoint filenames. 1520 1521 Args: 1522 last_checkpoints: A list of checkpoint filenames. 1523 1524 Raises: 1525 AssertionError: If last_checkpoints is not a list. 1526 """ 1527 assert isinstance(last_checkpoints, list) 1528 # We use a timestamp of +inf so that this checkpoint will never be 1529 # deleted. This is both safe and backwards compatible to a previous 1530 # version of the code which used s[1] as the "timestamp". 1531 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 1532 1533 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 1534 """Sets the list of old checkpoint filenames and timestamps. 1535 1536 Args: 1537 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 1538 timestamps. 1539 1540 Raises: 1541 AssertionError: If last_checkpoints_with_time is not a list. 1542 """ 1543 assert isinstance(last_checkpoints_with_time, list) 1544 self._last_checkpoints = last_checkpoints_with_time 1545 1546 def recover_last_checkpoints(self, checkpoint_paths): 1547 """Recovers the internal saver state after a crash. 1548 1549 This method is useful for recovering the "self._last_checkpoints" state. 1550 1551 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 1552 exist, use their mtime as the checkpoint timestamp. 1553 1554 Args: 1555 checkpoint_paths: a list of checkpoint paths. 1556 """ 1557 mtimes = get_checkpoint_mtimes(checkpoint_paths) 1558 self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) 1559 1560 def save(self, 1561 sess, 1562 save_path, 1563 global_step=None, 1564 latest_filename=None, 1565 meta_graph_suffix="meta", 1566 write_meta_graph=True, 1567 write_state=True, 1568 strip_default_attrs=False): 1569 # pylint: disable=line-too-long 1570 """Saves variables. 1571 1572 This method runs the ops added by the constructor for saving variables. 1573 It requires a session in which the graph was launched. The variables to 1574 save must also have been initialized. 1575 1576 The method returns the path prefix of the newly created checkpoint files. 1577 This string can be passed directly to a call to `restore()`. 1578 1579 Args: 1580 sess: A Session to use to save the variables. 1581 save_path: String. Prefix of filenames created for the checkpoint. 1582 global_step: If provided the global step number is appended to 1583 `save_path` to create the checkpoint filenames. The optional argument 1584 can be a `Tensor`, a `Tensor` name or an integer. 1585 latest_filename: Optional name for the protocol buffer file that will 1586 contains the list of most recent checkpoints. That file, 1587 kept in the same directory as the checkpoint files, is automatically 1588 managed by the saver to keep track of recent checkpoints. Defaults to 1589 'checkpoint'. 1590 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1591 write_meta_graph: `Boolean` indicating whether or not to write the meta 1592 graph file. 1593 write_state: `Boolean` indicating whether or not to write the 1594 `CheckpointStateProto`. 1595 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1596 removed from the NodeDefs. For a detailed guide, see 1597 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1598 1599 Returns: 1600 A string: path prefix used for the checkpoint files. If the saver is 1601 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 1602 is the number of shards created. 1603 If the saver is empty, returns None. 1604 1605 Raises: 1606 TypeError: If `sess` is not a `Session`. 1607 ValueError: If `latest_filename` contains path components, or if it 1608 collides with `save_path`. 1609 RuntimeError: If save and restore ops weren't built. 1610 """ 1611 # pylint: enable=line-too-long 1612 if not self._is_built and context.in_graph_mode(): 1613 raise RuntimeError( 1614 "`build()` should be called before save if defer_build==True") 1615 if latest_filename is None: 1616 latest_filename = "checkpoint" 1617 if self._write_version != saver_pb2.SaverDef.V2: 1618 logging.warning("*******************************************************") 1619 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 1620 logging.warning("Consider switching to the more efficient V2 format:") 1621 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 1622 logging.warning("now on by default.") 1623 logging.warning("*******************************************************") 1624 1625 if os.path.split(latest_filename)[0]: 1626 raise ValueError("'latest_filename' must not contain path components") 1627 1628 if global_step is not None: 1629 if not isinstance(global_step, compat.integral_types): 1630 global_step = training_util.global_step(sess, global_step) 1631 checkpoint_file = "%s-%d" % (save_path, global_step) 1632 if self._pad_step_number: 1633 # Zero-pads the step numbers, so that they are sorted when listed. 1634 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 1635 else: 1636 checkpoint_file = save_path 1637 if os.path.basename( 1638 save_path) == latest_filename and not self._sharded: 1639 # Guard against collision between data file and checkpoint state file. 1640 raise ValueError( 1641 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 1642 (latest_filename, save_path)) 1643 1644 if (context.in_graph_mode() and 1645 not isinstance(sess, session.SessionInterface)): 1646 raise TypeError("'sess' must be a Session; %s" % sess) 1647 1648 save_path_parent = os.path.dirname(save_path) 1649 if not self._is_empty: 1650 try: 1651 if context.in_graph_mode(): 1652 model_checkpoint_path = sess.run( 1653 self.saver_def.save_tensor_name, 1654 {self.saver_def.filename_tensor_name: checkpoint_file}) 1655 else: 1656 self._build_eager( 1657 checkpoint_file, build_save=True, build_restore=False) 1658 model_checkpoint_path = self.saver_def.save_tensor_name 1659 1660 model_checkpoint_path = compat.as_str(model_checkpoint_path) 1661 if write_state: 1662 self._RecordLastCheckpoint(model_checkpoint_path) 1663 _update_checkpoint_state( 1664 save_dir=save_path_parent, 1665 model_checkpoint_path=model_checkpoint_path, 1666 all_model_checkpoint_paths=self.last_checkpoints, 1667 latest_filename=latest_filename, 1668 save_relative_paths=self._save_relative_paths) 1669 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 1670 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 1671 if not gfile.IsDirectory(save_path_parent): 1672 exc = ValueError( 1673 "Parent directory of {} doesn't exist, can't save.".format( 1674 save_path)) 1675 raise exc 1676 1677 if write_meta_graph: 1678 meta_graph_filename = self._MetaGraphFilename( 1679 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 1680 if context.in_graph_mode(): 1681 with sess.graph.as_default(): 1682 self.export_meta_graph( 1683 meta_graph_filename, strip_default_attrs=strip_default_attrs) 1684 1685 if self._is_empty: 1686 return None 1687 else: 1688 return model_checkpoint_path 1689 1690 def export_meta_graph(self, 1691 filename=None, 1692 collection_list=None, 1693 as_text=False, 1694 export_scope=None, 1695 clear_devices=False, 1696 clear_extraneous_savers=False, 1697 strip_default_attrs=False): 1698 # pylint: disable=line-too-long 1699 """Writes `MetaGraphDef` to save_path/filename. 1700 1701 Args: 1702 filename: Optional meta_graph filename including the path. 1703 collection_list: List of string keys to collect. 1704 as_text: If `True`, writes the meta_graph as an ASCII proto. 1705 export_scope: Optional `string`. Name scope to remove. 1706 clear_devices: Whether or not to clear the device field for an `Operation` 1707 or `Tensor` during export. 1708 clear_extraneous_savers: Remove any Saver-related information from the 1709 graph (both Save/Restore ops and SaverDefs) that are not associated 1710 with this Saver. 1711 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1712 removed from the NodeDefs. For a detailed guide, see 1713 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1714 1715 Returns: 1716 A `MetaGraphDef` proto. 1717 """ 1718 # pylint: enable=line-too-long 1719 return export_meta_graph( 1720 filename=filename, 1721 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 1722 saver_def=self.saver_def, 1723 collection_list=collection_list, 1724 as_text=as_text, 1725 export_scope=export_scope, 1726 clear_devices=clear_devices, 1727 clear_extraneous_savers=clear_extraneous_savers, 1728 strip_default_attrs=strip_default_attrs) 1729 1730 def restore(self, sess, save_path): 1731 """Restores previously saved variables. 1732 1733 This method runs the ops added by the constructor for restoring variables. 1734 It requires a session in which the graph was launched. The variables to 1735 restore do not have to have been initialized, as restoring is itself a way 1736 to initialize variables. 1737 1738 The `save_path` argument is typically a value previously returned from a 1739 `save()` call, or a call to `latest_checkpoint()`. 1740 1741 Args: 1742 sess: A `Session` to use to restore the parameters. None in eager mode. 1743 save_path: Path where parameters were previously saved. 1744 1745 Raises: 1746 ValueError: If save_path is None. 1747 """ 1748 if self._is_empty: 1749 return 1750 if save_path is None: 1751 raise ValueError("Can't load save_path when it is None.") 1752 logging.info("Restoring parameters from %s", save_path) 1753 if context.in_graph_mode(): 1754 sess.run(self.saver_def.restore_op_name, 1755 {self.saver_def.filename_tensor_name: save_path}) 1756 else: 1757 self._build_eager(save_path, build_save=False, build_restore=True) 1758 1759 @staticmethod 1760 def _add_collection_def(meta_graph_def, key, export_scope=None): 1761 """Adds a collection to MetaGraphDef protocol buffer. 1762 1763 Args: 1764 meta_graph_def: MetaGraphDef protocol buffer. 1765 key: One of the GraphKeys or user-defined string. 1766 export_scope: Optional `string`. Name scope to remove. 1767 """ 1768 meta_graph.add_collection_def(meta_graph_def, key, 1769 export_scope=export_scope) 1770 1771 1772 def _prefix_to_checkpoint_path(prefix, format_version): 1773 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 1774 1775 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 1776 returns the pathname to the index file. 1777 1778 Args: 1779 prefix: a string, the prefix of a checkpoint. 1780 format_version: the checkpoint format version that corresponds to the 1781 prefix. 1782 Returns: 1783 The pathname of a checkpoint file, taking into account the checkpoint 1784 format version. 1785 """ 1786 if format_version == saver_pb2.SaverDef.V2: 1787 return prefix + ".index" # The index file identifies a checkpoint. 1788 return prefix # Just the data file. 1789 1790 1791 @tf_export("train.latest_checkpoint") 1792 def latest_checkpoint(checkpoint_dir, latest_filename=None): 1793 """Finds the filename of latest saved checkpoint file. 1794 1795 Args: 1796 checkpoint_dir: Directory where the variables were saved. 1797 latest_filename: Optional name for the protocol buffer file that 1798 contains the list of most recent checkpoint filenames. 1799 See the corresponding argument to `Saver.save()`. 1800 1801 Returns: 1802 The full path to the latest checkpoint or `None` if no checkpoint was found. 1803 """ 1804 # Pick the latest checkpoint based on checkpoint state. 1805 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 1806 if ckpt and ckpt.model_checkpoint_path: 1807 # Look for either a V2 path or a V1 path, with priority for V2. 1808 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 1809 saver_pb2.SaverDef.V2) 1810 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 1811 saver_pb2.SaverDef.V1) 1812 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 1813 v1_path): 1814 return ckpt.model_checkpoint_path 1815 else: 1816 logging.error("Couldn't match files for checkpoint %s", 1817 ckpt.model_checkpoint_path) 1818 return None 1819 1820 1821 @tf_export("train.import_meta_graph") 1822 def import_meta_graph(meta_graph_or_file, clear_devices=False, 1823 import_scope=None, **kwargs): 1824 """Recreates a Graph saved in a `MetaGraphDef` proto. 1825 1826 This function takes a `MetaGraphDef` protocol buffer as input. If 1827 the argument is a file containing a `MetaGraphDef` protocol buffer , 1828 it constructs a protocol buffer from the file content. The function 1829 then adds all the nodes from the `graph_def` field to the 1830 current graph, recreates all the collections, and returns a saver 1831 constructed from the `saver_def` field. 1832 1833 In combination with `export_meta_graph()`, this function can be used to 1834 1835 * Serialize a graph along with other Python objects such as `QueueRunner`, 1836 `Variable` into a `MetaGraphDef`. 1837 1838 * Restart training from a saved graph and checkpoints. 1839 1840 * Run inference from a saved graph and checkpoints. 1841 1842 ```Python 1843 ... 1844 # Create a saver. 1845 saver = tf.train.Saver(...variables...) 1846 # Remember the training_op we want to run by adding it to a collection. 1847 tf.add_to_collection('train_op', train_op) 1848 sess = tf.Session() 1849 for step in xrange(1000000): 1850 sess.run(train_op) 1851 if step % 1000 == 0: 1852 # Saves checkpoint, which by default also exports a meta_graph 1853 # named 'my-model-global_step.meta'. 1854 saver.save(sess, 'my-model', global_step=step) 1855 ``` 1856 1857 Later we can continue training from this saved `meta_graph` without building 1858 the model from scratch. 1859 1860 ```Python 1861 with tf.Session() as sess: 1862 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 1863 new_saver.restore(sess, 'my-save-dir/my-model-10000') 1864 # tf.get_collection() returns a list. In this example we only want the 1865 # first one. 1866 train_op = tf.get_collection('train_op')[0] 1867 for step in xrange(1000000): 1868 sess.run(train_op) 1869 ``` 1870 1871 NOTE: Restarting training from saved `meta_graph` only works if the 1872 device assignments have not changed. 1873 1874 Args: 1875 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 1876 the path) containing a `MetaGraphDef`. 1877 clear_devices: Whether or not to clear the device field for an `Operation` 1878 or `Tensor` during import. 1879 import_scope: Optional `string`. Name scope to add. Only used when 1880 initializing from protocol buffer. 1881 **kwargs: Optional keyed arguments. 1882 1883 Returns: 1884 A saver constructed from `saver_def` in `MetaGraphDef` or None. 1885 1886 A None value is returned if no variables exist in the `MetaGraphDef` 1887 (i.e., there are no variables to restore). 1888 1889 Raises: 1890 RuntimeError: If called with eager execution enabled. 1891 1892 @compatibility(eager) 1893 Exporting/importing meta graphs is not supported. No graph exists when eager 1894 execution is enabled. 1895 @end_compatibility 1896 """ # pylint: disable=g-doc-exception 1897 if context.in_eager_mode(): 1898 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1899 "eager execution is enabled. No graph exists when eager " 1900 "execution is enabled.") 1901 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 1902 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 1903 else: 1904 meta_graph_def = meta_graph_or_file 1905 1906 meta_graph.import_scoped_meta_graph(meta_graph_def, 1907 clear_devices=clear_devices, 1908 import_scope=import_scope, 1909 **kwargs) 1910 if meta_graph_def.HasField("saver_def"): 1911 return Saver(saver_def=meta_graph_def.saver_def, name=import_scope) 1912 else: 1913 if variables._all_saveable_objects(): # pylint: disable=protected-access 1914 # Return the default saver instance for all graph variables. 1915 return Saver() 1916 else: 1917 # If no graph variables exist, then a Saver cannot be constructed. 1918 logging.info("Saver not created because there are no variables in the" 1919 " graph to restore") 1920 return None 1921 1922 1923 @tf_export("train.export_meta_graph") 1924 def export_meta_graph(filename=None, 1925 meta_info_def=None, 1926 graph_def=None, 1927 saver_def=None, 1928 collection_list=None, 1929 as_text=False, 1930 graph=None, 1931 export_scope=None, 1932 clear_devices=False, 1933 clear_extraneous_savers=False, 1934 strip_default_attrs=False, 1935 **kwargs): 1936 # pylint: disable=line-too-long 1937 """Returns `MetaGraphDef` proto. Optionally writes it to filename. 1938 1939 This function exports the graph, saver, and collection objects into 1940 `MetaGraphDef` protocol buffer with the intention of it being imported 1941 at a later time or location to restart training, run inference, or be 1942 a subgraph. 1943 1944 Args: 1945 filename: Optional filename including the path for writing the 1946 generated `MetaGraphDef` protocol buffer. 1947 meta_info_def: `MetaInfoDef` protocol buffer. 1948 graph_def: `GraphDef` protocol buffer. 1949 saver_def: `SaverDef` protocol buffer. 1950 collection_list: List of string keys to collect. 1951 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 1952 graph: The `Graph` to import into. If `None`, use the default graph. 1953 export_scope: Optional `string`. Name scope under which to extract 1954 the subgraph. The scope name will be striped from the node definitions 1955 for easy import later into new name scopes. If `None`, the whole graph 1956 is exported. graph_def and export_scope cannot both be specified. 1957 clear_devices: Whether or not to clear the device field for an `Operation` 1958 or `Tensor` during export. 1959 clear_extraneous_savers: Remove any Saver-related information from the 1960 graph (both Save/Restore ops and SaverDefs) that are not associated 1961 with the provided SaverDef. 1962 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1963 removed from the NodeDefs. For a detailed guide, see 1964 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1965 **kwargs: Optional keyed arguments. 1966 1967 Returns: 1968 A `MetaGraphDef` proto. 1969 1970 Raises: 1971 ValueError: When the `GraphDef` is larger than 2GB. 1972 RuntimeError: If called with eager execution enabled. 1973 1974 @compatibility(eager) 1975 Exporting/importing meta graphs is not supported. No graph exists when eager 1976 execution is enabled. 1977 @end_compatibility 1978 """ 1979 # pylint: enable=line-too-long 1980 if context.in_eager_mode(): 1981 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1982 "eager execution is enabled. No graph exists when eager " 1983 "execution is enabled.") 1984 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 1985 filename=filename, 1986 meta_info_def=meta_info_def, 1987 graph_def=graph_def, 1988 saver_def=saver_def, 1989 collection_list=collection_list, 1990 as_text=as_text, 1991 graph=graph, 1992 export_scope=export_scope, 1993 clear_devices=clear_devices, 1994 clear_extraneous_savers=clear_extraneous_savers, 1995 strip_default_attrs=strip_default_attrs, 1996 **kwargs) 1997 return meta_graph_def 1998 1999 2000 @tf_export("train.checkpoint_exists") 2001 def checkpoint_exists(checkpoint_prefix): 2002 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 2003 2004 This is the recommended way to check if a checkpoint exists, since it takes 2005 into account the naming difference between V1 and V2 formats. 2006 2007 Args: 2008 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 2009 priority. Typically the result of `Saver.save()` or that of 2010 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 2011 V1/V2. 2012 Returns: 2013 A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. 2014 """ 2015 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 2016 saver_pb2.SaverDef.V2) 2017 if file_io.get_matching_files(pathname): 2018 return True 2019 elif file_io.get_matching_files(checkpoint_prefix): 2020 return True 2021 else: 2022 return False 2023 2024 2025 @tf_export("train.get_checkpoint_mtimes") 2026 def get_checkpoint_mtimes(checkpoint_prefixes): 2027 """Returns the mtimes (modification timestamps) of the checkpoints. 2028 2029 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 2030 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 2031 that priority. 2032 2033 This is the recommended way to get the mtimes, since it takes into account 2034 the naming difference between V1 and V2 formats. 2035 2036 Args: 2037 checkpoint_prefixes: a list of checkpoint paths, typically the results of 2038 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 2039 sharded/non-sharded or V1/V2. 2040 Returns: 2041 A list of mtimes (in microseconds) of the found checkpoints. 2042 """ 2043 mtimes = [] 2044 2045 def match_maybe_append(pathname): 2046 fnames = file_io.get_matching_files(pathname) 2047 if fnames: 2048 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 2049 return True 2050 return False 2051 2052 for checkpoint_prefix in checkpoint_prefixes: 2053 # Tries V2's metadata file first. 2054 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 2055 saver_pb2.SaverDef.V2) 2056 if match_maybe_append(pathname): 2057 continue 2058 # Otherwise, tries V1, where the prefix is the complete pathname. 2059 match_maybe_append(checkpoint_prefix) 2060 2061 return mtimes 2062 2063 2064 ops.register_proto_function( 2065 ops.GraphKeys.SAVERS, 2066 proto_type=saver_pb2.SaverDef, 2067 to_proto=Saver.to_proto, 2068 from_proto=Saver.from_proto) 2069