1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h" 17 18 #include <queue> 19 #include <set> 20 #include <unordered_map> 21 22 #include "tensorflow/compiler/tf2xla/sharding_util.h" 23 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/graph_def_util.h" 27 #include "tensorflow/core/framework/node_def.pb.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/tensor_shape.pb.h" 30 #include "tensorflow/core/framework/versions.pb.h" 31 #include "tensorflow/core/graph/tensor_id.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/lib/gtl/optional.h" 35 #include "tensorflow/core/lib/strings/strcat.h" 36 37 namespace tensorflow { 38 39 namespace { 40 41 Status ValidateTensorId(const tf2xla::TensorId& id) { 42 if (id.node_name().empty()) { 43 return errors::InvalidArgument("TensorId node_name must be non-empty"); 44 } 45 if (id.output_index() < 0) { 46 return errors::InvalidArgument("TensorId output_index must be positive"); 47 } 48 return Status::OK(); 49 } 50 51 Status CheckNameDuplicates(const string& kind, const string& name, 52 std::set<string>* names) { 53 if (!name.empty()) { 54 if (!names->insert(name).second) { 55 return errors::InvalidArgument("duplicate ", kind, " name: ", name); 56 } 57 } 58 return Status::OK(); 59 } 60 61 Status CheckFeedFetchNameConflicts(const string& kind, 62 const std::set<string>& names) { 63 // We don't allow the feeds or fetches to contain both "foo" and "foo_data", 64 // since that will cause a collision in codegen symbols. 65 for (const string& name : names) { 66 const string name_data(name + "_data"); 67 if (names.find(name_data) != names.end()) { 68 return errors::InvalidArgument("conflicting ", kind, " name: ", name, 69 " and ", name_data); 70 } 71 } 72 return Status::OK(); 73 } 74 75 } // namespace 76 77 Status ValidateConfig(const tf2xla::Config& config) { 78 std::set<string> names; 79 for (const tf2xla::Feed& feed : config.feed()) { 80 TF_RETURN_IF_ERROR(ValidateTensorId(feed.id())); 81 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape())); 82 TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names)); 83 } 84 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names)); 85 names.clear(); 86 for (const tf2xla::Fetch& fetch : config.fetch()) { 87 TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id())); 88 TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names)); 89 } 90 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); 91 if (config.fetch().empty()) { 92 return errors::InvalidArgument("fetches must be specified"); 93 } 94 return Status::OK(); 95 } 96 97 Status AddPlaceholdersForFeeds( 98 const tf2xla::Config& config, const OpRegistryInterface* op_registry, 99 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) { 100 struct PlaceholderInfo { 101 const tf2xla::Feed* feed = nullptr; // point to Feed in <config>. 102 string placeholder_name; 103 DataType data_type = DT_INVALID; 104 }; 105 106 // Put each fed tensor into a map by name:port. A map is used for determinism 107 // when creating placeholders (genrules want deterministic output). 108 std::map<string, PlaceholderInfo> placeholder_info; 109 for (int i = 0; i < config.feed_size(); ++i) { 110 const tf2xla::Feed* feed = &config.feed(i); 111 const string name_port = TensorIdToString(feed->id()); 112 PlaceholderInfo& info = placeholder_info[name_port]; 113 info.feed = feed; 114 info.placeholder_name = strings::StrCat( 115 "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); 116 (*feed_remapping)[name_port] = info.placeholder_name; 117 } 118 119 // Verify node exists and determine data type. 120 std::unordered_map<string, const NodeDef*> name_to_node; 121 for (int i = 0; i < graph_def->node_size(); ++i) { 122 name_to_node[graph_def->node(i).name()] = &graph_def->node(i); 123 } 124 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { 125 PlaceholderInfo& info = it->second; 126 const tf2xla::TensorId& feed_id = info.feed->id(); 127 128 // Find the existing node and determine data type. 129 auto node_it = name_to_node.find(feed_id.node_name()); 130 if (node_it == name_to_node.end()) { 131 return errors::NotFound("Can't find feed node: ", 132 TensorIdToString(feed_id)); 133 } 134 const NodeDef* existing = node_it->second; 135 136 if (info.feed->type() != DT_INVALID) { 137 info.data_type = info.feed->type(); 138 } else { 139 // Build the node in order to infer its type. 140 141 // Must first add default attrs as well, so do this in a copied GraphDef. 142 GraphDef gd; 143 *gd.mutable_versions() = graph_def->versions(); 144 *gd.add_node() = *existing; 145 TF_RETURN_IF_ERROR( 146 AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); 147 148 // Now build the node from the copied node def. 149 Graph g(op_registry); 150 g.set_versions(graph_def->versions()); 151 Status status; 152 Node* feed_node = g.AddNode(gd.node(0), &status); 153 TF_RETURN_IF_ERROR(status); 154 info.data_type = 155 BaseType(feed_node->output_type(info.feed->id().output_index())); 156 } 157 } 158 159 // Create placeholders. Note that we could avoid creating a placeholder for 160 // feeds which are already placeholders, but we omit that to avoid more cases 161 // in this code. 162 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { 163 const PlaceholderInfo& info = it->second; 164 NodeDef* d = graph_def->add_node(); 165 d->set_name(info.placeholder_name); 166 d->set_op("PlaceholderV2"); 167 auto& attr_map = *d->mutable_attr(); 168 attr_map["dtype"].set_type(info.data_type); 169 *attr_map["shape"].mutable_shape() = info.feed->shape(); 170 } 171 172 // Rewrite references to the fed tensors to refer to the placeholder. 173 for (int i = 0; i < graph_def->node_size(); ++i) { 174 NodeDef* node_def = graph_def->mutable_node(i); 175 for (int j = 0; j < node_def->input_size(); ++j) { 176 auto id = ParseTensorName(node_def->input(j)); 177 auto it = placeholder_info.find(id.ToString()); 178 if (it != placeholder_info.end()) { 179 node_def->set_input(j, it->second.placeholder_name); 180 } 181 } 182 } 183 184 return Status::OK(); 185 } 186 187 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, 188 GraphDef* out) { 189 *out = in; 190 out->clear_node(); 191 192 // Tensors needed for feeding. 193 std::set<std::pair<string, int>> feed_tensors; 194 for (const tf2xla::Feed& feed : config.feed()) { 195 feed_tensors.insert( 196 std::make_pair(feed.id().node_name(), feed.id().output_index())); 197 } 198 199 // Maps node name to reachability. 200 std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name; 201 for (const NodeDef& node : in.node()) { 202 node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node); 203 } 204 205 // Traverse. 206 std::queue<string> name_queue; 207 for (int i = 0; i < config.fetch_size(); ++i) { 208 name_queue.push(config.fetch(i).id().node_name()); 209 } 210 while (!name_queue.empty()) { 211 const string name = name_queue.front(); 212 name_queue.pop(); 213 214 auto find_it = node_by_name.find(name); 215 if (find_it == node_by_name.end()) { 216 return errors::InvalidArgument("While pruning graph, node ", name, 217 " needed but not found in the graph."); 218 } 219 auto& map_entry = find_it->second; 220 if (map_entry.first) { 221 continue; 222 } 223 map_entry.first = true; 224 225 // Push input nodes of the currently visited node to name_queue. 226 for (const string& in_edge : map_entry.second->input()) { 227 auto id = ParseTensorName(in_edge); 228 const string node_name = id.first.ToString(); 229 if (feed_tensors.find(std::make_pair(node_name, id.second)) == 230 feed_tensors.end()) { 231 name_queue.push(node_name); 232 } else { 233 // The input tensor is from an edge that is being fed. Therefore, 234 // we skip recursing down that edge, to avoid requiring nodes that 235 // may not be needed (note that the input node may still be added 236 // to name_queue later if one of its output edges is not being fed). 237 } 238 } 239 } 240 241 // Copy over, preserving order of original and only nodes that are reachable 242 // from the fetches. 243 out->mutable_node()->Reserve(in.node_size()); 244 for (const NodeDef& node : in.node()) { 245 if (node_by_name[node.name()].first) { 246 *out->add_node() = node; 247 } 248 } 249 return Status::OK(); 250 } 251 252 string TensorIdToString(const tf2xla::TensorId& id) { 253 return strings::StrCat(id.node_name(), ":", id.output_index()); 254 } 255 256 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { 257 int core = -1; 258 const Node* matching_node = nullptr; 259 for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) { 260 if (edge->IsControlEdge()) continue; 261 const Node* possible_match = out_edges ? edge->dst() : edge->src(); 262 TF_ASSIGN_OR_RETURN( 263 tensorflow::gtl::optional<xla::OpSharding> sharding, 264 ParseShardingFromDevice( 265 *possible_match, 266 /*num_cores_per_replica=*/std::numeric_limits<int32>::max())); 267 if (sharding.has_value()) { 268 TF_RET_CHECK(sharding.value().type() == 269 xla::OpSharding::Type::OpSharding_Type_MAXIMAL); 270 const int core_annotation = sharding.value().tile_assignment_devices(0); 271 if (core == -1 || core > core_annotation) { 272 core = core_annotation; 273 matching_node = possible_match; 274 } 275 } 276 } 277 if (matching_node != nullptr) { 278 n->set_assigned_device_name(matching_node->assigned_device_name()); 279 n->set_requested_device(matching_node->requested_device()); 280 } 281 return Status::OK(); 282 } 283 284 } // namespace tensorflow 285