Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Example HLO graph which demonstrates Graphviz dumper for HLO
     17 // computations. When run, pushes the example DOT graph to the Graphviz service
     18 // and prints the URL. Useful for seeing effect of changes to the graph
     19 // generation code.
     20 
     21 #include <stdio.h>
     22 #include <memory>
     23 #include <string>
     24 
     25 #include "tensorflow/compiler/xla/literal_util.h"
     26 #include "tensorflow/compiler/xla/ptr_util.h"
     27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     28 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/shape_util.h"
     33 #include "tensorflow/compiler/xla/types.h"
     34 #include "tensorflow/compiler/xla/xla_data.pb.h"
     35 #include "tensorflow/core/lib/strings/strcat.h"
     36 #include "tensorflow/core/platform/init_main.h"
     37 #include "tensorflow/core/platform/types.h"
     38 
     39 namespace xla {
     40 namespace {
     41 
     42 // Adds a computation to the given HLO module which adds a scalar constant to
     43 // its parameter and returns the result.
     44 HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
     45   auto builder =
     46       HloComputation::Builder(tensorflow::strings::StrCat("add_", addend));
     47   auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
     48       0, ShapeUtil::MakeShape(F32, {}), "x_value"));
     49   auto half = builder.AddInstruction(
     50       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.5)));
     51   builder.AddInstruction(HloInstruction::CreateBinary(
     52       half->shape(), HloOpcode::kAdd, x_value, half));
     53   return module->AddEmbeddedComputation(builder.Build());
     54 }
     55 
     56 // Adds a computation to the given HLO module which sums its two parameters and
     57 // returns the result.
     58 HloComputation* ScalarSumComputation(HloModule* module) {
     59   auto builder = HloComputation::Builder("add");
     60   auto lhs = builder.AddInstruction(
     61       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"));
     62   auto rhs = builder.AddInstruction(
     63       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "rhs"));
     64   builder.AddInstruction(
     65       HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
     66   return module->AddEmbeddedComputation(builder.Build());
     67 }
     68 
     69 // Adds a computation to the given HLO module which forwards its argument to a
     70 // kCall instruction which then calls the given computation.
     71 HloComputation* CallForwardingComputation(HloComputation* computation,
     72                                           HloModule* module) {
     73   auto builder = HloComputation::Builder("call_forward");
     74   auto arg = builder.AddInstruction(
     75       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "arg"));
     76   builder.AddInstruction(
     77       HloInstruction::CreateCall(arg->shape(), {arg}, computation));
     78   return module->AddEmbeddedComputation(builder.Build());
     79 }
     80 
     81 // Create a large, arbitrary computation with many different kinds of
     82 // instructions. Sets the computation as the entry to an HLO module and returns
     83 // the module.
     84 std::unique_ptr<HloModule> MakeBigGraph() {
     85   auto module = MakeUnique<HloModule>("BigGraph");
     86 
     87   auto builder = HloComputation::Builder("TestBigGraphvizGraph");
     88 
     89   // Shapes used in the computation.
     90   auto mshape = ShapeUtil::MakeShape(F32, {3, 5});
     91   auto vshape = ShapeUtil::MakeShape(F32, {3});
     92   auto sshape = ShapeUtil::MakeShape(F32, {3});
     93 
     94   // Create a set of parameter instructions.
     95   auto param_v0 =
     96       builder.AddInstruction(HloInstruction::CreateParameter(0, vshape, "foo"));
     97   auto param_v1 =
     98       builder.AddInstruction(HloInstruction::CreateParameter(1, vshape, "bar"));
     99   auto param_v2 =
    100       builder.AddInstruction(HloInstruction::CreateParameter(2, vshape, "baz"));
    101   auto param_s =
    102       builder.AddInstruction(HloInstruction::CreateParameter(3, sshape, "qux"));
    103   auto param_m =
    104       builder.AddInstruction(HloInstruction::CreateParameter(4, mshape, "zzz"));
    105 
    106   // Add an arbitrary expression of different instructions.
    107   auto copy = builder.AddInstruction(
    108       HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0));
    109   auto clamp = builder.AddInstruction(HloInstruction::CreateTernary(
    110       vshape, HloOpcode::kClamp, copy, param_v1, param_v2));
    111   DotDimensionNumbers dot_dnums;
    112   dot_dnums.add_lhs_contracting_dimensions(1);
    113   dot_dnums.add_rhs_contracting_dimensions(0);
    114   auto dot = builder.AddInstruction(
    115       HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
    116   auto tuple = builder.AddInstruction(
    117       HloInstruction::CreateTuple({dot, param_s, clamp}));
    118   auto scalar = builder.AddInstruction(
    119       HloInstruction::CreateGetTupleElement(sshape, tuple, 2));
    120   auto add_one = AddScalarConstantComputation(1.0, module.get());
    121   auto rng = builder.AddInstruction(
    122       HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m}));
    123   auto one = builder.AddInstruction(
    124       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    125   auto add_computation = ScalarSumComputation(module.get());
    126   builder.AddInstruction(
    127       HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation));
    128   auto map1 = builder.AddInstruction(
    129       HloInstruction::CreateMap(sshape, {scalar}, add_one));
    130   auto map2 = builder.AddInstruction(
    131       HloInstruction::CreateMap(sshape, {map1}, add_one));
    132   auto map3 = builder.AddInstruction(
    133       HloInstruction::CreateMap(sshape, {map2}, add_one));
    134 
    135   // Create a fusion instruction containing the chain of map instructions.
    136   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
    137       sshape, HloInstruction::FusionKind::kLoop, map3));
    138   fusion->FuseInstruction(map2);
    139   fusion->FuseInstruction(map1);
    140 
    141   // Add a random trace instruction.
    142   builder.AddInstruction(HloInstruction::CreateTrace("trace", dot));
    143 
    144   // Add a call instruction will calls the call-forwarding computation to call
    145   // another computation.
    146   auto call_computation = CallForwardingComputation(add_one, module.get());
    147   builder.AddInstruction(
    148       HloInstruction::CreateCall(fusion->shape(), {fusion}, call_computation));
    149 
    150   module->AddEntryComputation(builder.Build());
    151   return module;
    152 }
    153 
    154 }  // namespace
    155 }  // namespace xla
    156 
    157 int main(int argc, char** argv) {
    158   tensorflow::port::InitMain(argv[0], &argc, &argv);
    159 
    160   auto module = xla::MakeBigGraph();
    161 
    162   printf("Graph URL: %s\n", xla::hlo_graph_dumper::DumpGraph(
    163                                 *module->entry_computation(),
    164                                 "Example computation", xla::DebugOptions())
    165                                 .c_str());
    166   return 0;
    167 }
    168