Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstdio>
     17 #include <functional>
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/cc/ops/standard_ops.h"
     22 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
     23 #include "tensorflow/core/distributed_runtime/server_lib.h"
     24 #include "tensorflow/core/framework/graph.pb.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/graph/default_device.h"
     27 #include "tensorflow/core/graph/graph_def_builder.h"
     28 #include "tensorflow/core/lib/core/threadpool.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/lib/strings/stringprintf.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/test.h"
     33 #include "tensorflow/core/platform/test_benchmark.h"
     34 #include "tensorflow/core/platform/types.h"
     35 #include "tensorflow/core/protobuf/cluster.pb.h"
     36 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
     37 #include "tensorflow/core/public/session.h"
     38 
     39 namespace tensorflow {
     40 
     41 static const int kWorkers = 60;
     42 static thread::ThreadPool* worker_threads;
     43 
     44 void MakeGRPCCluster(const SessionOptions& options, int n,
     45                      std::vector<string>* workers,
     46                      std::vector<DeviceAttributes>* devices) {
     47   CHECK_GE(n, 1);
     48 
     49   workers->clear();
     50   std::vector<int> port(n);
     51   for (int i = 0; i < n; ++i) {
     52     port[i] = testing::PickUnusedPortOrDie();
     53     workers->push_back(strings::StrCat("grpc://localhost:", port[i]));
     54   }
     55 
     56   int num_cpus = 1;
     57   int num_gpus = 0;
     58   auto iter = options.config.device_count().find("CPU");
     59   if (iter != options.config.device_count().end()) {
     60     num_cpus = iter->second;
     61   }
     62   iter = options.config.device_count().find("GPU");
     63   if (iter != options.config.device_count().end()) {
     64     num_gpus = iter->second;
     65   }
     66 
     67   worker_threads = new thread::ThreadPool(Env::Default(), "worker_threads", n);
     68   for (int worker_idx = 0; worker_idx < n; ++worker_idx) {
     69     worker_threads->Schedule([worker_idx, n, num_cpus, num_gpus, &port] {
     70       ServerDef server;
     71       server.set_protocol("grpc");
     72       server.set_job_name("localhost");
     73       server.set_task_index(worker_idx);
     74 
     75       auto job_def = server.mutable_cluster()->add_job();
     76       job_def->set_name("localhost");
     77       for (int i = 0; i < n; i++) {
     78         (*(job_def->mutable_tasks()))[i] =
     79             strings::StrCat("localhost:", port[i]);
     80       }
     81 
     82       auto config = server.mutable_default_session_config();
     83       (*config->mutable_device_count())["CPU"] = num_cpus;
     84       (*config->mutable_device_count())["GPU"] = num_gpus;
     85 
     86       std::unique_ptr<ServerInterface> svr;
     87       TF_CHECK_OK(NewServer(server, &svr));
     88       TF_CHECK_OK(svr->Start());
     89       TF_CHECK_OK(svr->Join());
     90     });
     91   }
     92 
     93   // Get attributes for all devices.
     94   LOG(ERROR) << "W '" << (*workers)[0] << "'";
     95   SessionOptions options_copy(options);
     96   options_copy.target = (*workers)[0];
     97   std::unique_ptr<GrpcSession> session;
     98   TF_CHECK_OK(GrpcSession::Create(options_copy, &session));
     99   TF_CHECK_OK(session->ListDevices(devices));
    100 }
    101 
    102 struct Cluster {
    103   SessionOptions options;
    104   std::vector<string> workers;
    105   std::vector<DeviceAttributes> devices;  // One per process
    106 
    107   Cluster() {
    108     (*options.config.mutable_device_count())["CPU"] = 1;
    109     options.config.set_intra_op_parallelism_threads(1);
    110     options.config.set_inter_op_parallelism_threads(1);
    111     MakeGRPCCluster(options, kWorkers, &workers, &devices);
    112     LOG(ERROR) << "C " << workers.size() << " " << devices.size() << " "
    113                << workers[0] << " " << workers[1];
    114     options.target = workers[0];
    115   }
    116 };
    117 
    118 static const Cluster* GetCluster() {
    119   static Cluster* result = new Cluster;
    120   return result;
    121 }
    122 
    123 // Make a program with specified number of stages and "width" ops per stage.
    124 GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
    125                         bool use_multiple_devices, const Cluster* cluster) {
    126   CHECK_GE(cluster->devices.size(), width);
    127 
    128   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    129 
    130   Scope s = Scope::NewRootScope();
    131 
    132   // x is from the feed.
    133   Output x = Const(s.WithOpName("x"), 0.0f, {tensor_size, 1});
    134 
    135   // Create stages.
    136   std::vector<Output> last_stage;
    137   last_stage.push_back(x);
    138   for (int i = 0; i < num_stages; i++) {
    139     std::vector<Output> this_stage;
    140     for (int j = 0; j < width; j++) {
    141       Output combine = AddN(
    142           s.WithDevice(cluster->devices[use_multiple_devices ? j : 0].name()),
    143           last_stage);
    144       this_stage.push_back(combine);
    145     }
    146     last_stage = this_stage;
    147   }
    148 
    149   // Create output.
    150   /* Output y =*/AddN(s.WithOpName("y"), last_stage);
    151 
    152   GraphDef def;
    153   TF_CHECK_OK(s.ToGraphDef(&def));
    154   return def;
    155 }
    156 
    157 string DebugString(const Tensor& x, const Tensor& y, int tensor_size) {
    158   CHECK_EQ(x.NumElements(), tensor_size);
    159   CHECK_EQ(y.NumElements(), tensor_size);
    160   auto x_flat = x.flat<float>();
    161   auto y_flat = y.flat<float>();
    162   // Just print the first couple of elements of each tensor
    163   CHECK_GE(tensor_size, 2);
    164   return strings::Printf("x = [%8.6f %8.6f] y = [%8.6f %8.6f]", x_flat(0),
    165                          x_flat(1), y_flat(0), y_flat(1));
    166 }
    167 
    168 // TODO: Support sharding and depth.
    169 static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
    170                       bool use_multiple_devices) {
    171   testing::StopTiming();
    172   const Cluster* cluster = GetCluster();
    173 
    174   // Creates a session.
    175   std::unique_ptr<Session> session(NewSession(cluster->options));
    176   GraphDef def = CreateGraphDef(num_stages, width, tensor_size,
    177                                 use_multiple_devices, cluster);
    178   graph::SetDefaultDevice(cluster->devices[0].name(), &def);
    179 
    180   TF_CHECK_OK(session->Create(def));
    181 
    182   // Randomly initialize the input.
    183   Tensor x(DT_FLOAT, TensorShape({tensor_size, 1}));
    184 
    185   testing::SetLabel(
    186       strings::StrCat(def.node_size(), " nodes; ",
    187                       use_multiple_devices ? "Multi device" : "Single device",
    188                       "; tensor bytes/send: ", tensor_size * sizeof(float)));
    189 
    190   std::vector<Tensor> outputs;
    191 
    192   // Do a few warmup iterations.
    193   for (int i = 0; i < 3; i++) {
    194     outputs.clear();
    195     TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs));
    196     CHECK_EQ(size_t{1}, outputs.size());
    197 
    198     if (i == 0) {
    199       // Print out x, and y.
    200       const Tensor& y = outputs[0];
    201       VLOG(1) << DebugString(x, y, tensor_size);
    202     }
    203   }
    204 
    205   // Iterations.
    206   testing::StartTiming();
    207   for (int i = 0; i < iters; i++) {
    208     outputs.clear();
    209     TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs));
    210     CHECK_EQ(size_t{1}, outputs.size());
    211   }
    212   testing::StopTiming();
    213   TF_CHECK_OK(session->Close());
    214 }
    215 static void BM_ShardedProgram(int iters, int width, int num_stages) {
    216   BM_Helper(iters, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/);
    217 }
    218 BENCHMARK(BM_ShardedProgram)
    219     ->ArgPair(1, 1)
    220     ->ArgPair(1, 3)
    221     ->ArgPair(1, 5)
    222     ->ArgPair(1, 15)
    223     ->ArgPair(1, 60)
    224     ->ArgPair(15, 1)
    225     ->ArgPair(15, 3)
    226     ->ArgPair(15, 5)
    227     ->ArgPair(30, 1)
    228     ->ArgPair(30, 2)
    229     ->ArgPair(30, 3)
    230     ->ArgPair(30, 5)
    231     ->ArgPair(60, 1)
    232     ->ArgPair(60, 3)
    233     ->ArgPair(60, 5);
    234 
    235 static void BM_RPC(int iters, int width, int tensor_size) {
    236   BM_Helper(iters, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/);
    237 }
    238 BENCHMARK(BM_RPC)->ArgPair(30, 2)->ArgPair(30, 1000)->ArgPair(30, 100000);
    239 
    240 static void BM_SingleDevice(int iters, int width, int num_stages) {
    241   BM_Helper(iters, width, num_stages, 2 /*tensor_size*/,
    242             false /*not multi-device*/);
    243 }
    244 BENCHMARK(BM_SingleDevice)
    245     ->ArgPair(1, 1)
    246     ->ArgPair(30, 2)
    247     ->ArgPair(60, 5)
    248     ->ArgPair(4, 10000)
    249     ->ArgPair(1, 1000000);
    250 
    251 }  // namespace tensorflow
    252