Home | History | Annotate | Download | only in subgroups
      1 /*------------------------------------------------------------------------
      2  * Vulkan Conformance Tests
      3  * ------------------------
      4  *
      5  * Copyright (c) 2017 The Khronos Group Inc.
      6  * Copyright (c) 2017 Codeplay Software Ltd.
      7  *
      8  * Licensed under the Apache License, Version 2.0 (the "License");
      9  * you may not use this file except in compliance with the License.
     10  * You may obtain a copy of the License at
     11  *
     12  *      http://www.apache.org/licenses/LICENSE-2.0
     13  *
     14  * Unless required by applicable law or agreed to in writing, software
     15  * distributed under the License is distributed on an "AS IS" BASIS,
     16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     17  * See the License for the specific language governing permissions and
     18  * limitations under the License.
     19  *
     20  */ /*!
     21  * \file
     22  * \brief Subgroups Tests
     23  */ /*--------------------------------------------------------------------*/
     24 
     25 #include "vktSubgroupsBuiltinVarTests.hpp"
     26 #include "vktSubgroupsTestsUtils.hpp"
     27 
     28 #include <string>
     29 #include <vector>
     30 
     31 using namespace tcu;
     32 using namespace std;
     33 using namespace vk;
     34 
     35 namespace vkt
     36 {
     37 namespace subgroups
     38 {
     39 
     40 bool checkVertexPipelineStagesSubgroupSize(std::vector<const void*> datas,
     41 		deUint32 width, deUint32 subgroupSize)
     42 {
     43 	const deUint32* data =
     44 		reinterpret_cast<const deUint32*>(datas[0]);
     45 	for (deUint32 x = 0; x < width; ++x)
     46 	{
     47 		deUint32 val = data[x * 4];
     48 
     49 		if (subgroupSize != val)
     50 		{
     51 			return false;
     52 		}
     53 	}
     54 
     55 	return true;
     56 }
     57 
     58 bool checkVertexPipelineStagesSubgroupInvocationID(std::vector<const void*> datas,
     59 		deUint32 width, deUint32 subgroupSize)
     60 {
     61 	const deUint32* data =
     62 		reinterpret_cast<const deUint32*>(datas[0]);
     63 	vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
     64 
     65 	for (deUint32 x = 0; x < width; ++x)
     66 	{
     67 		deUint32 subgroupInvocationID = data[(x * 4) + 1];
     68 
     69 		if (subgroupInvocationID >= subgroupSize)
     70 		{
     71 			return false;
     72 		}
     73 
     74 		subgroupInvocationHits[subgroupInvocationID]++;
     75 	}
     76 
     77 	const deUint32 totalSize = width;
     78 
     79 	deUint32 totalInvocationsRun = 0;
     80 	for (deUint32 i = 0; i < subgroupSize; ++i)
     81 	{
     82 		totalInvocationsRun += subgroupInvocationHits[i];
     83 	}
     84 
     85 	if (totalInvocationsRun != totalSize)
     86 	{
     87 		return false;
     88 	}
     89 
     90 	return true;
     91 }
     92 
     93 static bool checkFragmentSubgroupSize(std::vector<const void*> datas,
     94 									  deUint32 width, deUint32 height, deUint32 subgroupSize)
     95 {
     96 	const deUint32* data =
     97 		reinterpret_cast<const deUint32*>(datas[0]);
     98 	for (deUint32 x = 0; x < width; ++x)
     99 	{
    100 		for (deUint32 y = 0; y < height; ++y)
    101 		{
    102 			deUint32 val = data[(x * height + y) * 4];
    103 
    104 			if (subgroupSize != val)
    105 			{
    106 				return false;
    107 			}
    108 		}
    109 	}
    110 
    111 	return true;
    112 }
    113 
    114 static bool checkFragmentSubgroupInvocationID(
    115 	std::vector<const void*> datas, deUint32 width, deUint32 height,
    116 	deUint32 subgroupSize)
    117 {
    118 	const deUint32* data =
    119 		reinterpret_cast<const deUint32*>(datas[0]);
    120 	vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
    121 
    122 	for (deUint32 x = 0; x < width; ++x)
    123 	{
    124 		for (deUint32 y = 0; y < height; ++y)
    125 		{
    126 			deUint32 subgroupInvocationID = data[((x * height + y) * 4) + 1];
    127 
    128 			if (subgroupInvocationID >= subgroupSize)
    129 			{
    130 				return false;
    131 			}
    132 
    133 			subgroupInvocationHits[subgroupInvocationID]++;
    134 		}
    135 	}
    136 
    137 	const deUint32 totalSize = width * height;
    138 
    139 	deUint32 totalInvocationsRun = 0;
    140 	for (deUint32 i = 0; i < subgroupSize; ++i)
    141 	{
    142 		totalInvocationsRun += subgroupInvocationHits[i];
    143 	}
    144 
    145 	if (totalInvocationsRun != totalSize)
    146 	{
    147 		return false;
    148 	}
    149 
    150 	return true;
    151 }
    152 
    153 static bool checkComputeSubgroupSize(std::vector<const void*> datas,
    154 									 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
    155 									 deUint32 subgroupSize)
    156 {
    157 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
    158 
    159 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
    160 	{
    161 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
    162 		{
    163 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
    164 			{
    165 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
    166 				{
    167 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
    168 					{
    169 						for (deUint32 lZ = 0; lZ < localSize[2];
    170 								++lZ)
    171 						{
    172 							const deUint32 globalInvocationX =
    173 								nX * localSize[0] + lX;
    174 							const deUint32 globalInvocationY =
    175 								nY * localSize[1] + lY;
    176 							const deUint32 globalInvocationZ =
    177 								nZ * localSize[2] + lZ;
    178 
    179 							const deUint32 globalSizeX =
    180 								numWorkgroups[0] * localSize[0];
    181 							const deUint32 globalSizeY =
    182 								numWorkgroups[1] * localSize[1];
    183 
    184 							const deUint32 offset =
    185 								globalSizeX *
    186 								((globalSizeY *
    187 								  globalInvocationZ) +
    188 								 globalInvocationY) +
    189 								globalInvocationX;
    190 
    191 							if (subgroupSize != data[offset * 4])
    192 							{
    193 								return false;
    194 							}
    195 						}
    196 					}
    197 				}
    198 			}
    199 		}
    200 	}
    201 
    202 	return true;
    203 }
    204 
    205 static bool checkComputeSubgroupInvocationID(std::vector<const void*> datas,
    206 		const deUint32 numWorkgroups[3], const deUint32 localSize[3],
    207 		deUint32 subgroupSize)
    208 {
    209 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
    210 
    211 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
    212 	{
    213 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
    214 		{
    215 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
    216 			{
    217 				const deUint32 totalLocalSize =
    218 					localSize[0] * localSize[1] * localSize[2];
    219 				vector<deUint32> subgroupInvocationHits(subgroupSize, 0);
    220 
    221 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
    222 				{
    223 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
    224 					{
    225 						for (deUint32 lZ = 0; lZ < localSize[2];
    226 								++lZ)
    227 						{
    228 							const deUint32 globalInvocationX =
    229 								nX * localSize[0] + lX;
    230 							const deUint32 globalInvocationY =
    231 								nY * localSize[1] + lY;
    232 							const deUint32 globalInvocationZ =
    233 								nZ * localSize[2] + lZ;
    234 
    235 							const deUint32 globalSizeX =
    236 								numWorkgroups[0] * localSize[0];
    237 							const deUint32 globalSizeY =
    238 								numWorkgroups[1] * localSize[1];
    239 
    240 							const deUint32 offset =
    241 								globalSizeX *
    242 								((globalSizeY *
    243 								  globalInvocationZ) +
    244 								 globalInvocationY) +
    245 								globalInvocationX;
    246 
    247 							deUint32 subgroupInvocationID = data[(offset * 4) + 1];
    248 
    249 							if (subgroupInvocationID >= subgroupSize)
    250 							{
    251 								return false;
    252 							}
    253 
    254 							subgroupInvocationHits[subgroupInvocationID]++;
    255 						}
    256 					}
    257 				}
    258 
    259 				deUint32 totalInvocationsRun = 0;
    260 				for (deUint32 i = 0; i < subgroupSize; ++i)
    261 				{
    262 					totalInvocationsRun += subgroupInvocationHits[i];
    263 				}
    264 
    265 				if (totalInvocationsRun != totalLocalSize)
    266 				{
    267 					return false;
    268 				}
    269 			}
    270 		}
    271 	}
    272 
    273 	return true;
    274 }
    275 
    276 static bool checkComputeNumSubgroups(std::vector<const void*> datas,
    277 									 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
    278 									 deUint32)
    279 {
    280 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
    281 
    282 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
    283 	{
    284 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
    285 		{
    286 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
    287 			{
    288 				const deUint32 totalLocalSize =
    289 					localSize[0] * localSize[1] * localSize[2];
    290 
    291 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
    292 				{
    293 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
    294 					{
    295 						for (deUint32 lZ = 0; lZ < localSize[2];
    296 								++lZ)
    297 						{
    298 							const deUint32 globalInvocationX =
    299 								nX * localSize[0] + lX;
    300 							const deUint32 globalInvocationY =
    301 								nY * localSize[1] + lY;
    302 							const deUint32 globalInvocationZ =
    303 								nZ * localSize[2] + lZ;
    304 
    305 							const deUint32 globalSizeX =
    306 								numWorkgroups[0] * localSize[0];
    307 							const deUint32 globalSizeY =
    308 								numWorkgroups[1] * localSize[1];
    309 
    310 							const deUint32 offset =
    311 								globalSizeX *
    312 								((globalSizeY *
    313 								  globalInvocationZ) +
    314 								 globalInvocationY) +
    315 								globalInvocationX;
    316 
    317 							deUint32 numSubgroups = data[(offset * 4) + 2];
    318 
    319 							if (numSubgroups > totalLocalSize)
    320 							{
    321 								return false;
    322 							}
    323 						}
    324 					}
    325 				}
    326 			}
    327 		}
    328 	}
    329 
    330 	return true;
    331 }
    332 
    333 static bool checkComputeSubgroupID(std::vector<const void*> datas,
    334 								   const deUint32 numWorkgroups[3], const deUint32 localSize[3],
    335 								   deUint32)
    336 {
    337 	const deUint32* data = reinterpret_cast<const deUint32*>(datas[0]);
    338 
    339 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
    340 	{
    341 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
    342 		{
    343 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
    344 			{
    345 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
    346 				{
    347 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
    348 					{
    349 						for (deUint32 lZ = 0; lZ < localSize[2];
    350 								++lZ)
    351 						{
    352 							const deUint32 globalInvocationX =
    353 								nX * localSize[0] + lX;
    354 							const deUint32 globalInvocationY =
    355 								nY * localSize[1] + lY;
    356 							const deUint32 globalInvocationZ =
    357 								nZ * localSize[2] + lZ;
    358 
    359 							const deUint32 globalSizeX =
    360 								numWorkgroups[0] * localSize[0];
    361 							const deUint32 globalSizeY =
    362 								numWorkgroups[1] * localSize[1];
    363 
    364 							const deUint32 offset =
    365 								globalSizeX *
    366 								((globalSizeY *
    367 								  globalInvocationZ) +
    368 								 globalInvocationY) +
    369 								globalInvocationX;
    370 
    371 							deUint32 numSubgroups = data[(offset * 4) + 2];
    372 							deUint32 subgroupID = data[(offset * 4) + 3];
    373 
    374 							if (subgroupID >= numSubgroups)
    375 							{
    376 								return false;
    377 							}
    378 						}
    379 					}
    380 				}
    381 			}
    382 		}
    383 	}
    384 
    385 	return true;
    386 }
    387 
    388 namespace
    389 {
    390 struct CaseDefinition
    391 {
    392 	std::string varName;
    393 	VkShaderStageFlags shaderStage;
    394 	bool noSSBO;
    395 };
    396 }
    397 
    398 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
    399 {
    400 	std::ostringstream src;
    401 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    402 	{
    403 		src << "#version 450\n"
    404 			<< "#extension GL_KHR_shader_subgroup_basic: enable\n"
    405 			<< "layout(location = 0) out vec4 out_color;\n"
    406 			<< "layout(location = 0) in highp vec4 in_position;\n"
    407 			<< "\n"
    408 			<< "void main (void)\n"
    409 			<< "{\n"
    410 			<< "  out_color = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 1.0f, 1.0f);\n"
    411 			<< "  gl_Position = in_position;\n"
    412 			<< "  gl_PointSize = 1.0f;\n"
    413 			<< "}\n";
    414 
    415 		programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    416 
    417 		std::ostringstream source;
    418 		source	<< glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    419 				<< "layout(location = 0) in vec4 in_color;\n"
    420 				<< "layout(location = 0) out uvec4 out_color;\n"
    421 				<< "void main()\n"
    422 				<<"{\n"
    423 				<< "	out_color = uvec4(in_color);\n"
    424 				<< "}\n";
    425 		programCollection.glslSources.add("fragment") << glu::FragmentSource(source.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    426 	}
    427 	else
    428 	{
    429 		DE_FATAL("Unsupported shader stage");
    430 	}
    431 }
    432 
    433 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
    434 {
    435 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    436 	{
    437 		std::ostringstream src;
    438 
    439 		src << "#version 450\n"
    440 			<< "#extension GL_KHR_shader_subgroup_basic: enable\n"
    441 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
    442 			"local_size_z_id = 2) in;\n"
    443 			<< "layout(set = 0, binding = 0, std430) buffer Output\n"
    444 			<< "{\n"
    445 			<< "  uvec4 result[];\n"
    446 			<< "};\n"
    447 			<< "\n"
    448 			<< "void main (void)\n"
    449 			<< "{\n"
    450 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
    451 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
    452 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
    453 			"gl_GlobalInvocationID.x;\n"
    454 			<< "  result[offset] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, gl_NumSubgroups, gl_SubgroupID);\n"
    455 			<< "}\n";
    456 
    457 		programCollection.glslSources.add("comp")
    458 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    459 	}
    460 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
    461 	{
    462 		programCollection.glslSources.add("vert")
    463 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    464 
    465 		std::ostringstream frag;
    466 
    467 		frag << "#version 450\n"
    468 			 << "#extension GL_KHR_shader_subgroup_basic: enable\n"
    469 			 << "layout(location = 0) out uvec4 data;\n"
    470 			 << "void main (void)\n"
    471 			 << "{\n"
    472 			 << "  data = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
    473 			 << "}\n";
    474 
    475 		programCollection.glslSources.add("frag")
    476 				<< glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    477 	}
    478 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    479 	{
    480 		std::ostringstream src;
    481 
    482 		src << "#version 450\n"
    483 			<< "#extension GL_KHR_shader_subgroup_basic: enable\n"
    484 			<< "layout(set = 0, binding = 0, std430) buffer Output\n"
    485 			<< "{\n"
    486 			<< "  uvec4 result[];\n"
    487 			<< "};\n"
    488 			<< "\n"
    489 			<< "void main (void)\n"
    490 			<< "{\n"
    491 			<< "  result[gl_VertexIndex] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
    492 			<< "  gl_PointSize = 1.0f;\n"
    493 			<< "}\n";
    494 
    495 		programCollection.glslSources.add("vert")
    496 				<< glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    497 	}
    498 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    499 	{
    500 		programCollection.glslSources.add("vert")
    501 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    502 
    503 		std::ostringstream src;
    504 
    505 		src << "#version 450\n"
    506 			<< "#extension GL_KHR_shader_subgroup_basic: enable\n"
    507 			<< "layout(points) in;\n"
    508 			<< "layout(points, max_vertices = 1) out;\n"
    509 			<< "layout(set = 0, binding = 0, std430) buffer Output\n"
    510 			<< "{\n"
    511 			<< "  uvec4 result[];\n"
    512 			<< "};\n"
    513 			<< "\n"
    514 			<< "void main (void)\n"
    515 			<< "{\n"
    516 			<< "  result[gl_PrimitiveIDIn] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
    517 			<< "}\n";
    518 
    519 		programCollection.glslSources.add("geom")
    520 				<< glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    521 	}
    522 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    523 	{
    524 		programCollection.glslSources.add("vert")
    525 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    526 
    527 		programCollection.glslSources.add("tese")
    528 				<< glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
    529 
    530 		std::ostringstream src;
    531 
    532 		src << "#version 450\n"
    533 			<< "#extension GL_KHR_shader_subgroup_basic: enable\n"
    534 			<< "layout(vertices=1) out;\n"
    535 			<< "layout(set = 0, binding = 0, std430) buffer Output\n"
    536 			<< "{\n"
    537 			<< "  uvec4 result[];\n"
    538 			<< "};\n"
    539 			<< "\n"
    540 			<< "void main (void)\n"
    541 			<< "{\n"
    542 			<< "  result[gl_PrimitiveID] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
    543 			<< "}\n";
    544 
    545 		programCollection.glslSources.add("tesc")
    546 				<< glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    547 	}
    548 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    549 	{
    550 		programCollection.glslSources.add("vert")
    551 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    552 
    553 		programCollection.glslSources.add("tesc")
    554 				<< glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
    555 
    556 		std::ostringstream src;
    557 
    558 		src << "#version 450\n"
    559 			<< "#extension GL_KHR_shader_subgroup_basic: enable\n"
    560 			<< "layout(isolines) in;\n"
    561 			<< "layout(set = 0, binding = 0, std430) buffer Output\n"
    562 			<< "{\n"
    563 			<< "  uvec4 result[];\n"
    564 			<< "};\n"
    565 			<< "\n"
    566 			<< "void main (void)\n"
    567 			<< "{\n"
    568 			<< "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = uvec4(gl_SubgroupSize, gl_SubgroupInvocationID, 0, 0);\n"
    569 			<< "}\n";
    570 
    571 		programCollection.glslSources.add("tese")
    572 				<< glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    573 	}
    574 	else
    575 	{
    576 		DE_FATAL("Unsupported shader stage");
    577 	}
    578 }
    579 
    580 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
    581 {
    582 	if (!subgroups::isSubgroupSupported(context))
    583 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
    584 
    585 	if (!areSubgroupOperationsSupportedForStage(
    586 				context, caseDef.shaderStage))
    587 	{
    588 		if (areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
    589 		{
    590 			return tcu::TestStatus::fail(
    591 					   "Shader stage " + getShaderStageName(caseDef.shaderStage) +
    592 					   " is required to support subgroup operations!");
    593 		}
    594 		else
    595 		{
    596 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
    597 		}
    598 	}
    599 
    600 	//Tests which don't use the SSBO
    601 	if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    602 	{
    603 		if ("gl_SubgroupSize" == caseDef.varName)
    604 		{
    605 			return makeVertexFrameBufferTest(
    606 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
    607 		}
    608 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    609 		{
    610 			return makeVertexFrameBufferTest(
    611 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
    612 		}
    613 	}
    614 
    615 	if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
    616 			(VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
    617 	{
    618 		if (!subgroups::isVertexSSBOSupportedForDevice(context))
    619 		{
    620 			TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
    621 		}
    622 	}
    623 
    624 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    625 	{
    626 		if ("gl_SubgroupSize" == caseDef.varName)
    627 		{
    628 			return makeComputeTest(
    629 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupSize);
    630 		}
    631 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    632 		{
    633 			return makeComputeTest(
    634 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupInvocationID);
    635 		}
    636 		else if ("gl_NumSubgroups" == caseDef.varName)
    637 		{
    638 			return makeComputeTest(
    639 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeNumSubgroups);
    640 		}
    641 		else if ("gl_SubgroupID" == caseDef.varName)
    642 		{
    643 			return makeComputeTest(
    644 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkComputeSubgroupID);
    645 		}
    646 		else
    647 		{
    648 			return tcu::TestStatus::fail(
    649 					   caseDef.varName + " failed (unhandled error checking case " +
    650 					   caseDef.varName + ")!");
    651 		}
    652 	}
    653 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
    654 	{
    655 		if ("gl_SubgroupSize" == caseDef.varName)
    656 		{
    657 			return makeFragmentTest(
    658 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkFragmentSubgroupSize);
    659 		}
    660 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    661 		{
    662 			return makeFragmentTest(
    663 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkFragmentSubgroupInvocationID);
    664 		}
    665 		else
    666 		{
    667 			return tcu::TestStatus::fail(
    668 					   caseDef.varName + " failed (unhandled error checking case " +
    669 					   caseDef.varName + ")!");
    670 		}
    671 	}
    672 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    673 	{
    674 		if ("gl_SubgroupSize" == caseDef.varName)
    675 		{
    676 			return makeVertexTest(
    677 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
    678 		}
    679 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    680 		{
    681 			return makeVertexTest(
    682 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
    683 		}
    684 		else
    685 		{
    686 			return tcu::TestStatus::fail(
    687 					   caseDef.varName + " failed (unhandled error checking case " +
    688 					   caseDef.varName + ")!");
    689 		}
    690 	}
    691 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    692 	{
    693 		if ("gl_SubgroupSize" == caseDef.varName)
    694 		{
    695 			return makeGeometryTest(
    696 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
    697 		}
    698 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    699 		{
    700 			return makeGeometryTest(
    701 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
    702 		}
    703 		else
    704 		{
    705 			return tcu::TestStatus::fail(
    706 					   caseDef.varName + " failed (unhandled error checking case " +
    707 					   caseDef.varName + ")!");
    708 		}
    709 	}
    710 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    711 	{
    712 		if ("gl_SubgroupSize" == caseDef.varName)
    713 		{
    714 			return makeTessellationControlTest(
    715 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
    716 		}
    717 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    718 		{
    719 			return makeTessellationControlTest(
    720 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
    721 		}
    722 		else
    723 		{
    724 			return tcu::TestStatus::fail(
    725 					   caseDef.varName + " failed (unhandled error checking case " +
    726 					   caseDef.varName + ")!");
    727 		}
    728 	}
    729 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    730 	{
    731 		if ("gl_SubgroupSize" == caseDef.varName)
    732 		{
    733 			return makeTessellationEvaluationTest(
    734 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupSize);
    735 		}
    736 		else if ("gl_SubgroupInvocationID" == caseDef.varName)
    737 		{
    738 			return makeTessellationEvaluationTest(
    739 					   context, VK_FORMAT_R32G32B32A32_UINT, DE_NULL, 0, checkVertexPipelineStagesSubgroupInvocationID);
    740 		}
    741 		else
    742 		{
    743 			return tcu::TestStatus::fail(
    744 					   caseDef.varName + " failed (unhandled error checking case " +
    745 					   caseDef.varName + ")!");
    746 		}
    747 	}
    748 	else
    749 	{
    750 		TCU_THROW(InternalError, "Unhandled shader stage");
    751 	}
    752 }
    753 
    754 tcu::TestCaseGroup* createSubgroupsBuiltinVarTests(tcu::TestContext& testCtx)
    755 {
    756 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
    757 			testCtx, "builtin_var", "Subgroup builtin variable tests"));
    758 
    759 	const char* const all_stages_vars[] =
    760 	{
    761 		"SubgroupSize",
    762 		"SubgroupInvocationID"
    763 	};
    764 
    765 	const char* const compute_only_vars[] =
    766 	{
    767 		"NumSubgroups",
    768 		"SubgroupID"
    769 	};
    770 
    771 	const VkShaderStageFlags stages[] =
    772 	{
    773 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
    774 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
    775 		VK_SHADER_STAGE_GEOMETRY_BIT,
    776 		VK_SHADER_STAGE_VERTEX_BIT,
    777 		VK_SHADER_STAGE_FRAGMENT_BIT,
    778 		VK_SHADER_STAGE_COMPUTE_BIT,
    779 	};
    780 
    781 	for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
    782 	{
    783 		const VkShaderStageFlags stage = stages[stageIndex];
    784 
    785 		for (int a = 0; a < DE_LENGTH_OF_ARRAY(all_stages_vars); ++a)
    786 		{
    787 			const std::string var = all_stages_vars[a];
    788 
    789 			CaseDefinition caseDef = {"gl_" + var, stage, false};
    790 
    791 			addFunctionCaseWithPrograms(group.get(),
    792 										de::toLower(var) + "_" +
    793 										getShaderStageName(stage), "",
    794 										initPrograms, test, caseDef);
    795 
    796 			if (VK_SHADER_STAGE_VERTEX_BIT == stage)
    797 			{
    798 				caseDef.noSSBO = true;
    799 				addFunctionCaseWithPrograms(group.get(),
    800 							de::toLower(var) + "_" +
    801 							getShaderStageName(stage)+"_framebuffer", "",
    802 							initFrameBufferPrograms, test, caseDef);
    803 			}
    804 		}
    805 	}
    806 
    807 	for (int a = 0; a < DE_LENGTH_OF_ARRAY(compute_only_vars); ++a)
    808 	{
    809 		const VkShaderStageFlags stage = VK_SHADER_STAGE_COMPUTE_BIT;
    810 		const std::string var = compute_only_vars[a];
    811 
    812 		CaseDefinition caseDef = {"gl_" + var, stage, false};
    813 
    814 		addFunctionCaseWithPrograms(group.get(), de::toLower(var) +
    815 									"_" + getShaderStageName(stage), "",
    816 									initPrograms, test, caseDef);
    817 	}
    818 
    819 	return group.release();
    820 }
    821 
    822 } // subgroups
    823 } // vkt
    824