Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/tf2xla.h"
     17 
     18 #include <map>
     19 #include <memory>
     20 #include <string>
     21 #include <unordered_map>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/tf2xla/dump_graph.h"
     26 #include "tensorflow/compiler/tf2xla/shape_util.h"
     27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     28 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     30 #include "tensorflow/core/common_runtime/function.h"
     31 #include "tensorflow/core/framework/function.h"
     32 #include "tensorflow/core/framework/graph.pb.h"
     33 #include "tensorflow/core/framework/graph_def_util.h"
     34 #include "tensorflow/core/framework/op.h"
     35 #include "tensorflow/core/framework/tensor_shape.h"
     36 #include "tensorflow/core/framework/versions.pb.h"
     37 #include "tensorflow/core/graph/algorithm.h"
     38 #include "tensorflow/core/graph/graph.h"
     39 #include "tensorflow/core/graph/graph_constructor.h"
     40 #include "tensorflow/core/graph/node_builder.h"
     41 #include "tensorflow/core/lib/core/errors.h"
     42 #include "tensorflow/core/lib/strings/str_util.h"
     43 #include "tensorflow/core/lib/strings/strcat.h"
     44 #include "tensorflow/core/platform/logging.h"
     45 #include "tensorflow/core/platform/types.h"
     46 
     47 namespace tensorflow {
     48 
     49 const char* const kArgOp = "_Arg";
     50 const char* const kRetvalOp = "_Retval";
     51 const char* const kFeedIdAttr = "_feed_id";
     52 const char* const kFetchIdAttr = "_fetch_id";
     53 const char* const kShapeAttr = "_shape";
     54 const char* const kDebugNameAttr = "_debug_name";
     55 
     56 namespace {
     57 
     58 typedef std::unordered_map<string, Node*> NodeMap;
     59 
     60 // Each feed id identifies the positional output of some node, which may consist
     61 // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
     62 // tensor with a placeholder.  For each feed tensor, replaces all edges so they
     63 // point from a new _Arg node instead.
     64 Status AddArgNodes(Graph* graph, const NodeMap& node_map,
     65                    const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
     66                    const std::unordered_map<string, string>& feed_remapping) {
     67   for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
     68     const tf2xla::Feed& feed = feeds[arg_index];
     69     // All feeds have been replaced by placeholders.
     70     const int output_index = 0;
     71 
     72     const string key = TensorIdToString(feed.id());
     73     const auto remap_it = feed_remapping.find(key);
     74     auto node_it = node_map.find(remap_it->second);
     75     if (node_it == node_map.end()) {
     76       // Strip off the aot_feed_#/ prefix.
     77       StringPiece name(remap_it->second);
     78       const auto index = name.find('/');
     79       if (index > 0) name.remove_prefix(index + 1);
     80       return errors::InvalidArgument(
     81           "Node is fed but not needed for fetching: ", name);
     82     }
     83     const Node* feed_node = node_it->second;
     84 
     85     // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
     86     // "_shape" attr if we can determine it.  That way the graph will be
     87     // initialized with whatever shapes we can infer, while the user can still
     88     // explicitly specify or override them.
     89     Node* arg_node = nullptr;
     90     TF_RETURN_IF_ERROR(
     91         NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
     92             .Attr("T", BaseType(feed_node->output_type(output_index)))
     93             .Attr("index", arg_index)
     94             .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
     95             .Attr(kShapeAttr, TensorShape(feed.shape()))
     96             .Attr(kDebugNameAttr, feed.name())
     97             .Finalize(graph, &arg_node));
     98 
     99     // Collects out-edges from the feed node that have a matching edge index;
    100     // these will be replaced with edges from the arg node instead.
    101     //
    102     // We must collect the edges first and process them in a second pass, since
    103     // removing the edge from the graph invalidates feed_node->out_edges.
    104     std::vector<const Edge*> feed_edges;
    105     for (const Edge* edge : feed_node->out_edges()) {
    106       if (edge->src_output() == output_index) {
    107         feed_edges.push_back(edge);
    108       }
    109     }
    110     for (const Edge* edge : feed_edges) {
    111       graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
    112       graph->RemoveEdge(edge);
    113     }
    114   }
    115   return Status::OK();
    116 }
    117 
    118 // Each fetch id identifies the positional output of some node.  For each fetch
    119 // node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
    120 Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
    121                       const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
    122                       std::unordered_set<const Node*>* retval_nodes) {
    123   for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
    124     const tf2xla::TensorId& id = fetches[ret_index].id();
    125     auto it = node_map.find(id.node_name());
    126     if (it == node_map.end()) {
    127       return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
    128     }
    129     Node* fetch_node = it->second;
    130     if (id.output_index() >= fetch_node->num_outputs()) {
    131       return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
    132                                      ", output index should be < ",
    133                                      fetch_node->num_outputs());
    134     }
    135     // Connects fetch_node -> retval_node.
    136     Node* retval_node = nullptr;
    137     TF_RETURN_IF_ERROR(
    138         NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
    139             .Input(fetch_node, id.output_index())
    140             .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
    141             .Attr("index", ret_index)
    142             .Attr(kFetchIdAttr, TensorIdToString(id))
    143             .Finalize(graph, &retval_node));
    144     retval_nodes->insert(retval_node);
    145   }
    146   return Status::OK();
    147 }
    148 
    149 // RewriteAndPruneGraph identifies input and output edges (named by the feed and
    150 // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
    151 // nodes, and outputs flow to _Retval nodes.  This allows the symbolic graph
    152 // execution to know the input and output args for the generated function.
    153 Status RewriteAndPruneGraph(
    154     Graph* graph, const tf2xla::Config& config,
    155     const std::unordered_map<string, string>& feed_remapping) {
    156   NodeMap node_map;
    157   for (Node* n : graph->nodes()) {
    158     node_map[n->name()] = n;
    159   }
    160   TF_RETURN_IF_ERROR(
    161       AddArgNodes(graph, node_map, config.feed(), feed_remapping));
    162   std::unordered_set<const Node*> retval_nodes;
    163   TF_RETURN_IF_ERROR(
    164       AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
    165   VLOG(2) << "Post rewrite: "
    166           << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph);
    167   PruneForReverseReachability(graph, retval_nodes);
    168   FixupSourceAndSinkEdges(graph);
    169   VLOG(2) << "Post prune: "
    170           << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph);
    171   // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
    172   std::set<string> missing_feeds, missing_fetches;
    173   for (const tf2xla::Feed& feed : config.feed()) {
    174     missing_feeds.insert(TensorIdToString(feed.id()));
    175   }
    176   for (const tf2xla::Fetch& fetch : config.fetch()) {
    177     missing_fetches.insert(TensorIdToString(fetch.id()));
    178   }
    179   for (const Node* n : graph->op_nodes()) {
    180     if (n->type_string() == kArgOp) {
    181       string feed_id;
    182       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
    183       if (missing_feeds.erase(feed_id) == 0) {
    184         return errors::Aborted(kArgOp,
    185                                " node found with unknown feed id: ", feed_id);
    186       }
    187     } else if (n->type_string() == kRetvalOp) {
    188       string fetch_id;
    189       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
    190       if (missing_fetches.erase(fetch_id) == 0) {
    191         return errors::Aborted(kRetvalOp,
    192                                " node found with unknown fetch id: ", fetch_id);
    193       }
    194     }
    195   }
    196   if (!missing_feeds.empty() || !missing_fetches.empty()) {
    197     return errors::Aborted(
    198         "Post graph-pruning",
    199         ", missing feeds: ", str_util::Join(missing_feeds, ", "),
    200         ", missing fetches: ", str_util::Join(missing_fetches, ", "));
    201   }
    202   return Status::OK();
    203 }
    204 
    205 // CollectArgNodes collects _Arg nodes from the graph, and performs basic
    206 // sanity-checking to ensure the index and type attributes of each node are
    207 // initialized correctly.
    208 Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
    209   std::map<int, Node*> indexed_arg_nodes;
    210   for (Node* n : graph.nodes()) {
    211     if (n->type_string() == kArgOp) {
    212       int index;
    213       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
    214       auto insert_result = indexed_arg_nodes.insert({index, n});
    215       if (!insert_result.second) {
    216         const Node* dup = insert_result.first->second;
    217         return errors::InvalidArgument(
    218             "Multiple ", kArgOp, " nodes with index ", index, ", ",
    219             n->DebugString(), " and ", dup->DebugString());
    220       }
    221     }
    222   }
    223   arg_nodes->clear();
    224   for (const auto& index_node : indexed_arg_nodes) {
    225     if (index_node.first != arg_nodes->size()) {
    226       return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
    227                                      arg_nodes->size(), ", but got index ",
    228                                      index_node.first);
    229     }
    230     arg_nodes->push_back(index_node.second);
    231   }
    232   return Status::OK();
    233 }
    234 
    235 // Fills in xla_args from the corresponding _Arg nodes in the graph.
    236 Status CreateXlaArgs(const Graph& graph,
    237                      std::vector<XlaCompiler::Argument>* xla_args) {
    238   std::vector<Node*> arg_nodes;
    239   TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
    240   for (const Node* node : arg_nodes) {
    241     XlaCompiler::Argument arg;
    242     arg.kind = XlaCompiler::Argument::kParameter;
    243     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
    244     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
    245     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
    246     xla_args->push_back(arg);
    247   }
    248   return Status::OK();
    249 }
    250 
    251 // Converts the TensorFlow graph into an XLA computation, by executing the
    252 // graph symbolically, with each op building up the XLA HLO.
    253 Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
    254                          xla::Computation* computation) {
    255   XlaOpRegistry::RegisterCompilationKernels();
    256   for (Node* node : graph->nodes()) {
    257     node->set_assigned_device_name(
    258         strings::StrCat("/device:", DEVICE_CPU_XLA_JIT));
    259   }
    260   std::vector<XlaCompiler::Argument> xla_args;
    261   TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
    262 
    263   // Compile the graph into an XLA computation.
    264   XlaCompiler::Options compiler_options;
    265   compiler_options.client = client;
    266   DeviceType device_type(DEVICE_CPU_XLA_JIT);
    267   compiler_options.device_type = &device_type;
    268   compiler_options.flib_def = &graph->flib_def();
    269   compiler_options.graph_def_version = graph->versions().producer();
    270   compiler_options.allow_cpu_custom_calls = true;
    271   XlaCompiler compiler(compiler_options);
    272 
    273   XlaCompiler::CompilationResult result;
    274   TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
    275                                            "tfcompile", std::move(graph),
    276                                            xla_args, &result));
    277   *computation = std::move(*result.computation);
    278 
    279   int num_const_results = 0;
    280   for (int i = 0; i < result.outputs.size(); ++i) {
    281     // Ending up with const results (i.e. output args) is an error, since it
    282     // means that one or more fetches that the user specified will be dropped
    283     // from the generated function.  It's most likely a configuration error,
    284     // since the user shouldn't be asking for output args that end up as consts.
    285     //
    286     // TODO(toddw): Provide a way for the user to access const output args,
    287     // e.g. perhaps hard-coded into the header, or somehow copied into the
    288     // output buffers.
    289     if (result.outputs[i].is_constant) {
    290       ++num_const_results;
    291       LOG(ERROR) << "ConstRetVal index:" << i
    292                  << " value:" << result.outputs[i].constant_value.DebugString();
    293     }
    294   }
    295   if (num_const_results > 0) {
    296     return errors::Unimplemented(
    297         "Conversion from TensorFlow graph to XLA resulted in ",
    298         num_const_results,
    299         " constant results.  The configuration of "
    300         "the output args (i.e. fetch ids) is probably wrong.");
    301   }
    302   return Status::OK();
    303 }
    304 
    305 // InitGraph creates a graph based on the graph_def, that may then be converted
    306 // to an xla::Computation via ConvertGraphToXla.
    307 //
    308 // The graph is rewritten with _Arg and _Retval nodes, representing the inputs
    309 // and outputs of the function that will be compiled.  Each feed id causes a new
    310 // _Arg node to be created, where we first collect all existing edges pointing
    311 // from the named node's output index, and then rewrite them to point from that
    312 // _Arg node instead.  Each fetch id causes a new _Retval node to be created,
    313 // with a new edge pointing from the named node's output index to that _Retval
    314 // node.
    315 Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
    316                  std::unique_ptr<Graph>* graph) {
    317   TF_RETURN_IF_ERROR(ValidateConfig(config));
    318 
    319   FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
    320   std::unique_ptr<Graph> g(new Graph(flib_def));
    321 
    322   // Replace references to fed tensors with references to newly added
    323   // placeholders.
    324   GraphDef first_copy_def = graph_def;
    325 
    326   // Maps from name:port of a feed to the name:port of the placeholder to use.
    327   std::unordered_map<string, string> feed_remapping;
    328   TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
    329                                              &feed_remapping, &first_copy_def));
    330 
    331   // Prune the GraphDef first so that unknown ops that we aren't compiling get
    332   // filtered out.
    333   GraphDef second_copy_def;
    334   TF_RETURN_IF_ERROR(
    335       PruneGraphDefInto(config, first_copy_def, &second_copy_def));
    336 
    337   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
    338       &second_copy_def, *g->op_registry(), /*node_offset=*/0));
    339 
    340   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
    341                                             second_copy_def, g.get()));
    342   TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
    343   *graph = std::move(g);
    344   return Status::OK();
    345 }
    346 
    347 }  // namespace
    348 
    349 Status ConvertGraphDefToXla(const GraphDef& graph_def,
    350                             const tf2xla::Config& config, xla::Client* client,
    351                             xla::Computation* computation) {
    352   std::unique_ptr<Graph> graph;
    353   TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
    354   TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation));
    355   return Status::OK();
    356 }
    357 
    358 }  // namespace tensorflow
    359