Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     17 
     18 #include <queue>
     19 #include <set>
     20 #include <unordered_map>
     21 
     22 #include "tensorflow/compiler/tf2xla/sharding_util.h"
     23 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
     24 #include "tensorflow/compiler/xla/xla_data.pb.h"
     25 #include "tensorflow/core/framework/graph.pb.h"
     26 #include "tensorflow/core/framework/graph_def_util.h"
     27 #include "tensorflow/core/framework/node_def.pb.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_shape.pb.h"
     30 #include "tensorflow/core/framework/versions.pb.h"
     31 #include "tensorflow/core/graph/tensor_id.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/lib/gtl/optional.h"
     35 #include "tensorflow/core/lib/strings/strcat.h"
     36 
     37 namespace tensorflow {
     38 
     39 namespace {
     40 
     41 Status ValidateTensorId(const tf2xla::TensorId& id) {
     42   if (id.node_name().empty()) {
     43     return errors::InvalidArgument("TensorId node_name must be non-empty");
     44   }
     45   if (id.output_index() < 0) {
     46     return errors::InvalidArgument("TensorId output_index must be positive");
     47   }
     48   return Status::OK();
     49 }
     50 
     51 Status CheckNameDuplicates(const string& kind, const string& name,
     52                            std::set<string>* names) {
     53   if (!name.empty()) {
     54     if (!names->insert(name).second) {
     55       return errors::InvalidArgument("duplicate ", kind, " name: ", name);
     56     }
     57   }
     58   return Status::OK();
     59 }
     60 
     61 Status CheckFeedFetchNameConflicts(const string& kind,
     62                                    const std::set<string>& names) {
     63   // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
     64   // since that will cause a collision in codegen symbols.
     65   for (const string& name : names) {
     66     const string name_data(name + "_data");
     67     if (names.find(name_data) != names.end()) {
     68       return errors::InvalidArgument("conflicting ", kind, " name: ", name,
     69                                      " and ", name_data);
     70     }
     71   }
     72   return Status::OK();
     73 }
     74 
     75 }  // namespace
     76 
     77 Status ValidateConfig(const tf2xla::Config& config) {
     78   std::set<string> names;
     79   for (const tf2xla::Feed& feed : config.feed()) {
     80     TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
     81     TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
     82     TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
     83   }
     84   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
     85   names.clear();
     86   for (const tf2xla::Fetch& fetch : config.fetch()) {
     87     TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
     88     TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
     89   }
     90   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
     91   if (config.fetch().empty()) {
     92     return errors::InvalidArgument("fetches must be specified");
     93   }
     94   return Status::OK();
     95 }
     96 
     97 Status AddPlaceholdersForFeeds(
     98     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
     99     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
    100   struct PlaceholderInfo {
    101     const tf2xla::Feed* feed = nullptr;  // point to Feed in <config>.
    102     string placeholder_name;
    103     DataType data_type = DT_INVALID;
    104   };
    105 
    106   // Put each fed tensor into a map by name:port. A map is used for determinism
    107   // when creating placeholders (genrules want deterministic output).
    108   std::map<string, PlaceholderInfo> placeholder_info;
    109   for (int i = 0; i < config.feed_size(); ++i) {
    110     const tf2xla::Feed* feed = &config.feed(i);
    111     const string name_port = TensorIdToString(feed->id());
    112     PlaceholderInfo& info = placeholder_info[name_port];
    113     info.feed = feed;
    114     info.placeholder_name = strings::StrCat(
    115         "aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
    116     (*feed_remapping)[name_port] = info.placeholder_name;
    117   }
    118 
    119   // Verify node exists and determine data type.
    120   std::unordered_map<string, const NodeDef*> name_to_node;
    121   for (int i = 0; i < graph_def->node_size(); ++i) {
    122     name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
    123   }
    124   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
    125     PlaceholderInfo& info = it->second;
    126     const tf2xla::TensorId& feed_id = info.feed->id();
    127 
    128     // Find the existing node and determine data type.
    129     auto node_it = name_to_node.find(feed_id.node_name());
    130     if (node_it == name_to_node.end()) {
    131       return errors::NotFound("Can't find feed node: ",
    132                               TensorIdToString(feed_id));
    133     }
    134     const NodeDef* existing = node_it->second;
    135 
    136     if (info.feed->type() != DT_INVALID) {
    137       info.data_type = info.feed->type();
    138     } else {
    139       // Build the node in order to infer its type.
    140 
    141       // Must first add default attrs as well, so do this in a copied GraphDef.
    142       GraphDef gd;
    143       *gd.mutable_versions() = graph_def->versions();
    144       *gd.add_node() = *existing;
    145       TF_RETURN_IF_ERROR(
    146           AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
    147 
    148       // Now build the node from the copied node def.
    149       Graph g(op_registry);
    150       g.set_versions(graph_def->versions());
    151       Status status;
    152       Node* feed_node = g.AddNode(gd.node(0), &status);
    153       TF_RETURN_IF_ERROR(status);
    154       info.data_type =
    155           BaseType(feed_node->output_type(info.feed->id().output_index()));
    156     }
    157   }
    158 
    159   // Create placeholders. Note that we could avoid creating a placeholder for
    160   // feeds which are already placeholders, but we omit that to avoid more cases
    161   // in this code.
    162   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
    163     const PlaceholderInfo& info = it->second;
    164     NodeDef* d = graph_def->add_node();
    165     d->set_name(info.placeholder_name);
    166     d->set_op("PlaceholderV2");
    167     auto& attr_map = *d->mutable_attr();
    168     attr_map["dtype"].set_type(info.data_type);
    169     *attr_map["shape"].mutable_shape() = info.feed->shape();
    170   }
    171 
    172   // Rewrite references to the fed tensors to refer to the placeholder.
    173   for (int i = 0; i < graph_def->node_size(); ++i) {
    174     NodeDef* node_def = graph_def->mutable_node(i);
    175     for (int j = 0; j < node_def->input_size(); ++j) {
    176       auto id = ParseTensorName(node_def->input(j));
    177       auto it = placeholder_info.find(id.ToString());
    178       if (it != placeholder_info.end()) {
    179         node_def->set_input(j, it->second.placeholder_name);
    180       }
    181     }
    182   }
    183 
    184   return Status::OK();
    185 }
    186 
    187 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
    188                          GraphDef* out) {
    189   *out = in;
    190   out->clear_node();
    191 
    192   // Tensors needed for feeding.
    193   std::set<std::pair<string, int>> feed_tensors;
    194   for (const tf2xla::Feed& feed : config.feed()) {
    195     feed_tensors.insert(
    196         std::make_pair(feed.id().node_name(), feed.id().output_index()));
    197   }
    198 
    199   // Maps node name to reachability.
    200   std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
    201   for (const NodeDef& node : in.node()) {
    202     node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
    203   }
    204 
    205   // Traverse.
    206   std::queue<string> name_queue;
    207   for (int i = 0; i < config.fetch_size(); ++i) {
    208     name_queue.push(config.fetch(i).id().node_name());
    209   }
    210   while (!name_queue.empty()) {
    211     const string name = name_queue.front();
    212     name_queue.pop();
    213 
    214     auto find_it = node_by_name.find(name);
    215     if (find_it == node_by_name.end()) {
    216       return errors::InvalidArgument("While pruning graph, node ", name,
    217                                      " needed but not found in the graph.");
    218     }
    219     auto& map_entry = find_it->second;
    220     if (map_entry.first) {
    221       continue;
    222     }
    223     map_entry.first = true;
    224 
    225     // Push input nodes of the currently visited node to name_queue.
    226     for (const string& in_edge : map_entry.second->input()) {
    227       auto id = ParseTensorName(in_edge);
    228       const string node_name = id.first.ToString();
    229       if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
    230           feed_tensors.end()) {
    231         name_queue.push(node_name);
    232       } else {
    233         // The input tensor is from an edge that is being fed. Therefore,
    234         // we skip recursing down that edge, to avoid requiring nodes that
    235         // may not be needed (note that the input node may still be added
    236         // to name_queue later if one of its output edges is not being fed).
    237       }
    238     }
    239   }
    240 
    241   // Copy over, preserving order of original and only nodes that are reachable
    242   // from the fetches.
    243   out->mutable_node()->Reserve(in.node_size());
    244   for (const NodeDef& node : in.node()) {
    245     if (node_by_name[node.name()].first) {
    246       *out->add_node() = node;
    247     }
    248   }
    249   return Status::OK();
    250 }
    251 
    252 string TensorIdToString(const tf2xla::TensorId& id) {
    253   return strings::StrCat(id.node_name(), ":", id.output_index());
    254 }
    255 
    256 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
    257   int core = -1;
    258   const Node* matching_node = nullptr;
    259   for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
    260     if (edge->IsControlEdge()) continue;
    261     const Node* possible_match = out_edges ? edge->dst() : edge->src();
    262     TF_ASSIGN_OR_RETURN(
    263         tensorflow::gtl::optional<xla::OpSharding> sharding,
    264         ParseShardingFromDevice(
    265             *possible_match,
    266             /*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
    267     if (sharding.has_value()) {
    268       TF_RET_CHECK(sharding.value().type() ==
    269                    xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
    270       const int core_annotation = sharding.value().tile_assignment_devices(0);
    271       if (core == -1 || core > core_annotation) {
    272         core = core_annotation;
    273         matching_node = possible_match;
    274       }
    275     }
    276   }
    277   if (matching_node != nullptr) {
    278     n->set_assigned_device_name(matching_node->assigned_device_name());
    279     n->set_requested_device(matching_node->requested_device());
    280   }
    281   return Status::OK();
    282 }
    283 
    284 }  // namespace tensorflow
    285