Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 LIcensed under the Apache License, Version 2.0 (the "License");
      4 You may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
     17 #include "tensorflow/compiler/xla/layout_util.h"
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     20 #include "tensorflow/compiler/xla/shape_util.h"
     21 #include "tensorflow/core/framework/attr_value.pb.h"
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/tensor_shape.pb.h"
     24 #include "tensorflow/core/lib/strings/str_util.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 
     27 using ::tensorflow::GraphDef;
     28 using ::tensorflow::NodeDef;
     29 using ::tensorflow::TensorShapeProto;
     30 using ::tensorflow::strings::StrAppend;
     31 using ::tensorflow::strings::StrCat;
     32 using ::tensorflow::str_util::Join;
     33 
     34 namespace xla {
     35 namespace hlo_graph_dumper {
     36 namespace {
     37 
     38 string GetOpDefName(const HloInstruction* instruction) {
     39   string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
     40   tensorflow::str_util::TitlecaseString(&name, "-");
     41   name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
     42 
     43   if (instruction->opcode() == HloOpcode::kFusion) {
     44     string fusion_name = ToString(instruction->fusion_kind());
     45     StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
     46   }
     47   return name;
     48 }
     49 
     50 TensorShapeProto GetTensorShape(const HloInstruction* instruction) {
     51   TensorShapeProto tensor_shape;
     52   const Shape& shape = instruction->shape();
     53   for (auto dim : shape.dimensions()) {
     54     tensor_shape.add_dim()->set_size(dim);
     55   }
     56   return tensor_shape;
     57 }
     58 
     59 string GetDeviceName(int device) { return StrCat("/device/XLA:", device); }
     60 
     61 void CleanNodeName(string* name) {
     62   name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
     63   const string chars_to_replace = "<>[]";
     64   auto pred = [&](char c) {
     65     return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
     66            chars_to_replace.end();
     67   };
     68   std::replace_if(name->begin(), name->end(), pred, '_');
     69 }
     70 
     71 }  // namespace
     72 
     73 HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options)
     74     : debug_options_(debug_options) {}
     75 
     76 Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
     77   VLOG(2) << "Adding computation " << computation.name();
     78   for (auto embedded : computation.MakeEmbeddedComputationsList()) {
     79     for (auto* instruction : embedded->instructions()) {
     80       TF_RETURN_IF_ERROR(AddInstruction(instruction));
     81     }
     82   }
     83   for (auto* instruction : computation.instructions()) {
     84     TF_RETURN_IF_ERROR(AddInstruction(instruction));
     85   }
     86   return Status::OK();
     87 }
     88 
     89 const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; }
     90 
     91 const string& HloTfGraphBuilder::GetNodeNameForInstruction(
     92     const HloInstruction* instruction) {
     93   if (ContainsKey(instruction_to_node_name_, instruction)) {
     94     return instruction_to_node_name_[instruction];
     95   }
     96   auto append = [](string* str, const string& other) {
     97     if (str->empty()) {
     98       *str = other;
     99     } else if (!other.empty()) {
    100       StrAppend(str, "/", other);
    101     }
    102   };
    103   string node_name;
    104   if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
    105       instruction->has_sharding() &&
    106       instruction->sharding().HasUniqueDevice()) {
    107     node_name = StrCat(
    108         "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
    109   }
    110   // If an instruction is fused, put it in the subgraph of the fusion;
    111   // otherwise, put it in the computation subgraph.
    112   const HloComputation* computation = instruction->parent();
    113   if (computation->IsFusionComputation()) {
    114     append(&node_name,
    115            GetNodeNameForInstruction(computation->FusionInstruction()));
    116   } else {
    117     append(&node_name, computation->name());
    118     if (!instruction->metadata().op_name().empty()) {
    119       // Always make computations contain TF ops but not the other way around.
    120       append(&node_name, instruction->metadata().op_name());
    121     }
    122   }
    123   string instruction_name = instruction->name();
    124   if (instruction->opcode() == HloOpcode::kParameter) {
    125     StrAppend(&instruction_name, ".", instruction->parameter_number());
    126   }
    127   append(&node_name, instruction_name);
    128   CleanNodeName(&node_name);
    129   auto ret =
    130       instruction_to_node_name_.insert(std::make_pair(instruction, node_name));
    131   CHECK(ret.second);
    132   return ret.first->second;
    133 }
    134 
    135 void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
    136                                      NodeDef* node_def) const {
    137   auto& attrs = *node_def->mutable_attr();
    138 
    139   // Set the number of arguments for instructions that have variadic operands.
    140   if (HloOpcodeIsVariadic(instruction->opcode())) {
    141     tensorflow::AttrValue attr_value;
    142     attr_value.set_i(instruction->operands().size());
    143     attrs["arg_num"] = attr_value;
    144   }
    145 
    146   // Set the node type.
    147   attrs["type"].set_s(
    148       xla::PrimitiveType_Name(instruction->shape().element_type()));
    149 
    150   // Set the framework op (e.g. Tensorflow op) that generated this XLA op.
    151   attrs["tf_op_type"].set_s(instruction->metadata().op_type());
    152   attrs["tf_op_name"].set_s(instruction->metadata().op_name());
    153 
    154   // Set the shape of the output tensor. "_output_shapes" is a special attribute
    155   // name used by Tensorboard for shapes of output tensors.
    156   tensorflow::AttrValue shapes;
    157   *shapes.mutable_list()->add_shape() = GetTensorShape(instruction);
    158   attrs["_output_shapes"] = shapes;
    159 
    160   // Set the layout.
    161   if (LayoutUtil::HasLayout(instruction->shape())) {
    162     string layout_string;
    163     if (ShapeUtil::IsTuple(instruction->shape())) {
    164       // For tuples, emit the full shape because the layout of a tuple is not
    165       // represented in a single Layout field.
    166       layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
    167     } else {
    168       layout_string = StrCat(
    169           "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}");
    170     }
    171     attrs["layout"].set_s(layout_string);
    172   }
    173 
    174   // Set op-specific attributes.
    175   switch (instruction->opcode()) {
    176     case HloOpcode::kConcatenate:
    177     case HloOpcode::kBroadcast:
    178     case HloOpcode::kReduce:
    179     case HloOpcode::kReverse:
    180     case HloOpcode::kTranspose:
    181       for (auto dim : instruction->dimensions()) {
    182         attrs["dims"].mutable_list()->add_i(dim);
    183       }
    184       break;
    185     case HloOpcode::kGetTupleElement:
    186       attrs["index"].set_i(instruction->tuple_index());
    187       break;
    188     case HloOpcode::kRng:
    189       attrs["dist"].set_s(
    190           RandomDistribution_Name(instruction->random_distribution()));
    191       break;
    192     case HloOpcode::kConstant:
    193       if (ShapeUtil::IsScalar(instruction->shape())) {
    194         attrs["value"].set_s(instruction->literal().GetAsString({}));
    195       }
    196       break;
    197     case HloOpcode::kCustomCall:
    198       attrs["custom_call_target"].set_s(instruction->custom_call_target());
    199       break;
    200     case HloOpcode::kSend:
    201     case HloOpcode::kRecv:
    202       attrs["channel_id"].set_i(instruction->channel_id());
    203       break;
    204     default:
    205       break;
    206   }
    207 }
    208 
    209 Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
    210   if (!visited_instructions_.insert(instruction).second) {
    211     // Skip instructions that have already been added.
    212     return Status::OK();
    213   }
    214 
    215   NodeDef* node_def = graph_def_.add_node();
    216   node_def->set_name(GetNodeNameForInstruction(instruction));
    217   node_def->set_op(GetOpDefName(instruction));
    218   if (instruction->has_sharding() &&
    219       instruction->sharding().HasUniqueDevice()) {
    220     TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
    221     node_def->set_device(GetDeviceName(device));
    222   }
    223   SetNodeAttrs(instruction, node_def);
    224   if (instruction->opcode() == HloOpcode::kFusion) {
    225     for (auto* fused_instruction : instruction->fused_instructions()) {
    226       TF_RETURN_IF_ERROR(AddInstruction(fused_instruction));
    227     }
    228   }
    229   // Add all edges including control edges.
    230   for (unsigned i = 0; i < instruction->operands().size(); ++i) {
    231     *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i));
    232   }
    233   // Called computations are control dependencies.
    234   for (const auto* called_computation : instruction->called_computations()) {
    235     *node_def->add_input() = StrCat(
    236         "^", GetNodeNameForInstruction(called_computation->root_instruction()));
    237   }
    238   return Status::OK();
    239 }
    240 
    241 }  // namespace hlo_graph_dumper
    242 }  // namespace xla
    243