1 /*------------------------------------------------------------------------ 2 * Vulkan Conformance Tests 3 * ------------------------ 4 * 5 * Copyright (c) 2017 The Khronos Group Inc. 6 * Copyright (c) 2017 Codeplay Software Ltd. 7 * 8 * Licensed under the Apache License, Version 2.0 (the "License"); 9 * you may not use this file except in compliance with the License. 10 * You may obtain a copy of the License at 11 * 12 * http://www.apache.org/licenses/LICENSE-2.0 13 * 14 * Unless required by applicable law or agreed to in writing, software 15 * distributed under the License is distributed on an "AS IS" BASIS, 16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 * See the License for the specific language governing permissions and 18 * limitations under the License. 19 * 20 */ /*! 21 * \file 22 * \brief Subgroups Tests 23 */ /*--------------------------------------------------------------------*/ 24 25 #include "vktSubgroupsClusteredTests.hpp" 26 #include "vktSubgroupsTestsUtils.hpp" 27 28 #include <string> 29 #include <vector> 30 31 using namespace tcu; 32 using namespace std; 33 using namespace vk; 34 using namespace vkt; 35 36 namespace 37 { 38 enum OpType 39 { 40 OPTYPE_CLUSTERED_ADD = 0, 41 OPTYPE_CLUSTERED_MUL, 42 OPTYPE_CLUSTERED_MIN, 43 OPTYPE_CLUSTERED_MAX, 44 OPTYPE_CLUSTERED_AND, 45 OPTYPE_CLUSTERED_OR, 46 OPTYPE_CLUSTERED_XOR, 47 OPTYPE_CLUSTERED_LAST 48 }; 49 50 static bool checkVertexPipelineStages(std::vector<const void*> datas, 51 deUint32 width, deUint32) 52 { 53 const deUint32* data = 54 reinterpret_cast<const deUint32*>(datas[0]); 55 for (deUint32 x = 0; x < width; ++x) 56 { 57 deUint32 val = data[x]; 58 59 if (0x1 != val) 60 { 61 return false; 62 } 63 } 64 65 return true; 66 } 67 68 static bool checkFragment(std::vector<const void*> datas, 69 deUint32 width, deUint32 height, deUint32) 70 { 71 const deUint32* data = 72 reinterpret_cast<const deUint32*>(datas[0]); 73 for (deUint32 x = 0; x < width; ++x) 74 { 75 for (deUint32 y = 0; y < height; ++y) 76 { 77 deUint32 val = data[x * height + y]; 78 79 if (0x1 != val) 80 { 81 return false; 82 } 83 } 84 } 85 86 return true; 87 } 88 89 static bool checkCompute(std::vector<const void*> datas, 90 const deUint32 numWorkgroups[3], const deUint32 localSize[3], 91 deUint32) 92 { 93 const deUint32* data = 94 reinterpret_cast<const deUint32*>(datas[0]); 95 96 for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX) 97 { 98 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY) 99 { 100 for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ) 101 { 102 for (deUint32 lX = 0; lX < localSize[0]; ++lX) 103 { 104 for (deUint32 lY = 0; lY < localSize[1]; ++lY) 105 { 106 for (deUint32 lZ = 0; lZ < localSize[2]; 107 ++lZ) 108 { 109 const deUint32 globalInvocationX = 110 nX * localSize[0] + lX; 111 const deUint32 globalInvocationY = 112 nY * localSize[1] + lY; 113 const deUint32 globalInvocationZ = 114 nZ * localSize[2] + lZ; 115 116 const deUint32 globalSizeX = 117 numWorkgroups[0] * localSize[0]; 118 const deUint32 globalSizeY = 119 numWorkgroups[1] * localSize[1]; 120 121 const deUint32 offset = 122 globalSizeX * 123 ((globalSizeY * 124 globalInvocationZ) + 125 globalInvocationY) + 126 globalInvocationX; 127 128 if (0x1 != data[offset]) 129 { 130 return false; 131 } 132 } 133 } 134 } 135 } 136 } 137 } 138 139 return true; 140 } 141 142 std::string getOpTypeName(int opType) 143 { 144 switch (opType) 145 { 146 default: 147 DE_FATAL("Unsupported op type"); 148 case OPTYPE_CLUSTERED_ADD: 149 return "subgroupClusteredAdd"; 150 case OPTYPE_CLUSTERED_MUL: 151 return "subgroupClusteredMul"; 152 case OPTYPE_CLUSTERED_MIN: 153 return "subgroupClusteredMin"; 154 case OPTYPE_CLUSTERED_MAX: 155 return "subgroupClusteredMax"; 156 case OPTYPE_CLUSTERED_AND: 157 return "subgroupClusteredAnd"; 158 case OPTYPE_CLUSTERED_OR: 159 return "subgroupClusteredOr"; 160 case OPTYPE_CLUSTERED_XOR: 161 return "subgroupClusteredXor"; 162 } 163 } 164 165 std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs) 166 { 167 switch (opType) 168 { 169 default: 170 DE_FATAL("Unsupported op type"); 171 case OPTYPE_CLUSTERED_ADD: 172 return lhs + " + " + rhs; 173 case OPTYPE_CLUSTERED_MUL: 174 return lhs + " * " + rhs; 175 case OPTYPE_CLUSTERED_MIN: 176 switch (format) 177 { 178 default: 179 return "min(" + lhs + ", " + rhs + ")"; 180 case VK_FORMAT_R32_SFLOAT: 181 case VK_FORMAT_R64_SFLOAT: 182 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))"; 183 case VK_FORMAT_R32G32_SFLOAT: 184 case VK_FORMAT_R32G32B32_SFLOAT: 185 case VK_FORMAT_R32G32B32A32_SFLOAT: 186 case VK_FORMAT_R64G64_SFLOAT: 187 case VK_FORMAT_R64G64B64_SFLOAT: 188 case VK_FORMAT_R64G64B64A64_SFLOAT: 189 return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))"; 190 } 191 case OPTYPE_CLUSTERED_MAX: 192 switch (format) 193 { 194 default: 195 return "max(" + lhs + ", " + rhs + ")"; 196 case VK_FORMAT_R32_SFLOAT: 197 case VK_FORMAT_R64_SFLOAT: 198 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))"; 199 case VK_FORMAT_R32G32_SFLOAT: 200 case VK_FORMAT_R32G32B32_SFLOAT: 201 case VK_FORMAT_R32G32B32A32_SFLOAT: 202 case VK_FORMAT_R64G64_SFLOAT: 203 case VK_FORMAT_R64G64B64_SFLOAT: 204 case VK_FORMAT_R64G64B64A64_SFLOAT: 205 return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))"; 206 } 207 case OPTYPE_CLUSTERED_AND: 208 switch (format) 209 { 210 default: 211 return lhs + " & " + rhs; 212 case VK_FORMAT_R8_USCALED: 213 return lhs + " && " + rhs; 214 case VK_FORMAT_R8G8_USCALED: 215 return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)"; 216 case VK_FORMAT_R8G8B8_USCALED: 217 return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)"; 218 case VK_FORMAT_R8G8B8A8_USCALED: 219 return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)"; 220 } 221 case OPTYPE_CLUSTERED_OR: 222 switch (format) 223 { 224 default: 225 return lhs + " | " + rhs; 226 case VK_FORMAT_R8_USCALED: 227 return lhs + " || " + rhs; 228 case VK_FORMAT_R8G8_USCALED: 229 return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)"; 230 case VK_FORMAT_R8G8B8_USCALED: 231 return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)"; 232 case VK_FORMAT_R8G8B8A8_USCALED: 233 return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)"; 234 } 235 case OPTYPE_CLUSTERED_XOR: 236 switch (format) 237 { 238 default: 239 return lhs + " ^ " + rhs; 240 case VK_FORMAT_R8_USCALED: 241 return lhs + " ^^ " + rhs; 242 case VK_FORMAT_R8G8_USCALED: 243 return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)"; 244 case VK_FORMAT_R8G8B8_USCALED: 245 return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)"; 246 case VK_FORMAT_R8G8B8A8_USCALED: 247 return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)"; 248 } 249 } 250 } 251 252 253 std::string getIdentity(int opType, vk::VkFormat format) 254 { 255 bool isFloat = false; 256 bool isInt = false; 257 bool isUnsigned = false; 258 259 switch (format) 260 { 261 default: 262 DE_FATAL("Unhandled format!"); 263 case VK_FORMAT_R32_SINT: 264 case VK_FORMAT_R32G32_SINT: 265 case VK_FORMAT_R32G32B32_SINT: 266 case VK_FORMAT_R32G32B32A32_SINT: 267 isInt = true; 268 break; 269 case VK_FORMAT_R32_UINT: 270 case VK_FORMAT_R32G32_UINT: 271 case VK_FORMAT_R32G32B32_UINT: 272 case VK_FORMAT_R32G32B32A32_UINT: 273 isUnsigned = true; 274 break; 275 case VK_FORMAT_R32_SFLOAT: 276 case VK_FORMAT_R32G32_SFLOAT: 277 case VK_FORMAT_R32G32B32_SFLOAT: 278 case VK_FORMAT_R32G32B32A32_SFLOAT: 279 case VK_FORMAT_R64_SFLOAT: 280 case VK_FORMAT_R64G64_SFLOAT: 281 case VK_FORMAT_R64G64B64_SFLOAT: 282 case VK_FORMAT_R64G64B64A64_SFLOAT: 283 isFloat = true; 284 break; 285 case VK_FORMAT_R8_USCALED: 286 case VK_FORMAT_R8G8_USCALED: 287 case VK_FORMAT_R8G8B8_USCALED: 288 case VK_FORMAT_R8G8B8A8_USCALED: 289 break; // bool types are not anything 290 } 291 292 switch (opType) 293 { 294 default: 295 DE_FATAL("Unsupported op type"); 296 case OPTYPE_CLUSTERED_ADD: 297 return subgroups::getFormatNameForGLSL(format) + "(0)"; 298 case OPTYPE_CLUSTERED_MUL: 299 return subgroups::getFormatNameForGLSL(format) + "(1)"; 300 case OPTYPE_CLUSTERED_MIN: 301 if (isFloat) 302 { 303 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))"; 304 } 305 else if (isInt) 306 { 307 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)"; 308 } 309 else if (isUnsigned) 310 { 311 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)"; 312 } 313 else 314 { 315 DE_FATAL("Unhandled case"); 316 } 317 case OPTYPE_CLUSTERED_MAX: 318 if (isFloat) 319 { 320 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))"; 321 } 322 else if (isInt) 323 { 324 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)"; 325 } 326 else if (isUnsigned) 327 { 328 return subgroups::getFormatNameForGLSL(format) + "(0)"; 329 } 330 else 331 { 332 DE_FATAL("Unhandled case"); 333 } 334 case OPTYPE_CLUSTERED_AND: 335 return subgroups::getFormatNameForGLSL(format) + "(~0)"; 336 case OPTYPE_CLUSTERED_OR: 337 return subgroups::getFormatNameForGLSL(format) + "(0)"; 338 case OPTYPE_CLUSTERED_XOR: 339 return subgroups::getFormatNameForGLSL(format) + "(0)"; 340 } 341 } 342 343 std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs) 344 { 345 std::string formatName = subgroups::getFormatNameForGLSL(format); 346 switch (format) 347 { 348 default: 349 return "all(equal(" + lhs + ", " + rhs + "))"; 350 case VK_FORMAT_R8_USCALED: 351 case VK_FORMAT_R32_UINT: 352 case VK_FORMAT_R32_SINT: 353 return "(" + lhs + " == " + rhs + ")"; 354 case VK_FORMAT_R32_SFLOAT: 355 case VK_FORMAT_R64_SFLOAT: 356 switch (opType) 357 { 358 default: 359 return "(abs(" + lhs + " - " + rhs + ") < 0.00001)"; 360 case OPTYPE_CLUSTERED_MIN: 361 case OPTYPE_CLUSTERED_MAX: 362 return "(" + lhs + " == " + rhs + ")"; 363 } 364 case VK_FORMAT_R32G32_SFLOAT: 365 case VK_FORMAT_R32G32B32_SFLOAT: 366 case VK_FORMAT_R32G32B32A32_SFLOAT: 367 case VK_FORMAT_R64G64_SFLOAT: 368 case VK_FORMAT_R64G64B64_SFLOAT: 369 case VK_FORMAT_R64G64B64A64_SFLOAT: 370 switch (opType) 371 { 372 default: 373 return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))"; 374 case OPTYPE_CLUSTERED_MIN: 375 case OPTYPE_CLUSTERED_MAX: 376 return "all(equal(" + lhs + ", " + rhs + "))"; 377 } 378 } 379 } 380 381 struct CaseDefinition 382 { 383 int opType; 384 VkShaderStageFlags shaderStage; 385 VkFormat format; 386 bool noSSBO; 387 }; 388 389 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef) 390 { 391 std::ostringstream bdy; 392 393 bdy << " bool tempResult = true;\n"; 394 395 for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2) 396 { 397 bdy << " {\n" 398 << " const uint clusterSize = " << i << ";\n" 399 << " if (clusterSize <= gl_SubgroupSize)\n" 400 << " {\n" 401 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " 402 << getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n" 403 << " for (uint clusterOffset = 0; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n" 404 << " {\n" 405 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = " 406 << getIdentity(caseDef.opType, caseDef.format) << ";\n" 407 << " for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n" 408 << " {\n" 409 << " if (subgroupBallotBitExtract(mask, index))\n" 410 << " {\n" 411 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" 412 << " }\n" 413 << " }\n" 414 << " if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n" 415 << " {\n" 416 << " if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n" 417 << " {\n" 418 << " tempResult = false;\n" 419 << " }\n" 420 << " }\n" 421 << " }\n" 422 << " }\n" 423 << " }\n"; 424 } 425 426 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 427 { 428 std::ostringstream src; 429 std::ostringstream fragmentSrc; 430 431 src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450 )<< "\n" 432 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 433 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 434 << "layout(location = 0) in highp vec4 in_position;\n" 435 << "layout(location = 0) out float out_color;\n" 436 << "layout(set = 0, binding = 0) uniform Buffer1\n" 437 << "{\n" 438 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n" 439 << "};\n" 440 << "\n" 441 << "void main (void)\n" 442 << "{\n" 443 << " uvec4 mask = subgroupBallot(true);\n" 444 << bdy.str() 445 << " out_color = float(tempResult ? 1 : 0);\n" 446 << " gl_Position = in_position;\n" 447 << "}\n"; 448 449 programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 450 451 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n" 452 << "layout(location = 0) in float in_color;\n" 453 << "layout(location = 0) out uint out_color;\n" 454 << "void main()\n" 455 <<"{\n" 456 << " out_color = uint(in_color);\n" 457 << "}\n"; 458 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 459 } 460 else 461 { 462 DE_FATAL("Unsupported shader stage"); 463 } 464 } 465 466 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef) 467 { 468 std::ostringstream bdy; 469 470 bdy << " bool tempResult = true;\n"; 471 472 for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2) 473 { 474 bdy << " {\n" 475 << " const uint clusterSize = " << i << ";\n" 476 << " if (clusterSize <= gl_SubgroupSize)\n" 477 << " {\n" 478 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " 479 << getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n" 480 << " for (uint clusterOffset = 0; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n" 481 << " {\n" 482 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = " 483 << getIdentity(caseDef.opType, caseDef.format) << ";\n" 484 << " for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n" 485 << " {\n" 486 << " if (subgroupBallotBitExtract(mask, index))\n" 487 << " {\n" 488 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" 489 << " }\n" 490 << " }\n" 491 << " if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n" 492 << " {\n" 493 << " if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n" 494 << " {\n" 495 << " tempResult = false;\n" 496 << " }\n" 497 << " }\n" 498 << " }\n" 499 << " }\n" 500 << " }\n"; 501 } 502 503 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage) 504 { 505 std::ostringstream src; 506 507 src << "#version 450\n" 508 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 509 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 510 << "layout (local_size_x_id = 0, local_size_y_id = 1, " 511 "local_size_z_id = 2) in;\n" 512 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 513 << "{\n" 514 << " uint result[];\n" 515 << "};\n" 516 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 517 << "{\n" 518 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 519 << "};\n" 520 << "\n" 521 << "void main (void)\n" 522 << "{\n" 523 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n" 524 << " highp uint offset = globalSize.x * ((globalSize.y * " 525 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + " 526 "gl_GlobalInvocationID.x;\n" 527 << " uvec4 mask = subgroupBallot(true);\n" 528 << bdy.str() 529 << " result[offset] = tempResult ? 1 : 0;\n" 530 << "}\n"; 531 532 programCollection.glslSources.add("comp") 533 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 534 } 535 else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage) 536 { 537 programCollection.glslSources.add("vert") 538 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 539 540 std::ostringstream frag; 541 542 frag << "#version 450\n" 543 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 544 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 545 << "layout(location = 0) out uint result;\n" 546 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer2\n" 547 << "{\n" 548 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 549 << "};\n" 550 << "void main (void)\n" 551 << "{\n" 552 << " uvec4 mask = subgroupBallot(true);\n" 553 << bdy.str() 554 << " result = tempResult ? 1 : 0;\n" 555 << "}\n"; 556 557 programCollection.glslSources.add("frag") 558 << glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 559 } 560 else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 561 { 562 std::ostringstream src; 563 564 src << "#version 450\n" 565 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 566 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 567 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 568 << "{\n" 569 << " uint result[];\n" 570 << "};\n" 571 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 572 << "{\n" 573 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 574 << "};\n" 575 << "\n" 576 << "void main (void)\n" 577 << "{\n" 578 << " uvec4 mask = subgroupBallot(true);\n" 579 << bdy.str() 580 << " result[gl_VertexIndex] = tempResult ? 1 : 0;\n" 581 << "}\n"; 582 583 programCollection.glslSources.add("vert") 584 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 585 } 586 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage) 587 { 588 programCollection.glslSources.add("vert") 589 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 590 591 std::ostringstream src; 592 593 src << "#version 450\n" 594 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 595 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 596 << "layout(points) in;\n" 597 << "layout(points, max_vertices = 1) out;\n" 598 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 599 << "{\n" 600 << " uint result[];\n" 601 << "};\n" 602 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 603 << "{\n" 604 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 605 << "};\n" 606 << "\n" 607 << "void main (void)\n" 608 << "{\n" 609 << " uvec4 mask = subgroupBallot(true);\n" 610 << bdy.str() 611 << " result[gl_PrimitiveIDIn] = tempResult ? 1 : 0;\n" 612 << "}\n"; 613 614 programCollection.glslSources.add("geom") 615 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 616 } 617 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage) 618 { 619 programCollection.glslSources.add("vert") 620 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 621 622 programCollection.glslSources.add("tese") 623 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n"); 624 625 std::ostringstream src; 626 627 src << "#version 450\n" 628 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 629 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 630 << "layout(vertices=1) out;\n" 631 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 632 << "{\n" 633 << " uint result[];\n" 634 << "};\n" 635 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 636 << "{\n" 637 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 638 << "};\n" 639 << "\n" 640 << "void main (void)\n" 641 << "{\n" 642 << " uvec4 mask = subgroupBallot(true);\n" 643 << bdy.str() 644 << " result[gl_PrimitiveID] = tempResult ? 1 : 0;\n" 645 << "}\n"; 646 647 programCollection.glslSources.add("tesc") 648 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 649 } 650 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage) 651 { 652 programCollection.glslSources.add("vert") 653 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 654 655 programCollection.glslSources.add("tesc") 656 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n"); 657 658 std::ostringstream src; 659 660 src << "#version 450\n" 661 << "#extension GL_KHR_shader_subgroup_clustered: enable\n" 662 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 663 << "layout(isolines) in;\n" 664 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 665 << "{\n" 666 << " uint result[];\n" 667 << "};\n" 668 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 669 << "{\n" 670 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 671 << "};\n" 672 << "\n" 673 << "void main (void)\n" 674 << "{\n" 675 << " uvec4 mask = subgroupBallot(true);\n" 676 << bdy.str() 677 << " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult ? 1 : 0;\n" 678 << "}\n"; 679 680 programCollection.glslSources.add("tese") 681 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 682 } 683 else 684 { 685 DE_FATAL("Unsupported shader stage"); 686 } 687 } 688 689 tcu::TestStatus test(Context& context, const CaseDefinition caseDef) 690 { 691 if (!subgroups::isSubgroupSupported(context)) 692 TCU_THROW(NotSupportedError, "Subgroup operations are not supported"); 693 694 if (!subgroups::areSubgroupOperationsSupportedForStage( 695 context, caseDef.shaderStage)) 696 { 697 if (subgroups::areSubgroupOperationsRequiredForStage( 698 caseDef.shaderStage)) 699 { 700 return tcu::TestStatus::fail( 701 "Shader stage " + 702 subgroups::getShaderStageName(caseDef.shaderStage) + 703 " is required to support subgroup operations!"); 704 } 705 else 706 { 707 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage"); 708 } 709 } 710 711 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_CLUSTERED_BIT)) 712 { 713 TCU_THROW(NotSupportedError, "Device does not support subgroup clustered operations"); 714 } 715 716 if (subgroups::isDoubleFormat(caseDef.format) && 717 !subgroups::isDoubleSupportedForDevice(context)) 718 { 719 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations"); 720 } 721 722 //Tests which don't use the SSBO 723 if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 724 { 725 subgroups::SSBOData inputData; 726 inputData.format = caseDef.format; 727 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 728 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 729 730 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages); 731 } 732 733 if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) && 734 (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage)) 735 { 736 if (!subgroups::isVertexSSBOSupportedForDevice(context)) 737 { 738 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes"); 739 } 740 } 741 742 if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage) 743 { 744 subgroups::SSBOData inputData; 745 inputData.format = caseDef.format; 746 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 747 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 748 749 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT, 750 &inputData, 1, checkFragment); 751 } 752 else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage) 753 { 754 subgroups::SSBOData inputData; 755 inputData.format = caseDef.format; 756 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 757 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 758 759 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 760 1, checkCompute); 761 } 762 else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 763 { 764 subgroups::SSBOData inputData; 765 inputData.format = caseDef.format; 766 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 767 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 768 769 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT, &inputData, 770 1, checkVertexPipelineStages); 771 } 772 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage) 773 { 774 subgroups::SSBOData inputData; 775 inputData.format = caseDef.format; 776 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 777 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 778 779 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT, &inputData, 780 1, checkVertexPipelineStages); 781 } 782 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage) 783 { 784 subgroups::SSBOData inputData; 785 inputData.format = caseDef.format; 786 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 787 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 788 789 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT, &inputData, 790 1, checkVertexPipelineStages); 791 } 792 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage) 793 { 794 subgroups::SSBOData inputData; 795 inputData.format = caseDef.format; 796 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 797 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 798 799 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT, &inputData, 800 1, checkVertexPipelineStages); 801 } 802 else 803 { 804 return tcu::TestStatus::pass("Unhandled shader stage!"); 805 } 806 } 807 } 808 809 namespace vkt 810 { 811 namespace subgroups 812 { 813 tcu::TestCaseGroup* createSubgroupsClusteredTests(tcu::TestContext& testCtx) 814 { 815 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup( 816 testCtx, "clustered", "Subgroup clustered category tests")); 817 818 const VkShaderStageFlags stages[] = 819 { 820 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, 821 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, 822 VK_SHADER_STAGE_GEOMETRY_BIT, 823 VK_SHADER_STAGE_VERTEX_BIT, 824 VK_SHADER_STAGE_FRAGMENT_BIT, 825 VK_SHADER_STAGE_COMPUTE_BIT 826 }; 827 828 const VkFormat formats[] = 829 { 830 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT, 831 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT, 832 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT, 833 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT, 834 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT, 835 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT, 836 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT, 837 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED, 838 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED, 839 }; 840 841 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex) 842 { 843 const VkShaderStageFlags stage = stages[stageIndex]; 844 845 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex) 846 { 847 const VkFormat format = formats[formatIndex]; 848 849 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex) 850 { 851 bool isBool = false; 852 bool isFloat = false; 853 854 switch (format) 855 { 856 default: 857 break; 858 case VK_FORMAT_R32_SFLOAT: 859 case VK_FORMAT_R32G32_SFLOAT: 860 case VK_FORMAT_R32G32B32_SFLOAT: 861 case VK_FORMAT_R32G32B32A32_SFLOAT: 862 case VK_FORMAT_R64_SFLOAT: 863 case VK_FORMAT_R64G64_SFLOAT: 864 case VK_FORMAT_R64G64B64_SFLOAT: 865 case VK_FORMAT_R64G64B64A64_SFLOAT: 866 isFloat = true; 867 break; 868 case VK_FORMAT_R8_USCALED: 869 case VK_FORMAT_R8G8_USCALED: 870 case VK_FORMAT_R8G8B8_USCALED: 871 case VK_FORMAT_R8G8B8A8_USCALED: 872 isBool = true; 873 break; 874 } 875 876 bool isBitwiseOp = false; 877 878 switch (opTypeIndex) 879 { 880 default: 881 break; 882 case OPTYPE_CLUSTERED_AND: 883 case OPTYPE_CLUSTERED_OR: 884 case OPTYPE_CLUSTERED_XOR: 885 isBitwiseOp = true; 886 break; 887 } 888 889 if (isFloat && isBitwiseOp) 890 { 891 // Skip float with bitwise category. 892 continue; 893 } 894 895 if (isBool && !isBitwiseOp) 896 { 897 // Skip bool when its not the bitwise category. 898 continue; 899 } 900 901 CaseDefinition caseDef = {opTypeIndex, stage, format, false}; 902 903 std::ostringstream name; 904 905 std::string op = getOpTypeName(opTypeIndex); 906 907 name << de::toLower(op) 908 << "_" << subgroups::getFormatNameForGLSL(format) 909 << "_" << getShaderStageName(stage); 910 911 addFunctionCaseWithPrograms(group.get(), name.str(), 912 "", initPrograms, test, caseDef); 913 914 if (VK_SHADER_STAGE_VERTEX_BIT == stage) 915 { 916 caseDef.noSSBO = true; 917 addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "", 918 initFrameBufferPrograms, test, caseDef); 919 } 920 } 921 } 922 } 923 924 return group.release(); 925 } 926 927 } // subgroups 928 } // vkt 929