1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 LIcensed under the Apache License, Version 2.0 (the "License"); 4 You may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" 17 #include "tensorflow/compiler/xla/layout_util.h" 18 #include "tensorflow/compiler/xla/literal_util.h" 19 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/tensor_shape.pb.h" 24 #include "tensorflow/core/lib/strings/str_util.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 27 using ::tensorflow::GraphDef; 28 using ::tensorflow::NodeDef; 29 using ::tensorflow::TensorShapeProto; 30 using ::tensorflow::strings::StrAppend; 31 using ::tensorflow::strings::StrCat; 32 using ::tensorflow::str_util::Join; 33 34 namespace xla { 35 namespace hlo_graph_dumper { 36 namespace { 37 38 string GetOpDefName(const HloInstruction* instruction) { 39 string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); 40 tensorflow::str_util::TitlecaseString(&name, "-"); 41 name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); 42 43 if (instruction->opcode() == HloOpcode::kFusion) { 44 string fusion_name = ToString(instruction->fusion_kind()); 45 StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); 46 } 47 return name; 48 } 49 50 TensorShapeProto GetTensorShape(const HloInstruction* instruction) { 51 TensorShapeProto tensor_shape; 52 const Shape& shape = instruction->shape(); 53 for (auto dim : shape.dimensions()) { 54 tensor_shape.add_dim()->set_size(dim); 55 } 56 return tensor_shape; 57 } 58 59 string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } 60 61 void CleanNodeName(string* name) { 62 name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); 63 const string chars_to_replace = "<>[]"; 64 auto pred = [&](char c) { 65 return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != 66 chars_to_replace.end(); 67 }; 68 std::replace_if(name->begin(), name->end(), pred, '_'); 69 } 70 71 } // namespace 72 73 HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) 74 : debug_options_(debug_options) {} 75 76 Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { 77 VLOG(2) << "Adding computation " << computation.name(); 78 for (auto embedded : computation.MakeEmbeddedComputationsList()) { 79 for (auto* instruction : embedded->instructions()) { 80 TF_RETURN_IF_ERROR(AddInstruction(instruction)); 81 } 82 } 83 for (auto* instruction : computation.instructions()) { 84 TF_RETURN_IF_ERROR(AddInstruction(instruction)); 85 } 86 return Status::OK(); 87 } 88 89 const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } 90 91 const string& HloTfGraphBuilder::GetNodeNameForInstruction( 92 const HloInstruction* instruction) { 93 if (ContainsKey(instruction_to_node_name_, instruction)) { 94 return instruction_to_node_name_[instruction]; 95 } 96 auto append = [](string* str, const string& other) { 97 if (str->empty()) { 98 *str = other; 99 } else if (!other.empty()) { 100 StrAppend(str, "/", other); 101 } 102 }; 103 string node_name; 104 if (debug_options_.xla_hlo_tfgraph_device_scopes() && 105 instruction->has_sharding() && 106 instruction->sharding().HasUniqueDevice()) { 107 node_name = StrCat( 108 "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); 109 } 110 // If an instruction is fused, put it in the subgraph of the fusion; 111 // otherwise, put it in the computation subgraph. 112 const HloComputation* computation = instruction->parent(); 113 if (computation->IsFusionComputation()) { 114 append(&node_name, 115 GetNodeNameForInstruction(computation->FusionInstruction())); 116 } else { 117 append(&node_name, computation->name()); 118 if (!instruction->metadata().op_name().empty()) { 119 // Always make computations contain TF ops but not the other way around. 120 append(&node_name, instruction->metadata().op_name()); 121 } 122 } 123 string instruction_name = instruction->name(); 124 if (instruction->opcode() == HloOpcode::kParameter) { 125 StrAppend(&instruction_name, ".", instruction->parameter_number()); 126 } 127 append(&node_name, instruction_name); 128 CleanNodeName(&node_name); 129 auto ret = 130 instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); 131 CHECK(ret.second); 132 return ret.first->second; 133 } 134 135 void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, 136 NodeDef* node_def) const { 137 auto& attrs = *node_def->mutable_attr(); 138 139 // Set the number of arguments for instructions that have variadic operands. 140 if (HloOpcodeIsVariadic(instruction->opcode())) { 141 tensorflow::AttrValue attr_value; 142 attr_value.set_i(instruction->operands().size()); 143 attrs["arg_num"] = attr_value; 144 } 145 146 // Set the node type. 147 attrs["type"].set_s( 148 xla::PrimitiveType_Name(instruction->shape().element_type())); 149 150 // Set the framework op (e.g. Tensorflow op) that generated this XLA op. 151 attrs["tf_op_type"].set_s(instruction->metadata().op_type()); 152 attrs["tf_op_name"].set_s(instruction->metadata().op_name()); 153 154 // Set the shape of the output tensor. "_output_shapes" is a special attribute 155 // name used by Tensorboard for shapes of output tensors. 156 tensorflow::AttrValue shapes; 157 *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); 158 attrs["_output_shapes"] = shapes; 159 160 // Set the layout. 161 if (LayoutUtil::HasLayout(instruction->shape())) { 162 string layout_string; 163 if (ShapeUtil::IsTuple(instruction->shape())) { 164 // For tuples, emit the full shape because the layout of a tuple is not 165 // represented in a single Layout field. 166 layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); 167 } else { 168 layout_string = StrCat( 169 "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); 170 } 171 attrs["layout"].set_s(layout_string); 172 } 173 174 // Set op-specific attributes. 175 switch (instruction->opcode()) { 176 case HloOpcode::kConcatenate: 177 case HloOpcode::kBroadcast: 178 case HloOpcode::kReduce: 179 case HloOpcode::kReverse: 180 case HloOpcode::kTranspose: 181 for (auto dim : instruction->dimensions()) { 182 attrs["dims"].mutable_list()->add_i(dim); 183 } 184 break; 185 case HloOpcode::kGetTupleElement: 186 attrs["index"].set_i(instruction->tuple_index()); 187 break; 188 case HloOpcode::kRng: 189 attrs["dist"].set_s( 190 RandomDistribution_Name(instruction->random_distribution())); 191 break; 192 case HloOpcode::kConstant: 193 if (ShapeUtil::IsScalar(instruction->shape())) { 194 attrs["value"].set_s(instruction->literal().GetAsString({})); 195 } 196 break; 197 case HloOpcode::kCustomCall: 198 attrs["custom_call_target"].set_s(instruction->custom_call_target()); 199 break; 200 case HloOpcode::kSend: 201 case HloOpcode::kRecv: 202 attrs["channel_id"].set_i(instruction->channel_id()); 203 break; 204 default: 205 break; 206 } 207 } 208 209 Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { 210 if (!visited_instructions_.insert(instruction).second) { 211 // Skip instructions that have already been added. 212 return Status::OK(); 213 } 214 215 NodeDef* node_def = graph_def_.add_node(); 216 node_def->set_name(GetNodeNameForInstruction(instruction)); 217 node_def->set_op(GetOpDefName(instruction)); 218 if (instruction->has_sharding() && 219 instruction->sharding().HasUniqueDevice()) { 220 TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); 221 node_def->set_device(GetDeviceName(device)); 222 } 223 SetNodeAttrs(instruction, node_def); 224 if (instruction->opcode() == HloOpcode::kFusion) { 225 for (auto* fused_instruction : instruction->fused_instructions()) { 226 TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); 227 } 228 } 229 // Add all edges including control edges. 230 for (unsigned i = 0; i < instruction->operands().size(); ++i) { 231 *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); 232 } 233 // Called computations are control dependencies. 234 for (const auto* called_computation : instruction->called_computations()) { 235 *node_def->add_input() = StrCat( 236 "^", GetNodeNameForInstruction(called_computation->root_instruction())); 237 } 238 return Status::OK(); 239 } 240 241 } // namespace hlo_graph_dumper 242 } // namespace xla 243