Home | History | Annotate | Download | only in jit
      1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
     17 
     18 #include "tensorflow/compiler/jit/xla_cluster_util.h"
     19 #include "tensorflow/compiler/xla/status_macros.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/tensor.pb.h"
     22 
     23 namespace tensorflow {
     24 
     25 using se::port::StatusOr;
     26 
     27 string CloneConstantsForBetterClusteringPass::GenerateUniqueName(
     28     const absl::flat_hash_set<string>& name_set, absl::string_view prefix) {
     29   string candidate;
     30   do {
     31     candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++);
     32   } while (name_set.contains(candidate));
     33   return candidate;
     34 }
     35 
     36 StatusOr<Node*> CloneConstantsForBetterClusteringPass::CloneNode(
     37     Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) {
     38   NodeDef new_in_def = n->def();
     39   new_in_def.clear_input();
     40   new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name()));
     41   Status s;
     42   Node* new_in = g->AddNode(new_in_def, &s);
     43   TF_RETURN_IF_ERROR(s);
     44 
     45   for (const Edge* e : n->in_edges()) {
     46     if (e->IsControlEdge()) {
     47       g->AddControlEdge(e->src(), new_in);
     48     } else {
     49       g->AddEdge(e->src(), e->src_output(), new_in, e->dst_input());
     50     }
     51   }
     52 
     53   new_in->set_assigned_device_name(n->assigned_device_name());
     54   return new_in;
     55 }
     56 
     57 namespace {
     58 // We only clone host constants for now since we want to avoid increasing memory
     59 // pressure on GPUs.
     60 StatusOr<bool> IsSmallHostConstant(Node* n) {
     61   if (!n->IsConstant()) {
     62     return false;
     63   }
     64 
     65   DeviceNameUtils::ParsedName parsed;
     66   TF_RET_CHECK(
     67       DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
     68   if (parsed.type != DEVICE_CPU) {
     69     return false;
     70   }
     71 
     72   const TensorProto* proto = nullptr;
     73   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
     74 
     75   // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large
     76   // constant" threshold, if there is one.
     77   const int kSmallTensorThreshold = 16;
     78   int64 total_elements = 1;
     79   for (const auto& dim : proto->tensor_shape().dim()) {
     80     if (dim.size() < 0) {
     81       return errors::Internal("Unknown dimension size in constant tensor ",
     82                               n->name());
     83     }
     84     total_elements *= dim.size();
     85   }
     86   return total_elements < kSmallTensorThreshold;
     87 }
     88 
     89 bool IsInPlaceOp(absl::string_view op_name) {
     90   return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
     91          op_name == "InplaceSub";
     92 }
     93 }  // namespace
     94 
     95 Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs(
     96     Graph* g, const absl::flat_hash_set<string>& name_set, Node* n) {
     97   std::vector<const Edge*> in_edges;
     98   absl::c_copy(n->in_edges(), std::back_inserter(in_edges));
     99   for (const Edge* e : in_edges) {
    100     Node* input = e->src();
    101     TF_ASSIGN_OR_RETURN(bool is_small_host_constant,
    102                         IsSmallHostConstant(input));
    103     if (is_small_host_constant && input->out_edges().size() != 1) {
    104       VLOG(2) << "Cloning small host constant " << input->name();
    105       TF_ASSIGN_OR_RETURN(Node* const input_cloned,
    106                           CloneNode(g, name_set, input));
    107       if (e->IsControlEdge()) {
    108         g->AddControlEdge(input_cloned, e->dst());
    109       } else {
    110         int dst_input = e->dst_input();
    111         TF_RET_CHECK(e->src_output() == 0)
    112             << "expected constant to have exactly one non-control output, but "
    113                "found output index = "
    114             << e->src_output();
    115         g->RemoveEdge(e);
    116         g->AddEdge(input_cloned, 0, n, dst_input);
    117       }
    118     }
    119   }
    120   return Status::OK();
    121 }
    122 
    123 Status CloneConstantsForBetterClusteringPass::Run(
    124     const GraphOptimizationPassOptions& options) {
    125   if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) {
    126     return Status::OK();
    127   }
    128 
    129   Graph* g = options.graph->get();
    130   absl::flat_hash_set<string> name_set;
    131   absl::c_transform(g->nodes(), std::inserter(name_set, name_set.begin()),
    132                     [](Node* n) { return n->name(); });
    133   std::vector<Node*> nodes;
    134   for (Node* n : g->nodes()) {
    135     // We rely on the immutability of Tensors to safely clone Const operations.
    136     // However, "in place" ops do not respect the immutability of Tensors so we
    137     // avoid this transformation when such ops are present in the graph.
    138     //
    139     // In-place operations are problematic because they break the semantic
    140     // illusion that tensorflow::Tensor instances are immutable.  For instance
    141     // if we have the following graph:
    142     //
    143     // digraph {
    144     //   SRC -> Const
    145     //   SRC -> I
    146     //   SRC -> V
    147     //   Const -> Identity
    148     //   Const -> InplaceAdd [label="x"]
    149     //   I -> InplaceAdd [label="i"]
    150     //   V -> InplaceAdd [label="v"]
    151     //   InplaceAdd -> Identity [style=dotted]
    152     // }
    153     //
    154     // then the value produced by `Identity` is Const+I*V since InplaceAdd
    155     // modifies the tensor in place.  However, if we clone `Const` and turn the
    156     // graph into:
    157     //
    158     // digraph {
    159     //   SRC -> "Const/clone_1"
    160     //   SRC -> "Const/clone_2"
    161     //   SRC -> I
    162     //   SRC -> V
    163     //   "Const/clone_1" -> Identity
    164     //   "Const/clone_2" -> InplaceAdd [label="x"]
    165     //   I -> InplaceAdd [label="i"]
    166     //   V -> InplaceAdd [label="v"]
    167     //   InplaceAdd -> Identity [style=dotted]
    168     // }
    169     //
    170     // then `Identity` no longer produces Const+I*V because the InplaceAdd
    171     // operation only modifies Const/clone_2 in place.
    172 
    173     if (IsInPlaceOp(n->type_string())) {
    174       return Status::OK();
    175     }
    176     nodes.push_back(n);
    177   }
    178 
    179   // Iterate over a copy of the nodes to avoid iterating over g->nodes() while
    180   // creating more nodes.
    181   for (Node* n : nodes) {
    182     TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(g, name_set, n));
    183   }
    184   return Status::OK();
    185 }
    186 
    187 }  // namespace tensorflow
    188