1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" 17 #include "tensorflow/compiler/xla/client/computation_builder.h" 18 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 19 #include "tensorflow/core/framework/attr_value.pb.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 22 namespace xla { 23 namespace hlo_graph_dumper { 24 namespace { 25 26 using ::tensorflow::GraphDef; 27 28 class HloTfGraphBuilderTest : public HloTestBase { 29 protected: 30 HloTfGraphBuilderTest() {} 31 HloTfGraphBuilder generator_; 32 33 // Create a computation which takes a scalar and returns its negation. 34 std::unique_ptr<HloComputation> CreateNegateComputation() { 35 auto builder = HloComputation::Builder("Negate"); 36 auto param = builder.AddInstruction( 37 HloInstruction::CreateParameter(0, r0f32_, "param0")); 38 builder.AddInstruction( 39 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); 40 return builder.Build(); 41 } 42 43 // Creates a computation which calls map with the given computation. 44 std::unique_ptr<HloComputation> CreateMapComputation( 45 HloComputation *map_computation) { 46 auto builder = HloComputation::Builder("Map"); 47 auto param = builder.AddInstruction( 48 HloInstruction::CreateParameter(0, r0f32_, "param0")); 49 builder.AddInstruction( 50 HloInstruction::CreateMap(r0f32_, {param}, map_computation)); 51 return builder.Build(); 52 } 53 Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); 54 }; 55 56 static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, 57 const string &attr_name) { 58 auto attr = node.attr().find(attr_name); 59 CHECK(attr != node.attr().end()); 60 return attr->second; 61 } 62 63 TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { 64 auto builder = HloComputation::Builder("Concatenate"); 65 Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); 66 auto param_1 = builder.AddInstruction( 67 HloInstruction::CreateParameter(0, shape, "param0")); 68 auto param_2 = builder.AddInstruction( 69 HloInstruction::CreateParameter(1, shape, "param1")); 70 builder.AddInstruction(HloInstruction::CreateConcatenate( 71 ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); 72 TF_CHECK_OK(generator_.AddComputation(*builder.Build())); 73 GraphDef graph_def = generator_.GetGraphDef(); 74 EXPECT_EQ(graph_def.node_size(), 3); 75 const auto &node = graph_def.node(2); 76 EXPECT_EQ(node.name(), "Concatenate/concatenate"); 77 78 // Check dimensions. 79 auto dims_value = GetNodeAttr(node, "dims"); 80 EXPECT_EQ(dims_value.list().i_size(), 1); 81 EXPECT_EQ(dims_value.list().i(0), 1); 82 83 // Check shapes. 84 auto shape_value = GetNodeAttr(node, "_output_shapes"); 85 EXPECT_EQ(shape_value.list().shape_size(), 1); 86 EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); 87 EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); 88 EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); 89 } 90 91 TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { 92 auto builder = HloComputation::Builder("Const"); 93 HloInstruction *instruction = builder.AddInstruction( 94 HloInstruction::CreateConstant(Literal::CreateR0(123))); 95 OpMetadata metadata; 96 metadata.set_op_name("x"); 97 metadata.set_op_type("y"); 98 instruction->set_metadata(metadata); 99 TF_CHECK_OK(generator_.AddComputation(*builder.Build())); 100 GraphDef graph_def = generator_.GetGraphDef(); 101 EXPECT_EQ(graph_def.node_size(), 1); 102 const auto &node = graph_def.node(0); 103 EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); 104 EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); 105 EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); 106 EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); 107 } 108 109 TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { 110 auto negate_computation = CreateNegateComputation(); 111 TF_CHECK_OK(generator_.AddComputation(*negate_computation)); 112 GraphDef graph_def = generator_.GetGraphDef(); 113 EXPECT_EQ(graph_def.node_size(), 2); 114 EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); 115 EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); 116 EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); 117 EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); 118 EXPECT_EQ(graph_def.node(1).input_size(), 1); 119 EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); 120 } 121 122 TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { 123 auto builder = HloComputation::Builder("GE"); 124 auto param_1 = builder.AddInstruction( 125 HloInstruction::CreateParameter(0, r0f32_, "param0")); 126 auto param_2 = builder.AddInstruction( 127 HloInstruction::CreateParameter(1, r0f32_, "param1")); 128 builder.AddInstruction( 129 HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); 130 TF_CHECK_OK(generator_.AddComputation(*builder.Build())); 131 GraphDef graph_def = generator_.GetGraphDef(); 132 EXPECT_EQ(graph_def.node_size(), 3); 133 EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); 134 EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); 135 EXPECT_EQ(graph_def.node(2).input_size(), 2); 136 EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); 137 EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); 138 } 139 140 TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { 141 auto builder = HloComputation::Builder("GE"); 142 auto param_1 = builder.AddInstruction( 143 HloInstruction::CreateParameter(0, r0f32_, "param0")); 144 auto param_2 = builder.AddInstruction( 145 HloInstruction::CreateParameter(1, r0f32_, "param1")); 146 auto ge = builder.AddInstruction( 147 HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); 148 OpMetadata metadata; 149 metadata.set_op_name("x/y"); 150 metadata.set_op_type("Y"); 151 ge->set_metadata(metadata); 152 TF_CHECK_OK(generator_.AddComputation(*builder.Build())); 153 GraphDef graph_def = generator_.GetGraphDef(); 154 EXPECT_EQ(graph_def.node_size(), 3); 155 EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); 156 EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); 157 EXPECT_EQ(graph_def.node(2).input_size(), 2); 158 EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); 159 EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); 160 } 161 162 TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { 163 // Create computations with a diamond-shaped callgraph. 164 auto negate_computation = CreateNegateComputation(); 165 auto map1_computation = CreateMapComputation(negate_computation.get()); 166 auto map2_computation = CreateMapComputation(negate_computation.get()); 167 168 auto builder = HloComputation::Builder(TestName()); 169 auto param = builder.AddInstruction( 170 HloInstruction::CreateParameter(0, r0f32_, "param0")); 171 auto map1 = builder.AddInstruction( 172 HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); 173 auto map2 = builder.AddInstruction( 174 HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); 175 builder.AddInstruction( 176 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); 177 auto computation = builder.Build(); 178 TF_CHECK_OK(generator_.AddComputation(*computation)); 179 EXPECT_GT(generator_.GetGraphDef().node_size(), 0); 180 } 181 182 } // namespace 183 } // namespace hlo_graph_dumper 184 } // namespace xla 185