1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/client/client.h" 22 #include "tensorflow/compiler/xla/client/client_library.h" 23 #include "tensorflow/compiler/xla/client/computation.h" 24 #include "tensorflow/compiler/xla/client/computation_builder.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/compiler/xla/client/padding.h" 27 #include "tensorflow/compiler/xla/service/computation_tracker.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/local_service.h" 30 #include "tensorflow/compiler/xla/service/service.h" 31 #include "tensorflow/compiler/xla/service/user_computation.h" 32 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" 33 #include "tensorflow/compiler/xla/shape_util.h" 34 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 35 #include "tensorflow/core/platform/logging.h" 36 37 #include "tensorflow/compiler/xla/statusor.h" 38 #include "tensorflow/compiler/xla/test_helpers.h" 39 40 namespace xla { 41 namespace { 42 43 constexpr int64 kPointerSize = 8; 44 45 int64 ShapeSize(const Shape& shape) { 46 return ShapeUtil::ByteSizeOf(shape, kPointerSize); 47 } 48 49 // This test suite tests the HLO cost analysis by first building a computation 50 // using the client computation builder and running the HloCostAnalysis that 51 // returns the number of floating point and transcendental operations in the 52 // graph. We test both individual HLO operations as well as a mixed graph. 53 class HloCostAnalysisTest : public ::testing::Test { 54 protected: 55 HloCostAnalysisTest() 56 : client_(ClientLibrary::LocalClientOrDie()), 57 // Accessing service instance is required for the unit tests to enable 58 // whitebox accesses to the user computation built from the client, 59 // as shown in the BuildHloGraph functions below. 60 service_(static_cast<Service*>(ClientLibrary::GetXlaService( 61 static_cast<LocalClient*>(client_)->platform()))), 62 computation_tracker_(service_->computation_tracker()) { 63 // Create a computation for a unary user function: x => exp(x + 0.5) 64 { 65 ComputationBuilder builder(client_, "add_and_exp"); 66 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 67 auto half = builder.ConstantR0<float>(0.5); 68 builder.Exp(builder.Add(x, half)); 69 auto computation_status = builder.Build(); 70 TF_CHECK_OK(computation_status.status()); 71 add_and_exp_ = computation_status.ConsumeValueOrDie(); 72 } 73 74 // Create a computation for a binary user function: (x, y) => x + y 75 { 76 ComputationBuilder builder(client_, "add"); 77 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 78 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 79 builder.Add(x, y); 80 auto computation_status = builder.Build(); 81 TF_CHECK_OK(computation_status.status()); 82 add_ = computation_status.ConsumeValueOrDie(); 83 } 84 85 // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) 86 { 87 ComputationBuilder builder(client_, "sigmoid"); 88 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 89 auto one = builder.ConstantR0<float>(1.0); 90 builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); 91 auto computation_status = builder.Build(); 92 TF_CHECK_OK(computation_status.status()); 93 sigmoid_ = computation_status.ConsumeValueOrDie(); 94 } 95 96 // Create a computation for a binary max function: (x, y) => max (x, y) 97 { 98 ComputationBuilder builder(client_, "max"); 99 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 100 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 101 builder.Max(x, y); 102 auto computation_status = builder.Build(); 103 TF_CHECK_OK(computation_status.status()); 104 max_ = computation_status.ConsumeValueOrDie(); 105 } 106 107 // Create a computation for a binary GT function: (x, y) => x > y 108 { 109 ComputationBuilder builder(client_, "gt"); 110 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 111 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 112 builder.Gt(x, y); 113 auto computation_status = builder.Build(); 114 TF_CHECK_OK(computation_status.status()); 115 gt_ = computation_status.ConsumeValueOrDie(); 116 } 117 } 118 119 // Build HLO graph from the given builder and return the HLO module. 120 std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) { 121 auto computation_status = builder->Build(); 122 TF_CHECK_OK(computation_status.status()); 123 auto computation = computation_status.ConsumeValueOrDie(); 124 auto user_computation_status = 125 computation_tracker_.Resolve(computation.handle()); 126 TF_CHECK_OK(user_computation_status.status()); 127 auto user_computation = user_computation_status.ConsumeValueOrDie(); 128 VersionedComputationHandle versioned_handle = 129 user_computation->GetVersionedHandle(); 130 return std::move( 131 computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) 132 .ValueOrDie()); 133 } 134 135 Client* client_; 136 Service* service_; 137 const ComputationTracker& computation_tracker_; 138 139 // User computations used for higher order operations (e.g., Map, Reduce). 140 Computation add_; 141 Computation add_and_exp_; 142 Computation sigmoid_; 143 Computation max_; 144 Computation gt_; 145 }; 146 147 TEST_F(HloCostAnalysisTest, MatrixMultiply) { 148 ComputationBuilder builder(client_, "matrix_multiply"); 149 auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); 150 auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); 151 auto result = builder.Dot(lhs, rhs); 152 153 // Run HLO cost analysis. 154 auto hlo_module = BuildHloGraph(&builder); 155 HloCostAnalysis analysis(ShapeSize); 156 ASSERT_IS_OK( 157 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 158 159 // Check the number of computations returned from the analysis (1500 FMAs). 160 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5); 161 162 EXPECT_EQ(analysis.transcendental_count(), 0); 163 164 // Bytes accessed is sum of inputs and output. 165 EXPECT_EQ(analysis.bytes_accessed(), 166 sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30)); 167 } 168 169 TEST_F(HloCostAnalysisTest, Map) { 170 ComputationBuilder builder(client_, "map"); 171 auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); 172 auto result = builder.Map({input}, add_and_exp_, {0}); 173 174 // Run HLO cost analysis. 175 auto hlo_module = BuildHloGraph(&builder); 176 HloCostAnalysis analysis(ShapeSize); 177 ASSERT_IS_OK( 178 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 179 180 // add contributes to 10 flops and exp contributes to 10 transcendental ops. 181 EXPECT_EQ(analysis.flop_count(), 10); 182 EXPECT_EQ(analysis.transcendental_count(), 10); 183 EXPECT_EQ(analysis.bytes_accessed(), 80); 184 } 185 186 TEST_F(HloCostAnalysisTest, Convolution) { 187 ComputationBuilder builder(client_, "convolution"); 188 auto input = builder.Parameter( 189 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, 190 /*x_dim=*/20}), 191 "input"); 192 auto kernel = builder.Parameter( 193 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, 194 /*x_dim=*/3}), 195 "kernel"); 196 auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); 197 198 // Run HLO cost analysis. 199 auto hlo_module = BuildHloGraph(&builder); 200 HloCostAnalysis analysis(ShapeSize); 201 ASSERT_IS_OK( 202 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 203 204 // Output shape is [1x1x8x18] and each output element requires (3x3) 205 // FMAs and one FMA is 2 flops. 206 EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3); 207 208 // Bytes accessed is sum of inputs and output. 209 EXPECT_EQ(analysis.bytes_accessed(), 210 sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); 211 } 212 213 TEST_F(HloCostAnalysisTest, Reduce) { 214 ComputationBuilder builder(client_, "reduce"); 215 auto input = 216 builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); 217 auto result = 218 builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1}); 219 220 // Run HLO cost analysis. 221 auto hlo_module = BuildHloGraph(&builder); 222 HloCostAnalysis analysis(ShapeSize); 223 ASSERT_IS_OK( 224 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 225 226 // Subtracting the output size from the input size gives the number of 227 // reduction operations performed. 228 EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10); 229 } 230 231 TEST_F(HloCostAnalysisTest, ReduceWindow) { 232 ComputationBuilder builder(client_, "reduce_window"); 233 auto input = 234 builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); 235 auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_, 236 {4, 5}, {4, 5}, Padding::kValid); 237 238 // Run HLO cost analysis. 239 auto hlo_module = BuildHloGraph(&builder); 240 HloCostAnalysis analysis(ShapeSize); 241 ASSERT_IS_OK( 242 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 243 244 // Each of [2x4] output elements are generated from reducing [4x5] elements. 245 EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1)); 246 } 247 248 TEST_F(HloCostAnalysisTest, SelectAndScatter) { 249 ComputationBuilder builder(client_, "select_and_scatter"); 250 auto operand = 251 builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); 252 auto source = 253 builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); 254 auto result = 255 builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, 256 source, builder.ConstantR0<float>(0), add_); 257 258 // Run HLO cost analysis. 259 auto hlo_module = BuildHloGraph(&builder); 260 HloCostAnalysis analysis(ShapeSize); 261 ASSERT_IS_OK( 262 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 263 264 // Each of [2x4] source elements computes its destination from reducing [4x5] 265 // elements followed by the scatter computation. 266 EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1)); 267 } 268 269 TEST_F(HloCostAnalysisTest, Broadcast) { 270 ComputationBuilder b(client_, "broadcast"); 271 b.Broadcast(b.ConstantR0<float>(42), {10, 7}); 272 auto hlo_module = BuildHloGraph(&b); 273 HloCostAnalysis analysis(ShapeSize); 274 ASSERT_IS_OK( 275 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 276 EXPECT_EQ(analysis.flop_count(), 0); 277 } 278 279 // Calculates the computation cost of a graph with more than one HLO node. 280 TEST_F(HloCostAnalysisTest, FullyConnectedForward) { 281 ComputationBuilder builder(client_, "fully_connected_forward"); 282 auto input = 283 builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); 284 auto weight = 285 builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); 286 auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); 287 // sigmoid(input * weight + bias) 288 auto result = builder.Map( 289 {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); 290 291 // Run HLO cost analysis. 292 auto hlo_module = BuildHloGraph(&builder); 293 HloCostAnalysis analysis(ShapeSize); 294 ASSERT_IS_OK( 295 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 296 297 // 1000 FMAs from matrix multiplication, 200 flops from bias addition, 298 // 600 flops from sigmoid, and 200 transcendental ops from sigmoid. 299 EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200); 300 EXPECT_EQ(analysis.transcendental_count(), 200); 301 } 302 303 TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { 304 HloCostAnalysis conv_analysis(ShapeSize); 305 { 306 ComputationBuilder builder(client_, "conv_looking_matmul"); 307 auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), 308 "input"); 309 auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), 310 "weights"); 311 builder.Conv(lhs, rhs, {1, 1}, Padding::kSame); 312 auto hlo_module = BuildHloGraph(&builder); 313 ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( 314 &conv_analysis)); 315 } 316 317 HloCostAnalysis matmul_analysis(ShapeSize); 318 { 319 ComputationBuilder builder(client_, "matmul"); 320 auto lhs = 321 builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); 322 auto rhs = 323 builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); 324 builder.Dot(lhs, rhs); 325 auto hlo_module = BuildHloGraph(&builder); 326 ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( 327 &matmul_analysis)); 328 } 329 330 EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count()); 331 } 332 333 using FusionCostAnalysis = HloTestBase; 334 335 TEST_F(FusionCostAnalysis, LoopFusion) { 336 // Do this 4 times with different per-second rates to test the computation of 337 // bottleneck time on fusion nodes. 338 for (int i = 0; i < 4; ++i) { 339 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); 340 341 // Fuse all instructions in complicated expression: 342 // 343 // add = Add(C1, C2) 344 // clamp = Clamp(C2, add, add) 345 // exp = Exp(add) 346 // mul = Mul(exp, C3) 347 // sub = Sub(mul, clamp) 348 // tuple = Tuple({sub, sub, mul, C1}) 349 HloComputation::Builder builder(TestName()); 350 auto c1 = builder.AddInstruction( 351 HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( 352 /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2))); 353 auto c2 = builder.AddInstruction( 354 HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( 355 /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2))); 356 auto c3 = builder.AddInstruction( 357 HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( 358 /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2))); 359 auto add = builder.AddInstruction( 360 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2)); 361 auto clamp = builder.AddInstruction( 362 HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add)); 363 auto exp = builder.AddInstruction( 364 HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add)); 365 auto mul = builder.AddInstruction( 366 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3)); 367 auto sub = builder.AddInstruction( 368 HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); 369 auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); 370 371 HloModule module(TestName()); 372 auto* computation = module.AddEntryComputation(builder.Build()); 373 auto* fusion = computation->CreateFusionInstruction( 374 {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); 375 376 // The time given these rates at i == 0 is exactly even among the properties 377 // at 1.0 seconds. For other values, one of the rates is slower so that it 378 // becomes the bottleneck. 379 HloCostAnalysis fusion_analysis(ShapeSize); 380 fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0)); 381 fusion_analysis.set_transcendentals_per_second(4 * 382 (i == 2 ? 1 / 4.0 : 1.0)); 383 fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0)); 384 ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); 385 386 EXPECT_EQ(fusion_analysis.flop_count(), 16); 387 EXPECT_EQ(fusion_analysis.transcendental_count(), 4); 388 constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2; 389 static_assert(bytes_accessed == 64, ""); 390 EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed); 391 392 EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i); 393 } 394 } 395 396 TEST_F(FusionCostAnalysis, NoLayout) { 397 Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); 398 // Instructions within a fused op may have no layout. 399 Shape shape_without_layout = shape_with_layout; 400 shape_without_layout.clear_layout(); 401 402 HloComputation::Builder builder(TestName()); 403 auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( 404 Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)))); 405 auto c2 = builder.AddInstruction( 406 HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}))); 407 408 auto broadcast = builder.AddInstruction( 409 HloInstruction::CreateBroadcast(shape_without_layout, c2, {1})); 410 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 411 shape_with_layout, HloOpcode::kAdd, c1, broadcast)); 412 413 HloModule module(TestName()); 414 auto* computation = module.AddEntryComputation(builder.Build()); 415 auto* fusion = computation->CreateFusionInstruction( 416 {add, broadcast}, HloInstruction::FusionKind::kLoop); 417 418 HloCostAnalysis fusion_analysis(ShapeSize); 419 ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); 420 421 EXPECT_EQ(fusion_analysis.flop_count(), 120); 422 EXPECT_EQ(fusion_analysis.transcendental_count(), 0); 423 } 424 425 TEST_F(HloCostAnalysisTest, TupleCost) { 426 HloCostAnalysis analysis(ShapeSize); 427 { 428 ComputationBuilder builder(client_, "matmul"); 429 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); 430 auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); 431 auto tuple = builder.Tuple({x, y}); 432 auto hlo_module = BuildHloGraph(&builder); 433 434 ASSERT_IS_OK( 435 hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); 436 } 437 438 EXPECT_EQ(analysis.flop_count(), 0); 439 EXPECT_EQ(analysis.transcendental_count(), 0); 440 EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); 441 } 442 443 } // namespace 444 } // namespace xla 445