Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Support for training models.
     17 
     18 See the @{$python/train} guide.
     19 
     20 @@Optimizer
     21 @@GradientDescentOptimizer
     22 @@AdadeltaOptimizer
     23 @@AdagradOptimizer
     24 @@AdagradDAOptimizer
     25 @@MomentumOptimizer
     26 @@AdamOptimizer
     27 @@FtrlOptimizer
     28 @@ProximalGradientDescentOptimizer
     29 @@ProximalAdagradOptimizer
     30 @@RMSPropOptimizer
     31 @@gradients
     32 @@AggregationMethod
     33 @@stop_gradient
     34 @@hessians
     35 @@clip_by_value
     36 @@clip_by_norm
     37 @@clip_by_average_norm
     38 @@clip_by_global_norm
     39 @@global_norm
     40 @@cosine_decay
     41 @@cosine_decay_restarts
     42 @@linear_cosine_decay
     43 @@noisy_linear_cosine_decay
     44 @@exponential_decay
     45 @@inverse_time_decay
     46 @@natural_exp_decay
     47 @@piecewise_constant
     48 @@polynomial_decay
     49 @@ExponentialMovingAverage
     50 @@Coordinator
     51 @@QueueRunner
     52 @@LooperThread
     53 @@add_queue_runner
     54 @@start_queue_runners
     55 @@Server
     56 @@Supervisor
     57 @@SessionManager
     58 @@ClusterSpec
     59 @@replica_device_setter
     60 @@MonitoredTrainingSession
     61 @@MonitoredSession
     62 @@SingularMonitoredSession
     63 @@Scaffold
     64 @@SessionCreator
     65 @@ChiefSessionCreator
     66 @@WorkerSessionCreator
     67 @@summary_iterator
     68 @@SessionRunHook
     69 @@SessionRunArgs
     70 @@SessionRunContext
     71 @@SessionRunValues
     72 @@LoggingTensorHook
     73 @@StopAtStepHook
     74 @@CheckpointSaverHook
     75 @@CheckpointSaverListener
     76 @@NewCheckpointReader
     77 @@StepCounterHook
     78 @@NanLossDuringTrainingError
     79 @@NanTensorHook
     80 @@SummarySaverHook
     81 @@GlobalStepWaiterHook
     82 @@FinalOpsHook
     83 @@FeedFnHook
     84 @@ProfilerHook
     85 @@SecondOrStepTimer
     86 @@global_step
     87 @@basic_train_loop
     88 @@get_global_step
     89 @@get_or_create_global_step
     90 @@create_global_step
     91 @@assert_global_step
     92 @@write_graph
     93 @@load_checkpoint
     94 @@load_variable
     95 @@list_variables
     96 @@init_from_checkpoint
     97 """
     98 
     99 # Optimizers.
    100 from __future__ import absolute_import
    101 from __future__ import division
    102 from __future__ import print_function
    103 
    104 import sys as _sys
    105 
    106 from tensorflow.python.ops import io_ops as _io_ops
    107 from tensorflow.python.ops import sdca_ops as _sdca_ops
    108 from tensorflow.python.ops import state_ops as _state_ops
    109 from tensorflow.python.util.all_util import remove_undocumented
    110 
    111 # pylint: disable=g-bad-import-order,unused-import
    112 from tensorflow.python.ops.sdca_ops import sdca_optimizer
    113 from tensorflow.python.ops.sdca_ops import sdca_fprint
    114 from tensorflow.python.ops.sdca_ops import sdca_shrink_l1
    115 from tensorflow.python.training.adadelta import AdadeltaOptimizer
    116 from tensorflow.python.training.adagrad import AdagradOptimizer
    117 from tensorflow.python.training.adagrad_da import AdagradDAOptimizer
    118 from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer
    119 from tensorflow.python.training.adam import AdamOptimizer
    120 from tensorflow.python.training.ftrl import FtrlOptimizer
    121 from tensorflow.python.training.momentum import MomentumOptimizer
    122 from tensorflow.python.training.moving_averages import ExponentialMovingAverage
    123 from tensorflow.python.training.optimizer import Optimizer
    124 from tensorflow.python.training.rmsprop import RMSPropOptimizer
    125 from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
    126 from tensorflow.python.training.proximal_gradient_descent import ProximalGradientDescentOptimizer
    127 from tensorflow.python.training.sync_replicas_optimizer import SyncReplicasOptimizer
    128 
    129 # Utility classes for training.
    130 from tensorflow.python.training.coordinator import Coordinator
    131 from tensorflow.python.training.coordinator import LooperThread
    132 # go/tf-wildcard-import
    133 # pylint: disable=wildcard-import
    134 from tensorflow.python.training.queue_runner import *
    135 
    136 # For the module level doc.
    137 from tensorflow.python.training import input as _input
    138 from tensorflow.python.training.input import *  # pylint: disable=redefined-builtin
    139 # pylint: enable=wildcard-import
    140 
    141 from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer
    142 from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook
    143 from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook
    144 from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook
    145 from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener
    146 from tensorflow.python.training.basic_session_run_hooks import StepCounterHook
    147 from tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError
    148 from tensorflow.python.training.basic_session_run_hooks import NanTensorHook
    149 from tensorflow.python.training.basic_session_run_hooks import SummarySaverHook
    150 from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterHook
    151 from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
    152 from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
    153 from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
    154 from tensorflow.python.training.basic_loops import basic_train_loop
    155 from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
    156 from tensorflow.python.training.checkpoint_utils import list_variables
    157 from tensorflow.python.training.checkpoint_utils import load_checkpoint
    158 from tensorflow.python.training.checkpoint_utils import load_variable
    159 
    160 from tensorflow.python.training.device_setter import replica_device_setter
    161 from tensorflow.python.training.monitored_session import Scaffold
    162 from tensorflow.python.training.monitored_session import MonitoredTrainingSession
    163 from tensorflow.python.training.monitored_session import SessionCreator
    164 from tensorflow.python.training.monitored_session import ChiefSessionCreator
    165 from tensorflow.python.training.monitored_session import WorkerSessionCreator
    166 from tensorflow.python.training.monitored_session import MonitoredSession
    167 from tensorflow.python.training.monitored_session import SingularMonitoredSession
    168 from tensorflow.python.training.saver import Saver
    169 from tensorflow.python.training.saver import checkpoint_exists
    170 from tensorflow.python.training.saver import generate_checkpoint_state_proto
    171 from tensorflow.python.training.saver import get_checkpoint_mtimes
    172 from tensorflow.python.training.saver import get_checkpoint_state
    173 from tensorflow.python.training.saver import latest_checkpoint
    174 from tensorflow.python.training.saver import update_checkpoint_state
    175 from tensorflow.python.training.saver import export_meta_graph
    176 from tensorflow.python.training.saver import import_meta_graph
    177 from tensorflow.python.training.session_run_hook import SessionRunHook
    178 from tensorflow.python.training.session_run_hook import SessionRunArgs
    179 from tensorflow.python.training.session_run_hook import SessionRunContext
    180 from tensorflow.python.training.session_run_hook import SessionRunValues
    181 from tensorflow.python.training.session_manager import SessionManager
    182 from tensorflow.python.training.summary_io import summary_iterator
    183 from tensorflow.python.training.supervisor import Supervisor
    184 from tensorflow.python.training.training_util import write_graph
    185 from tensorflow.python.training.training_util import global_step
    186 from tensorflow.python.training.training_util import get_global_step
    187 from tensorflow.python.training.training_util import assert_global_step
    188 from tensorflow.python.training.training_util import create_global_step
    189 from tensorflow.python.training.training_util import get_or_create_global_step
    190 from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
    191 from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
    192 from tensorflow.python.util.tf_export import tf_export
    193 
    194 # pylint: disable=wildcard-import
    195 # Training data protos.
    196 from tensorflow.core.example.example_pb2 import *
    197 from tensorflow.core.example.feature_pb2 import *
    198 from tensorflow.core.protobuf.saver_pb2 import *
    199 
    200 # Utility op.  Open Source. TODO(touts): move to nn?
    201 from tensorflow.python.training.learning_rate_decay import *
    202 # pylint: enable=wildcard-import
    203 
    204 # Distributed computing support.
    205 from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
    206 from tensorflow.core.protobuf.cluster_pb2 import JobDef
    207 from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
    208 from tensorflow.python.training.server_lib import ClusterSpec
    209 from tensorflow.python.training.server_lib import Server
    210 
    211 # Symbols whitelisted for export without documentation.
    212 _allowed_symbols = [
    213     # TODO(cwhipkey): review these and move to contrib or expose through
    214     # documentation.
    215     "generate_checkpoint_state_proto",  # Used internally by saver.
    216     "checkpoint_exists",  # Only used in test?
    217     "get_checkpoint_mtimes",  # Only used in test?
    218 
    219     # Legacy: remove.
    220     "do_quantize_training_on_graphdef",  # At least use grah_def, not graphdef.
    221     # No uses within tensorflow.
    222     "queue_runner",  # Use tf.train.start_queue_runner etc directly.
    223     # This is also imported internally.
    224 
    225     # TODO(drpng): document these. The reference in howtos/distributed does
    226     # not link.
    227     "SyncReplicasOptimizer",
    228     # Protobufs:
    229     "BytesList",  # from example_pb2.
    230     "ClusterDef",
    231     "Example",  # from example_pb2
    232     "Feature",  # from example_pb2
    233     "Features",  # from example_pb2
    234     "FeatureList",  # from example_pb2
    235     "FeatureLists",  # from example_pb2
    236     "FloatList",  # from example_pb2.
    237     "Int64List",  # from example_pb2.
    238     "JobDef",
    239     "SaverDef",  # From saver_pb2.
    240     "SequenceExample",  # from example_pb2.
    241     "ServerDef",
    242 ]
    243 
    244 # pylint: disable=undefined-variable
    245 tf_export("train.BytesList")(BytesList)
    246 tf_export("train.ClusterDef")(ClusterDef)
    247 tf_export("train.Example")(Example)
    248 tf_export("train.Feature")(Feature)
    249 tf_export("train.Features")(Features)
    250 tf_export("train.FeatureList")(FeatureList)
    251 tf_export("train.FeatureLists")(FeatureLists)
    252 tf_export("train.FloatList")(FloatList)
    253 tf_export("train.Int64List")(Int64List)
    254 tf_export("train.JobDef")(JobDef)
    255 tf_export("train.SaverDef")(SaverDef)
    256 tf_export("train.SequenceExample")(SequenceExample)
    257 tf_export("train.ServerDef")(ServerDef)
    258 # pylint: enable=undefined-variable
    259 
    260 # Include extra modules for docstrings because:
    261 # * Input methods in tf.train are documented in io_ops.
    262 # * Saver methods in tf.train are documented in state_ops.
    263 remove_undocumented(__name__, _allowed_symbols,
    264                     [_sys.modules[__name__], _io_ops, _sdca_ops, _state_ops])
    265