Home | History | Annotate | Download | only in python

Lines Matching defs:flags

46 flags = tf.app.flags
47 flags.DEFINE_string("data_dir", "/tmp/mnist-data",
49 flags.DEFINE_boolean("download_only", False,
52 flags.DEFINE_integer("task_index", None,
56 flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine."
58 flags.DEFINE_integer("replicas_to_aggregate", None,
62 flags.DEFINE_integer("hidden_units", 100,
64 flags.DEFINE_integer("train_steps", 200,
66 flags.DEFINE_integer("batch_size", 100, "Training batch size")
67 flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
68 flags.DEFINE_boolean(
73 flags.DEFINE_boolean(
78 flags.DEFINE_string("ps_hosts", "localhost:2222",
80 flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
82 flags.DEFINE_string("job_name", None, "job name: worker or ps")
84 FLAGS = flags.FLAGS
90 mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
91 if FLAGS.download_only:
94 if FLAGS.job_name is None or FLAGS.job_name == "":
96 if FLAGS.task_index is None or FLAGS.task_index == "":
99 print("job name = %s" % FLAGS.job_name)
100 print("task index = %d" % FLAGS.task_index)
103 ps_spec = FLAGS.ps_hosts.split(",")
104 worker_spec = FLAGS.worker_hosts.split(",")
111 if not FLAGS.existing_servers:
114 cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
115 if FLAGS.job_name == "ps":
118 is_chief = (FLAGS.task_index == 0)
119 if FLAGS.num_gpus > 0:
122 gpu = (FLAGS.task_index % FLAGS.num_gpus)
123 worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
124 elif FLAGS.num_gpus == 0:
127 worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
141 [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
144 hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
149 [FLAGS.hidden_units, 10],
150 stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
154 # Ops: located on the worker specified with FLAGS.task_index
164 opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
166 if FLAGS.sync_replicas:
167 if FLAGS.replicas_to_aggregate is None:
170 replicas_to_aggregate = FLAGS.replicas_to_aggregate
180 if FLAGS.sync_replicas:
194 if FLAGS.sync_replicas:
215 "/job:worker/task:%d" % FLAGS.task_index])
220 print("Worker %d: Initializing session..." % FLAGS.task_index)
223 FLAGS.task_index)
225 if FLAGS.existing_servers:
226 server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
233 print("Worker %d: Session initialization complete." % FLAGS.task_index)
235 if FLAGS.sync_replicas and is_chief:
247 batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
255 (now, FLAGS.task_index, local_step, step))
257 if step >= FLAGS.train_steps:
269 (FLAGS.train_steps, val_xent))