1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Example HLO graph which demonstrates Graphviz dumper for HLO 17 // computations. When run, pushes the example DOT graph to the Graphviz service 18 // and prints the URL. Useful for seeing effect of changes to the graph 19 // generation code. 20 21 #include <stdio.h> 22 #include <memory> 23 #include <string> 24 25 #include "tensorflow/compiler/xla/literal_util.h" 26 #include "tensorflow/compiler/xla/ptr_util.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 32 #include "tensorflow/compiler/xla/shape_util.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/core/lib/strings/strcat.h" 36 #include "tensorflow/core/platform/init_main.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace xla { 40 namespace { 41 42 // Adds a computation to the given HLO module which adds a scalar constant to 43 // its parameter and returns the result. 44 HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { 45 auto builder = 46 HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); 47 auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 48 0, ShapeUtil::MakeShape(F32, {}), "x_value")); 49 auto half = builder.AddInstruction( 50 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.5))); 51 builder.AddInstruction(HloInstruction::CreateBinary( 52 half->shape(), HloOpcode::kAdd, x_value, half)); 53 return module->AddEmbeddedComputation(builder.Build()); 54 } 55 56 // Adds a computation to the given HLO module which sums its two parameters and 57 // returns the result. 58 HloComputation* ScalarSumComputation(HloModule* module) { 59 auto builder = HloComputation::Builder("add"); 60 auto lhs = builder.AddInstruction( 61 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "lhs")); 62 auto rhs = builder.AddInstruction( 63 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "rhs")); 64 builder.AddInstruction( 65 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); 66 return module->AddEmbeddedComputation(builder.Build()); 67 } 68 69 // Adds a computation to the given HLO module which forwards its argument to a 70 // kCall instruction which then calls the given computation. 71 HloComputation* CallForwardingComputation(HloComputation* computation, 72 HloModule* module) { 73 auto builder = HloComputation::Builder("call_forward"); 74 auto arg = builder.AddInstruction( 75 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "arg")); 76 builder.AddInstruction( 77 HloInstruction::CreateCall(arg->shape(), {arg}, computation)); 78 return module->AddEmbeddedComputation(builder.Build()); 79 } 80 81 // Create a large, arbitrary computation with many different kinds of 82 // instructions. Sets the computation as the entry to an HLO module and returns 83 // the module. 84 std::unique_ptr<HloModule> MakeBigGraph() { 85 auto module = MakeUnique<HloModule>("BigGraph"); 86 87 auto builder = HloComputation::Builder("TestBigGraphvizGraph"); 88 89 // Shapes used in the computation. 90 auto mshape = ShapeUtil::MakeShape(F32, {3, 5}); 91 auto vshape = ShapeUtil::MakeShape(F32, {3}); 92 auto sshape = ShapeUtil::MakeShape(F32, {3}); 93 94 // Create a set of parameter instructions. 95 auto param_v0 = 96 builder.AddInstruction(HloInstruction::CreateParameter(0, vshape, "foo")); 97 auto param_v1 = 98 builder.AddInstruction(HloInstruction::CreateParameter(1, vshape, "bar")); 99 auto param_v2 = 100 builder.AddInstruction(HloInstruction::CreateParameter(2, vshape, "baz")); 101 auto param_s = 102 builder.AddInstruction(HloInstruction::CreateParameter(3, sshape, "qux")); 103 auto param_m = 104 builder.AddInstruction(HloInstruction::CreateParameter(4, mshape, "zzz")); 105 106 // Add an arbitrary expression of different instructions. 107 auto copy = builder.AddInstruction( 108 HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); 109 auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( 110 vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); 111 DotDimensionNumbers dot_dnums; 112 dot_dnums.add_lhs_contracting_dimensions(1); 113 dot_dnums.add_rhs_contracting_dimensions(0); 114 auto dot = builder.AddInstruction( 115 HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); 116 auto tuple = builder.AddInstruction( 117 HloInstruction::CreateTuple({dot, param_s, clamp})); 118 auto scalar = builder.AddInstruction( 119 HloInstruction::CreateGetTupleElement(sshape, tuple, 2)); 120 auto add_one = AddScalarConstantComputation(1.0, module.get()); 121 auto rng = builder.AddInstruction( 122 HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); 123 auto one = builder.AddInstruction( 124 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 125 auto add_computation = ScalarSumComputation(module.get()); 126 builder.AddInstruction( 127 HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); 128 auto map1 = builder.AddInstruction( 129 HloInstruction::CreateMap(sshape, {scalar}, add_one)); 130 auto map2 = builder.AddInstruction( 131 HloInstruction::CreateMap(sshape, {map1}, add_one)); 132 auto map3 = builder.AddInstruction( 133 HloInstruction::CreateMap(sshape, {map2}, add_one)); 134 135 // Create a fusion instruction containing the chain of map instructions. 136 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( 137 sshape, HloInstruction::FusionKind::kLoop, map3)); 138 fusion->FuseInstruction(map2); 139 fusion->FuseInstruction(map1); 140 141 // Add a random trace instruction. 142 builder.AddInstruction(HloInstruction::CreateTrace("trace", dot)); 143 144 // Add a call instruction will calls the call-forwarding computation to call 145 // another computation. 146 auto call_computation = CallForwardingComputation(add_one, module.get()); 147 builder.AddInstruction( 148 HloInstruction::CreateCall(fusion->shape(), {fusion}, call_computation)); 149 150 module->AddEntryComputation(builder.Build()); 151 return module; 152 } 153 154 } // namespace 155 } // namespace xla 156 157 int main(int argc, char** argv) { 158 tensorflow::port::InitMain(argv[0], &argc, &argv); 159 160 auto module = xla::MakeBigGraph(); 161 162 printf("Graph URL: %s\n", xla::hlo_graph_dumper::DumpGraph( 163 *module->entry_computation(), 164 "Example computation", xla::DebugOptions()) 165 .c_str()); 166 return 0; 167 } 168