Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
     17 
     18 #include <set>
     19 #include <string>
     20 #include <unordered_set>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/graph/algorithm.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 
     27 namespace tensorflow {
     28 namespace gpu_stream_util {
     29 
     30 Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
     31                      std::unordered_map<int, int>* node_to_stream_id) {
     32   VLOG(1) << "AssignStreams";
     33   Status status;
     34 
     35   // Sanity check arguments.
     36   if (graph == nullptr)
     37     status.Update(errors::InvalidArgument("Bad graph argument supplied."));
     38   if (node_to_stream_id == nullptr) {
     39     status.Update(
     40         errors::InvalidArgument("Bad node_to_stream_id argument supplied."));
     41   }
     42   if ((opts.max_streams < 1) || (opts.send_stream >= opts.max_streams) ||
     43       (opts.recv_stream >= opts.max_streams) ||
     44       (opts.const_stream >= opts.max_streams) ||
     45       (opts.compute_stream >= opts.max_streams)) {
     46     status.Update(errors::InvalidArgument("Bad graph argument supplied."));
     47   }
     48   TF_RETURN_IF_ERROR(status);
     49 
     50   // Topologically sort the nodes.
     51   std::vector<Node*> order;
     52   GetReversePostOrder(*graph, &order);
     53   if (VLOG_IS_ON(2)) {
     54     for (Node* n : order) {
     55       const int node_id = n->id();
     56       VLOG(2) << "Node " << node_id << " " << n->type_string() << " "
     57               << n->name() << " " << n->in_edges().size() << " inputs";
     58       for (const Edge* e : n->in_edges()) {
     59         VLOG(2) << "  Edge from " << e->src()->id() << "  " << e->src()->name()
     60                 << " fanout " << e->src()->out_edges().size();
     61       }
     62     }
     63   }
     64   // We perform stream assignment assuming a large number of
     65   // stream IDs and then map these down to the required number of streams
     66   // using simple round-robin.
     67   // Stream Assignment strategy:
     68   // 1. Nodes with zero inputs are always be executed on a
     69   // fresh stream.
     70   // 2. Try to execute a node on the same stream as one of its
     71   // inputs to avoid inter-stream dependencies.
     72   // 3. If any input comes from a node with a large fanout then
     73   // perhaps an indication that it is shared between parallel
     74   // streams of work. We choose a new stream here so that all consumers
     75   // of the tensor are likely to run in parallel.
     76   int highest_stream_id = -1;
     77   for (Node* n : order) {
     78     VLOG(3) << "Inspecting node " << n->DebugString();
     79     const int node_id = n->id();
     80     const string& op = n->type_string();
     81 
     82     // Determine a suitable stream to use.
     83     int stream_id = highest_stream_id + 1;
     84     for (const Edge* e : n->in_edges()) {
     85       const size_t fanout = e->src()->out_edges().size();
     86       if (fanout == 1) {
     87         stream_id = (*node_to_stream_id)[e->src()->id()];
     88         break;
     89       }
     90     }
     91     // Override stream for specific op types.
     92     if (op == "_Send") {
     93       if (opts.send_stream >= 0) stream_id = opts.send_stream;
     94     } else if (op == "_Recv") {
     95       if (opts.recv_stream >= 0) stream_id = opts.recv_stream;
     96     } else if (op == "Const") {
     97       if (opts.const_stream >= 0) stream_id = opts.const_stream;
     98     } else {
     99       if (opts.compute_stream >= 0) stream_id = opts.compute_stream;
    100     }
    101 
    102     (*node_to_stream_id)[node_id] = stream_id % opts.max_streams;
    103     highest_stream_id = std::max(stream_id, highest_stream_id);
    104   }
    105   VLOG(1) << "Identified " << highest_stream_id << " candidate streams for "
    106           << order.size() << " nodes.";
    107 
    108   return Status::OK();
    109 }
    110 
    111 }  // namespace gpu_stream_util
    112 }  // namespace tensorflow
    113