1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Support for training models. 17 18 See the @{$python/train} guide. 19 20 @@Optimizer 21 @@GradientDescentOptimizer 22 @@AdadeltaOptimizer 23 @@AdagradOptimizer 24 @@AdagradDAOptimizer 25 @@MomentumOptimizer 26 @@AdamOptimizer 27 @@FtrlOptimizer 28 @@ProximalGradientDescentOptimizer 29 @@ProximalAdagradOptimizer 30 @@RMSPropOptimizer 31 @@gradients 32 @@AggregationMethod 33 @@stop_gradient 34 @@hessians 35 @@clip_by_value 36 @@clip_by_norm 37 @@clip_by_average_norm 38 @@clip_by_global_norm 39 @@global_norm 40 @@cosine_decay 41 @@cosine_decay_restarts 42 @@linear_cosine_decay 43 @@noisy_linear_cosine_decay 44 @@exponential_decay 45 @@inverse_time_decay 46 @@natural_exp_decay 47 @@piecewise_constant 48 @@polynomial_decay 49 @@ExponentialMovingAverage 50 @@Coordinator 51 @@QueueRunner 52 @@LooperThread 53 @@add_queue_runner 54 @@start_queue_runners 55 @@Server 56 @@Supervisor 57 @@SessionManager 58 @@ClusterSpec 59 @@replica_device_setter 60 @@MonitoredTrainingSession 61 @@MonitoredSession 62 @@SingularMonitoredSession 63 @@Scaffold 64 @@SessionCreator 65 @@ChiefSessionCreator 66 @@WorkerSessionCreator 67 @@summary_iterator 68 @@SessionRunHook 69 @@SessionRunArgs 70 @@SessionRunContext 71 @@SessionRunValues 72 @@LoggingTensorHook 73 @@StopAtStepHook 74 @@CheckpointSaverHook 75 @@CheckpointSaverListener 76 @@NewCheckpointReader 77 @@StepCounterHook 78 @@NanLossDuringTrainingError 79 @@NanTensorHook 80 @@SummarySaverHook 81 @@GlobalStepWaiterHook 82 @@FinalOpsHook 83 @@FeedFnHook 84 @@ProfilerHook 85 @@SecondOrStepTimer 86 @@global_step 87 @@basic_train_loop 88 @@get_global_step 89 @@get_or_create_global_step 90 @@create_global_step 91 @@assert_global_step 92 @@write_graph 93 @@load_checkpoint 94 @@load_variable 95 @@list_variables 96 @@init_from_checkpoint 97 """ 98 99 # Optimizers. 100 from __future__ import absolute_import 101 from __future__ import division 102 from __future__ import print_function 103 104 import sys as _sys 105 106 from tensorflow.python.ops import io_ops as _io_ops 107 from tensorflow.python.ops import sdca_ops as _sdca_ops 108 from tensorflow.python.ops import state_ops as _state_ops 109 from tensorflow.python.util.all_util import remove_undocumented 110 111 # pylint: disable=g-bad-import-order,unused-import 112 from tensorflow.python.ops.sdca_ops import sdca_optimizer 113 from tensorflow.python.ops.sdca_ops import sdca_fprint 114 from tensorflow.python.ops.sdca_ops import sdca_shrink_l1 115 from tensorflow.python.training.adadelta import AdadeltaOptimizer 116 from tensorflow.python.training.adagrad import AdagradOptimizer 117 from tensorflow.python.training.adagrad_da import AdagradDAOptimizer 118 from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer 119 from tensorflow.python.training.adam import AdamOptimizer 120 from tensorflow.python.training.ftrl import FtrlOptimizer 121 from tensorflow.python.training.momentum import MomentumOptimizer 122 from tensorflow.python.training.moving_averages import ExponentialMovingAverage 123 from tensorflow.python.training.optimizer import Optimizer 124 from tensorflow.python.training.rmsprop import RMSPropOptimizer 125 from tensorflow.python.training.gradient_descent import GradientDescentOptimizer 126 from tensorflow.python.training.proximal_gradient_descent import ProximalGradientDescentOptimizer 127 from tensorflow.python.training.sync_replicas_optimizer import SyncReplicasOptimizer 128 129 # Utility classes for training. 130 from tensorflow.python.training.coordinator import Coordinator 131 from tensorflow.python.training.coordinator import LooperThread 132 # go/tf-wildcard-import 133 # pylint: disable=wildcard-import 134 from tensorflow.python.training.queue_runner import * 135 136 # For the module level doc. 137 from tensorflow.python.training import input as _input 138 from tensorflow.python.training.input import * # pylint: disable=redefined-builtin 139 # pylint: enable=wildcard-import 140 141 from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer 142 from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook 143 from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook 144 from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook 145 from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener 146 from tensorflow.python.training.basic_session_run_hooks import StepCounterHook 147 from tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError 148 from tensorflow.python.training.basic_session_run_hooks import NanTensorHook 149 from tensorflow.python.training.basic_session_run_hooks import SummarySaverHook 150 from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterHook 151 from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook 152 from tensorflow.python.training.basic_session_run_hooks import FeedFnHook 153 from tensorflow.python.training.basic_session_run_hooks import ProfilerHook 154 from tensorflow.python.training.basic_loops import basic_train_loop 155 from tensorflow.python.training.checkpoint_utils import init_from_checkpoint 156 from tensorflow.python.training.checkpoint_utils import list_variables 157 from tensorflow.python.training.checkpoint_utils import load_checkpoint 158 from tensorflow.python.training.checkpoint_utils import load_variable 159 160 from tensorflow.python.training.device_setter import replica_device_setter 161 from tensorflow.python.training.monitored_session import Scaffold 162 from tensorflow.python.training.monitored_session import MonitoredTrainingSession 163 from tensorflow.python.training.monitored_session import SessionCreator 164 from tensorflow.python.training.monitored_session import ChiefSessionCreator 165 from tensorflow.python.training.monitored_session import WorkerSessionCreator 166 from tensorflow.python.training.monitored_session import MonitoredSession 167 from tensorflow.python.training.monitored_session import SingularMonitoredSession 168 from tensorflow.python.training.saver import Saver 169 from tensorflow.python.training.saver import checkpoint_exists 170 from tensorflow.python.training.saver import generate_checkpoint_state_proto 171 from tensorflow.python.training.saver import get_checkpoint_mtimes 172 from tensorflow.python.training.saver import get_checkpoint_state 173 from tensorflow.python.training.saver import latest_checkpoint 174 from tensorflow.python.training.saver import update_checkpoint_state 175 from tensorflow.python.training.saver import export_meta_graph 176 from tensorflow.python.training.saver import import_meta_graph 177 from tensorflow.python.training.session_run_hook import SessionRunHook 178 from tensorflow.python.training.session_run_hook import SessionRunArgs 179 from tensorflow.python.training.session_run_hook import SessionRunContext 180 from tensorflow.python.training.session_run_hook import SessionRunValues 181 from tensorflow.python.training.session_manager import SessionManager 182 from tensorflow.python.training.summary_io import summary_iterator 183 from tensorflow.python.training.supervisor import Supervisor 184 from tensorflow.python.training.training_util import write_graph 185 from tensorflow.python.training.training_util import global_step 186 from tensorflow.python.training.training_util import get_global_step 187 from tensorflow.python.training.training_util import assert_global_step 188 from tensorflow.python.training.training_util import create_global_step 189 from tensorflow.python.training.training_util import get_or_create_global_step 190 from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef 191 from tensorflow.python.pywrap_tensorflow import NewCheckpointReader 192 from tensorflow.python.util.tf_export import tf_export 193 194 # pylint: disable=wildcard-import 195 # Training data protos. 196 from tensorflow.core.example.example_pb2 import * 197 from tensorflow.core.example.feature_pb2 import * 198 from tensorflow.core.protobuf.saver_pb2 import * 199 200 # Utility op. Open Source. TODO(touts): move to nn? 201 from tensorflow.python.training.learning_rate_decay import * 202 # pylint: enable=wildcard-import 203 204 # Distributed computing support. 205 from tensorflow.core.protobuf.cluster_pb2 import ClusterDef 206 from tensorflow.core.protobuf.cluster_pb2 import JobDef 207 from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef 208 from tensorflow.python.training.server_lib import ClusterSpec 209 from tensorflow.python.training.server_lib import Server 210 211 # Symbols whitelisted for export without documentation. 212 _allowed_symbols = [ 213 # TODO(cwhipkey): review these and move to contrib or expose through 214 # documentation. 215 "generate_checkpoint_state_proto", # Used internally by saver. 216 "checkpoint_exists", # Only used in test? 217 "get_checkpoint_mtimes", # Only used in test? 218 219 # Legacy: remove. 220 "do_quantize_training_on_graphdef", # At least use grah_def, not graphdef. 221 # No uses within tensorflow. 222 "queue_runner", # Use tf.train.start_queue_runner etc directly. 223 # This is also imported internally. 224 225 # TODO(drpng): document these. The reference in howtos/distributed does 226 # not link. 227 "SyncReplicasOptimizer", 228 # Protobufs: 229 "BytesList", # from example_pb2. 230 "ClusterDef", 231 "Example", # from example_pb2 232 "Feature", # from example_pb2 233 "Features", # from example_pb2 234 "FeatureList", # from example_pb2 235 "FeatureLists", # from example_pb2 236 "FloatList", # from example_pb2. 237 "Int64List", # from example_pb2. 238 "JobDef", 239 "SaverDef", # From saver_pb2. 240 "SequenceExample", # from example_pb2. 241 "ServerDef", 242 ] 243 244 # pylint: disable=undefined-variable 245 tf_export("train.BytesList")(BytesList) 246 tf_export("train.ClusterDef")(ClusterDef) 247 tf_export("train.Example")(Example) 248 tf_export("train.Feature")(Feature) 249 tf_export("train.Features")(Features) 250 tf_export("train.FeatureList")(FeatureList) 251 tf_export("train.FeatureLists")(FeatureLists) 252 tf_export("train.FloatList")(FloatList) 253 tf_export("train.Int64List")(Int64List) 254 tf_export("train.JobDef")(JobDef) 255 tf_export("train.SaverDef")(SaverDef) 256 tf_export("train.SequenceExample")(SequenceExample) 257 tf_export("train.ServerDef")(ServerDef) 258 # pylint: enable=undefined-variable 259 260 # Include extra modules for docstrings because: 261 # * Input methods in tf.train are documented in io_ops. 262 # * Saver methods in tf.train are documented in state_ops. 263 remove_undocumented(__name__, _allowed_symbols, 264 [_sys.modules[__name__], _io_ops, _sdca_ops, _state_ops]) 265