Home | History | Annotate | Download | only in tutorials
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstdio>
     17 #include <functional>
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/cc/ops/standard_ops.h"
     22 #include "tensorflow/core/framework/graph.pb.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/graph/default_device.h"
     25 #include "tensorflow/core/graph/graph_def_builder.h"
     26 #include "tensorflow/core/lib/core/threadpool.h"
     27 #include "tensorflow/core/lib/strings/stringprintf.h"
     28 #include "tensorflow/core/platform/init_main.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 #include "tensorflow/core/platform/types.h"
     31 #include "tensorflow/core/public/session.h"
     32 
     33 using tensorflow::string;
     34 using tensorflow::int32;
     35 
     36 namespace tensorflow {
     37 namespace example {
     38 
     39 struct Options {
     40   int num_concurrent_sessions = 1;   // The number of concurrent sessions
     41   int num_concurrent_steps = 10;     // The number of concurrent steps
     42   int num_iterations = 100;          // Each step repeats this many times
     43   bool use_gpu = false;              // Whether to use gpu in the training
     44 };
     45 
     46 // A = [3 2; -1 0]; x = rand(2, 1);
     47 // We want to compute the largest eigenvalue for A.
     48 // repeat x = y / y.norm(); y = A * x; end
     49 GraphDef CreateGraphDef() {
     50   // TODO(jeff,opensource): This should really be a more interesting
     51   // computation.  Maybe turn this into an mnist model instead?
     52   Scope root = Scope::NewRootScope();
     53   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
     54 
     55   // A = [3 2; -1 0].  Using Const<float> means the result will be a
     56   // float tensor even though the initializer has integers.
     57   auto a = Const<float>(root, {{3, 2}, {-1, 0}});
     58 
     59   // x = [1.0; 1.0]
     60   auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
     61 
     62   // y = A * x
     63   auto y = MatMul(root.WithOpName("y"), a, x);
     64 
     65   // y2 = y.^2
     66   auto y2 = Square(root, y);
     67 
     68   // y2_sum = sum(y2).  Note that you can pass constants directly as
     69   // inputs.  Sum() will automatically create a Const node to hold the
     70   // 0 value.
     71   auto y2_sum = Sum(root, y2, 0);
     72 
     73   // y_norm = sqrt(y2_sum)
     74   auto y_norm = Sqrt(root, y2_sum);
     75 
     76   // y_normalized = y ./ y_norm
     77   Div(root.WithOpName("y_normalized"), y, y_norm);
     78 
     79   GraphDef def;
     80   TF_CHECK_OK(root.ToGraphDef(&def));
     81 
     82   return def;
     83 }
     84 
     85 string DebugString(const Tensor& x, const Tensor& y) {
     86   CHECK_EQ(x.NumElements(), 2);
     87   CHECK_EQ(y.NumElements(), 2);
     88   auto x_flat = x.flat<float>();
     89   auto y_flat = y.flat<float>();
     90   // Compute an estimate of the eigenvalue via
     91   //      (x' A x) / (x' x) = (x' y) / (x' x)
     92   // and exploit the fact that x' x = 1 by assumption
     93   Eigen::Tensor<float, 0, Eigen::RowMajor> lambda = (x_flat * y_flat).sum();
     94   return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
     95                          lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1));
     96 }
     97 
     98 void ConcurrentSteps(const Options* opts, int session_index) {
     99   // Creates a session.
    100   SessionOptions options;
    101   std::unique_ptr<Session> session(NewSession(options));
    102   GraphDef def = CreateGraphDef();
    103   if (options.target.empty()) {
    104     graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);
    105   }
    106 
    107   TF_CHECK_OK(session->Create(def));
    108 
    109   // Spawn M threads for M concurrent steps.
    110   const int M = opts->num_concurrent_steps;
    111   std::unique_ptr<thread::ThreadPool> step_threads(
    112       new thread::ThreadPool(Env::Default(), "trainer", M));
    113 
    114   for (int step = 0; step < M; ++step) {
    115     step_threads->Schedule([&session, opts, session_index, step]() {
    116       // Randomly initialize the input.
    117       Tensor x(DT_FLOAT, TensorShape({2, 1}));
    118       auto x_flat = x.flat<float>();
    119       x_flat.setRandom();
    120       Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
    121           x_flat.square().sum().sqrt().inverse();
    122       x_flat = x_flat * inv_norm();
    123 
    124       // Iterations.
    125       std::vector<Tensor> outputs;
    126       for (int iter = 0; iter < opts->num_iterations; ++iter) {
    127         outputs.clear();
    128         TF_CHECK_OK(
    129             session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));
    130         CHECK_EQ(size_t{2}, outputs.size());
    131 
    132         const Tensor& y = outputs[0];
    133         const Tensor& y_norm = outputs[1];
    134         // Print out lambda, x, and y.
    135         std::printf("%06d/%06d %s\n", session_index, step,
    136                     DebugString(x, y).c_str());
    137         // Copies y_normalized to x.
    138         x = y_norm;
    139       }
    140     });
    141   }
    142 
    143   // Delete the threadpool, thus waiting for all threads to complete.
    144   step_threads.reset(nullptr);
    145   TF_CHECK_OK(session->Close());
    146 }
    147 
    148 void ConcurrentSessions(const Options& opts) {
    149   // Spawn N threads for N concurrent sessions.
    150   const int N = opts.num_concurrent_sessions;
    151 
    152   // At the moment our Session implementation only allows
    153   // one concurrently computing Session on GPU.
    154   CHECK_EQ(1, N) << "Currently can only have one concurrent session.";
    155 
    156   thread::ThreadPool session_threads(Env::Default(), "trainer", N);
    157   for (int i = 0; i < N; ++i) {
    158     session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
    159   }
    160 }
    161 
    162 }  // end namespace example
    163 }  // end namespace tensorflow
    164 
    165 namespace {
    166 
    167 bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
    168                     int32* dst) {
    169   if (arg.Consume(flag) && arg.Consume("=")) {
    170     char extra;
    171     return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
    172   }
    173 
    174   return false;
    175 }
    176 
    177 bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
    178                    bool* dst) {
    179   if (arg.Consume(flag)) {
    180     if (arg.empty()) {
    181       *dst = true;
    182       return true;
    183     }
    184 
    185     if (arg == "=true") {
    186       *dst = true;
    187       return true;
    188     } else if (arg == "=false") {
    189       *dst = false;
    190       return true;
    191     }
    192   }
    193 
    194   return false;
    195 }
    196 
    197 }  // namespace
    198 
    199 int main(int argc, char* argv[]) {
    200   tensorflow::example::Options opts;
    201   std::vector<char*> unknown_flags;
    202   for (int i = 1; i < argc; ++i) {
    203     if (string(argv[i]) == "--") {
    204       while (i < argc) {
    205         unknown_flags.push_back(argv[i]);
    206         ++i;
    207       }
    208       break;
    209     }
    210 
    211     if (ParseInt32Flag(argv[i], "--num_concurrent_sessions",
    212                        &opts.num_concurrent_sessions) ||
    213         ParseInt32Flag(argv[i], "--num_concurrent_steps",
    214                        &opts.num_concurrent_steps) ||
    215         ParseInt32Flag(argv[i], "--num_iterations", &opts.num_iterations) ||
    216         ParseBoolFlag(argv[i], "--use_gpu", &opts.use_gpu)) {
    217       continue;
    218     }
    219 
    220     fprintf(stderr, "Unknown flag: %s\n", argv[i]);
    221     return -1;
    222   }
    223 
    224   // Passthrough any unknown flags.
    225   int dst = 1;  // Skip argv[0]
    226   for (char* f : unknown_flags) {
    227     argv[dst++] = f;
    228   }
    229   argv[dst++] = nullptr;
    230   argc = static_cast<int>(unknown_flags.size() + 1);
    231   tensorflow::port::InitMain(argv[0], &argc, &argv);
    232   tensorflow::example::ConcurrentSessions(opts);
    233 }
    234