Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
     17 #include "tensorflow/compiler/xla/client/computation_builder.h"
     18 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/tensor_shape.pb.h"
     21 
     22 namespace xla {
     23 namespace hlo_graph_dumper {
     24 namespace {
     25 
     26 using ::tensorflow::GraphDef;
     27 
     28 class HloTfGraphBuilderTest : public HloTestBase {
     29  protected:
     30   HloTfGraphBuilderTest() {}
     31   HloTfGraphBuilder generator_;
     32 
     33   // Create a computation which takes a scalar and returns its negation.
     34   std::unique_ptr<HloComputation> CreateNegateComputation() {
     35     auto builder = HloComputation::Builder("Negate");
     36     auto param = builder.AddInstruction(
     37         HloInstruction::CreateParameter(0, r0f32_, "param0"));
     38     builder.AddInstruction(
     39         HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
     40     return builder.Build();
     41   }
     42 
     43   // Creates a computation which calls map with the given computation.
     44   std::unique_ptr<HloComputation> CreateMapComputation(
     45       HloComputation *map_computation) {
     46     auto builder = HloComputation::Builder("Map");
     47     auto param = builder.AddInstruction(
     48         HloInstruction::CreateParameter(0, r0f32_, "param0"));
     49     builder.AddInstruction(
     50         HloInstruction::CreateMap(r0f32_, {param}, map_computation));
     51     return builder.Build();
     52   }
     53   Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
     54 };
     55 
     56 static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node,
     57                                                 const string &attr_name) {
     58   auto attr = node.attr().find(attr_name);
     59   CHECK(attr != node.attr().end());
     60   return attr->second;
     61 }
     62 
     63 TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
     64   auto builder = HloComputation::Builder("Concatenate");
     65   Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
     66   auto param_1 = builder.AddInstruction(
     67       HloInstruction::CreateParameter(0, shape, "param0"));
     68   auto param_2 = builder.AddInstruction(
     69       HloInstruction::CreateParameter(1, shape, "param1"));
     70   builder.AddInstruction(HloInstruction::CreateConcatenate(
     71       ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1));
     72   TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
     73   GraphDef graph_def = generator_.GetGraphDef();
     74   EXPECT_EQ(graph_def.node_size(), 3);
     75   const auto &node = graph_def.node(2);
     76   EXPECT_EQ(node.name(), "Concatenate/concatenate");
     77 
     78   // Check dimensions.
     79   auto dims_value = GetNodeAttr(node, "dims");
     80   EXPECT_EQ(dims_value.list().i_size(), 1);
     81   EXPECT_EQ(dims_value.list().i(0), 1);
     82 
     83   // Check shapes.
     84   auto shape_value = GetNodeAttr(node, "_output_shapes");
     85   EXPECT_EQ(shape_value.list().shape_size(), 1);
     86   EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2);
     87   EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2);
     88   EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4);
     89 }
     90 
     91 TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
     92   auto builder = HloComputation::Builder("Const");
     93   HloInstruction *instruction = builder.AddInstruction(
     94       HloInstruction::CreateConstant(Literal::CreateR0(123)));
     95   OpMetadata metadata;
     96   metadata.set_op_name("x");
     97   metadata.set_op_type("y");
     98   instruction->set_metadata(metadata);
     99   TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
    100   GraphDef graph_def = generator_.GetGraphDef();
    101   EXPECT_EQ(graph_def.node_size(), 1);
    102   const auto &node = graph_def.node(0);
    103   EXPECT_EQ(GetNodeAttr(node, "value").s(), "123");
    104   EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32");
    105   EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x");
    106   EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y");
    107 }
    108 
    109 TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
    110   auto negate_computation = CreateNegateComputation();
    111   TF_CHECK_OK(generator_.AddComputation(*negate_computation));
    112   GraphDef graph_def = generator_.GetGraphDef();
    113   EXPECT_EQ(graph_def.node_size(), 2);
    114   EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0");
    115   EXPECT_EQ(graph_def.node(0).op(), "HloParameter");
    116   EXPECT_EQ(graph_def.node(1).name(), "Negate/negate");
    117   EXPECT_EQ(graph_def.node(1).op(), "HloNegate");
    118   EXPECT_EQ(graph_def.node(1).input_size(), 1);
    119   EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0");
    120 }
    121 
    122 TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) {
    123   auto builder = HloComputation::Builder("GE");
    124   auto param_1 = builder.AddInstruction(
    125       HloInstruction::CreateParameter(0, r0f32_, "param0"));
    126   auto param_2 = builder.AddInstruction(
    127       HloInstruction::CreateParameter(1, r0f32_, "param1"));
    128   builder.AddInstruction(
    129       HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
    130   TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
    131   GraphDef graph_def = generator_.GetGraphDef();
    132   EXPECT_EQ(graph_def.node_size(), 3);
    133   EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
    134   EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
    135   EXPECT_EQ(graph_def.node(2).input_size(), 2);
    136   EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to");
    137   EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
    138 }
    139 
    140 TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) {
    141   auto builder = HloComputation::Builder("GE");
    142   auto param_1 = builder.AddInstruction(
    143       HloInstruction::CreateParameter(0, r0f32_, "param0"));
    144   auto param_2 = builder.AddInstruction(
    145       HloInstruction::CreateParameter(1, r0f32_, "param1"));
    146   auto ge = builder.AddInstruction(
    147       HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2));
    148   OpMetadata metadata;
    149   metadata.set_op_name("x/y");
    150   metadata.set_op_type("Y");
    151   ge->set_metadata(metadata);
    152   TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
    153   GraphDef graph_def = generator_.GetGraphDef();
    154   EXPECT_EQ(graph_def.node_size(), 3);
    155   EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0");
    156   EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1");
    157   EXPECT_EQ(graph_def.node(2).input_size(), 2);
    158   EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to");
    159   EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo");
    160 }
    161 
    162 TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
    163   // Create computations with a diamond-shaped callgraph.
    164   auto negate_computation = CreateNegateComputation();
    165   auto map1_computation = CreateMapComputation(negate_computation.get());
    166   auto map2_computation = CreateMapComputation(negate_computation.get());
    167 
    168   auto builder = HloComputation::Builder(TestName());
    169   auto param = builder.AddInstruction(
    170       HloInstruction::CreateParameter(0, r0f32_, "param0"));
    171   auto map1 = builder.AddInstruction(
    172       HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get()));
    173   auto map2 = builder.AddInstruction(
    174       HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get()));
    175   builder.AddInstruction(
    176       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
    177   auto computation = builder.Build();
    178   TF_CHECK_OK(generator_.AddComputation(*computation));
    179   EXPECT_GT(generator_.GetGraphDef().node_size(), 0);
    180 }
    181 
    182 }  // namespace
    183 }  // namespace hlo_graph_dumper
    184 }  // namespace xla
    185