Home | History | Annotate | Download | only in subgroups
      1 /*------------------------------------------------------------------------
      2  * Vulkan Conformance Tests
      3  * ------------------------
      4  *
      5  * Copyright (c) 2017 The Khronos Group Inc.
      6  * Copyright (c) 2017 Codeplay Software Ltd.
      7  *
      8  * Licensed under the Apache License, Version 2.0 (the "License");
      9  * you may not use this file except in compliance with the License.
     10  * You may obtain a copy of the License at
     11  *
     12  *      http://www.apache.org/licenses/LICENSE-2.0
     13  *
     14  * Unless required by applicable law or agreed to in writing, software
     15  * distributed under the License is distributed on an "AS IS" BASIS,
     16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     17  * See the License for the specific language governing permissions and
     18  * limitations under the License.
     19  *
     20  */ /*!
     21  * \file
     22  * \brief Subgroups Tests
     23  */ /*--------------------------------------------------------------------*/
     24 
     25 #include "vktSubgroupsBallotBroadcastTests.hpp"
     26 #include "vktSubgroupsTestsUtils.hpp"
     27 
     28 #include <string>
     29 #include <vector>
     30 
     31 using namespace tcu;
     32 using namespace std;
     33 using namespace vk;
     34 using namespace vkt;
     35 
     36 namespace
     37 {
     38 enum OpType
     39 {
     40 	OPTYPE_BROADCAST = 0,
     41 	OPTYPE_BROADCAST_FIRST,
     42 	OPTYPE_LAST
     43 };
     44 
     45 static bool checkVertexPipelineStages(std::vector<const void*> datas,
     46 									  deUint32 width, deUint32)
     47 {
     48 	const deUint32* data =
     49 		reinterpret_cast<const deUint32*>(datas[0]);
     50 	for (deUint32 x = 0; x < width; ++x)
     51 	{
     52 		deUint32 val = data[x];
     53 
     54 		if (0x3 != val)
     55 		{
     56 			return false;
     57 		}
     58 	}
     59 
     60 	return true;
     61 }
     62 
     63 static bool checkFragment(std::vector<const void*> datas,
     64 						  deUint32 width, deUint32 height, deUint32)
     65 {
     66 	const deUint32* data =
     67 		reinterpret_cast<const deUint32*>(datas[0]);
     68 	for (deUint32 x = 0; x < width; ++x)
     69 	{
     70 		for (deUint32 y = 0; y < height; ++y)
     71 		{
     72 			deUint32 val = data[x * height + y];
     73 
     74 			if (0x3 != val)
     75 			{
     76 				return false;
     77 			}
     78 		}
     79 	}
     80 
     81 	return true;
     82 }
     83 
     84 static bool checkCompute(std::vector<const void*> datas,
     85 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
     86 						 deUint32)
     87 {
     88 	const deUint32* data =
     89 		reinterpret_cast<const deUint32*>(datas[0]);
     90 
     91 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
     92 	{
     93 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
     94 		{
     95 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
     96 			{
     97 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
     98 				{
     99 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
    100 					{
    101 						for (deUint32 lZ = 0; lZ < localSize[2];
    102 								++lZ)
    103 						{
    104 							const deUint32 globalInvocationX =
    105 								nX * localSize[0] + lX;
    106 							const deUint32 globalInvocationY =
    107 								nY * localSize[1] + lY;
    108 							const deUint32 globalInvocationZ =
    109 								nZ * localSize[2] + lZ;
    110 
    111 							const deUint32 globalSizeX =
    112 								numWorkgroups[0] * localSize[0];
    113 							const deUint32 globalSizeY =
    114 								numWorkgroups[1] * localSize[1];
    115 
    116 							const deUint32 offset =
    117 								globalSizeX *
    118 								((globalSizeY *
    119 								  globalInvocationZ) +
    120 								 globalInvocationY) +
    121 								globalInvocationX;
    122 
    123 							if (0x3 != data[offset])
    124 							{
    125 								return false;
    126 							}
    127 						}
    128 					}
    129 				}
    130 			}
    131 		}
    132 	}
    133 
    134 	return true;
    135 }
    136 
    137 
    138 std::string getOpTypeName(int opType)
    139 {
    140 	switch (opType)
    141 	{
    142 		default:
    143 			DE_FATAL("Unsupported op type");
    144 		case OPTYPE_BROADCAST:
    145 			return "subgroupBroadcast";
    146 		case OPTYPE_BROADCAST_FIRST:
    147 			return "subgroupBroadcastFirst";
    148 	}
    149 }
    150 
    151 
    152 struct CaseDefinition
    153 {
    154 	int					opType;
    155 	VkShaderStageFlags	shaderStage;
    156 	VkFormat			format;
    157 	bool				noSSBO;
    158 };
    159 
    160 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
    161 {
    162 	std::ostringstream bdy;
    163 
    164 	bdy << "  uint tempResult = 0;\n";
    165 
    166 	if (OPTYPE_BROADCAST == caseDef.opType)
    167 	{
    168 		bdy << "  tempResult = 0x3;\n";
    169 
    170 		for (deUint32 i = 0; i < subgroups::maxSupportedSubgroupSize(); i++)
    171 		{
    172 			bdy	<< "  {\n"
    173 				<< "    const uint id = " << i << ";\n"
    174 				<< "    " << subgroups::getFormatNameForGLSL(caseDef.format)
    175 				<< " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
    176 				<< "    if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
    177 				<< "    {\n"
    178 				<< "      if (op != data1[id])\n"
    179 				<< "      {\n"
    180 				<< "        tempResult = 0;\n"
    181 				<< "      }\n"
    182 				<< "    }\n"
    183 				<< "  }\n";
    184 		}
    185 	}
    186 	else
    187 	{
    188 		bdy	<< "  uint firstActive = 0;\n"
    189 			<< "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
    190 			<< "  {\n"
    191 			<< "    if (subgroupBallotBitExtract(mask, i))\n"
    192 			<< "    {\n"
    193 			<< "      firstActive = i;\n"
    194 			<< "      break;\n"
    195 			<< "    }\n"
    196 			<< "  }\n"
    197 			<< "  tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
    198 			<< "  // make the firstActive invocation inactive now\n"
    199 			<< "  if (firstActive == gl_SubgroupInvocationID)\n"
    200 			<< "  {\n"
    201 			<< "    for (uint i = 0; i < gl_SubgroupSize; i++)\n"
    202 			<< "    {\n"
    203 			<< "      if (subgroupBallotBitExtract(mask, i))\n"
    204 			<< "      {\n"
    205 			<< "        firstActive = i;\n"
    206 			<< "        break;\n"
    207 			<< "      }\n"
    208 			<< "    }\n"
    209 			<< "    tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
    210 			<< "  }\n"
    211 			<< "  else\n"
    212 			<< "  {\n"
    213 			<< "    // the firstActive invocation didn't partake in the second result so set it to true\n"
    214 			<< "    tempResult |= 0x2;\n"
    215 			<< "  }\n";
    216 	}
    217 
    218 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    219 	{
    220 		std::ostringstream src;
    221 		std::ostringstream	fragmentSrc;
    222 
    223 		src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    224 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    225 			<< "layout(location = 0) in highp vec4 in_position;\n"
    226 			<< "layout(location = 0) out float out_color;\n"
    227 			<< "layout(set = 0, binding = 0) uniform  Buffer1\n"
    228 			<< "{\n"
    229 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    230 			<< "};\n"
    231 			<< "\n"
    232 			<< "void main (void)\n"
    233 			<< "{\n"
    234 			<< "  uvec4 mask = subgroupBallot(true);\n"
    235 			<< bdy.str()
    236 			<< "  out_color = float(tempResult);\n"
    237 			<< "  gl_Position = in_position;\n"
    238 			<< "  gl_PointSize = 1.0f;\n"
    239 			<< "}\n";
    240 
    241 		programCollection.glslSources.add("vert")
    242 				<< glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    243 
    244 		fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    245 			<< "layout(location = 0) in float in_color;\n"
    246 			<< "layout(location = 0) out uint out_color;\n"
    247 			<< "void main()\n"
    248 			<<"{\n"
    249 			<< "	out_color = uint(in_color);\n"
    250 			<< "}\n";
    251 		programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    252 	}
    253 	else
    254 	{
    255 		DE_FATAL("Unsupported shader stage");
    256 	}
    257 }
    258 
    259 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
    260 {
    261 	std::ostringstream bdy;
    262 
    263 	bdy << "  uint tempResult = 0;\n";
    264 
    265 	if (OPTYPE_BROADCAST == caseDef.opType)
    266 	{
    267 		bdy << "  tempResult = 0x3;\n";
    268 
    269 		for (deUint32 i = 0; i < subgroups::maxSupportedSubgroupSize(); i++)
    270 		{
    271 			bdy	<< "  {\n"
    272 				<< "    const uint id = " << i << ";\n"
    273 				<< "    " << subgroups::getFormatNameForGLSL(caseDef.format)
    274 				<< " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
    275 				<< "    if ((0 <= id) && (id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
    276 				<< "    {\n"
    277 				<< "      if (op != data1[id])\n"
    278 				<< "      {\n"
    279 				<< "        tempResult = 0;\n"
    280 				<< "      }\n"
    281 				<< "    }\n"
    282 				<< "  }\n";
    283 		}
    284 	}
    285 	else
    286 	{
    287 		bdy	<< "  uint firstActive = 0;\n"
    288 			<< "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
    289 			<< "  {\n"
    290 			<< "    if (subgroupBallotBitExtract(mask, i))\n"
    291 			<< "    {\n"
    292 			<< "      firstActive = i;\n"
    293 			<< "      break;\n"
    294 			<< "    }\n"
    295 			<< "  }\n"
    296 			<< "  tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
    297 			<< "  // make the firstActive invocation inactive now\n"
    298 			<< "  if (firstActive == gl_SubgroupInvocationID)\n"
    299 			<< "  {\n"
    300 			<< "    for (uint i = 0; i < gl_SubgroupSize; i++)\n"
    301 			<< "    {\n"
    302 			<< "      if (subgroupBallotBitExtract(mask, i))\n"
    303 			<< "      {\n"
    304 			<< "        firstActive = i;\n"
    305 			<< "        break;\n"
    306 			<< "      }\n"
    307 			<< "    }\n"
    308 			<< "    tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
    309 			<< "  }\n"
    310 			<< "  else\n"
    311 			<< "  {\n"
    312 			<< "    // the firstActive invocation didn't partake in the second result so set it to true\n"
    313 			<< "    tempResult |= 0x2;\n"
    314 			<< "  }\n";
    315 	}
    316 
    317 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    318 	{
    319 		std::ostringstream src;
    320 
    321 		src << "#version 450\n"
    322 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    323 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
    324 			"local_size_z_id = 2) in;\n"
    325 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    326 			<< "{\n"
    327 			<< "  uint result[];\n"
    328 			<< "};\n"
    329 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    330 			<< "{\n"
    331 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
    332 			<< "};\n"
    333 			<< "\n"
    334 			<< "void main (void)\n"
    335 			<< "{\n"
    336 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
    337 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
    338 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
    339 			"gl_GlobalInvocationID.x;\n"
    340 			<< "  uvec4 mask = subgroupBallot(true);\n"
    341 			<< bdy.str()
    342 			<< "  result[offset] = tempResult;\n"
    343 			<< "}\n";
    344 
    345 		programCollection.glslSources.add("comp")
    346 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    347 	}
    348 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
    349 	{
    350 		programCollection.glslSources.add("vert")
    351 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    352 
    353 		std::ostringstream frag;
    354 
    355 		frag << "#version 450\n"
    356 			 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    357 			 << "layout(location = 0) out uint result;\n"
    358 			 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer1\n"
    359 			 << "{\n"
    360 			 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
    361 			 << "};\n"
    362 			 << "void main (void)\n"
    363 			 << "{\n"
    364 			 << "  uvec4 mask = subgroupBallot(true);\n"
    365 			 << bdy.str()
    366 			 << "  result = tempResult;\n"
    367 			 << "}\n";
    368 
    369 		programCollection.glslSources.add("frag")
    370 				<< glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    371 	}
    372 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    373 	{
    374 		std::ostringstream src;
    375 
    376 		src << "#version 450\n"
    377 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    378 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    379 			<< "{\n"
    380 			<< "  uint result[];\n"
    381 			<< "};\n"
    382 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    383 			<< "{\n"
    384 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
    385 			<< "};\n"
    386 			<< "\n"
    387 			<< "void main (void)\n"
    388 			<< "{\n"
    389 			<< "  uvec4 mask = subgroupBallot(true);\n"
    390 			<< bdy.str()
    391 			<< "  result[gl_VertexIndex] = tempResult;\n"
    392 			<< "  gl_PointSize = 1.0f;\n"
    393 			<< "}\n";
    394 
    395 		programCollection.glslSources.add("vert")
    396 				<< glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    397 	}
    398 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    399 	{
    400 		programCollection.glslSources.add("vert")
    401 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    402 
    403 		std::ostringstream src;
    404 
    405 		src << "#version 450\n"
    406 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    407 			<< "layout(points) in;\n"
    408 			<< "layout(points, max_vertices = 1) out;\n"
    409 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    410 			<< "{\n"
    411 			<< "  uint result[];\n"
    412 			<< "};\n"
    413 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    414 			<< "{\n"
    415 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
    416 			<< "};\n"
    417 			<< "\n"
    418 			<< "void main (void)\n"
    419 			<< "{\n"
    420 			<< "  uvec4 mask = subgroupBallot(true);\n"
    421 			<< bdy.str()
    422 			<< "  result[gl_PrimitiveIDIn] = tempResult;\n"
    423 			<< "}\n";
    424 
    425 		programCollection.glslSources.add("geom")
    426 				<< glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    427 	}
    428 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    429 	{
    430 		programCollection.glslSources.add("vert")
    431 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    432 
    433 		programCollection.glslSources.add("tese")
    434 				<< glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
    435 
    436 		std::ostringstream src;
    437 
    438 		src << "#version 450\n"
    439 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    440 			<< "layout(vertices=1) out;\n"
    441 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    442 			<< "{\n"
    443 			<< "  uint result[];\n"
    444 			<< "};\n"
    445 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    446 			<< "{\n"
    447 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
    448 			<< "};\n"
    449 			<< "\n"
    450 			<< "void main (void)\n"
    451 			<< "{\n"
    452 			<< "  uvec4 mask = subgroupBallot(true);\n"
    453 			<< bdy.str()
    454 			<< "  result[gl_PrimitiveID] = tempResult;\n"
    455 			<< "}\n";
    456 
    457 		programCollection.glslSources.add("tesc")
    458 				<< glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    459 	}
    460 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    461 	{
    462 		programCollection.glslSources.add("vert")
    463 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    464 
    465 		programCollection.glslSources.add("tesc")
    466 				<< glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
    467 
    468 		std::ostringstream src;
    469 
    470 		src << "#version 450\n"
    471 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    472 			<< "layout(isolines) in;\n"
    473 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    474 			<< "{\n"
    475 			<< "  uint result[];\n"
    476 			<< "};\n"
    477 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    478 			<< "{\n"
    479 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
    480 			<< "};\n"
    481 			<< "\n"
    482 			<< "void main (void)\n"
    483 			<< "{\n"
    484 			<< "  uvec4 mask = subgroupBallot(true);\n"
    485 			<< bdy.str()
    486 			<< "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
    487 			<< "}\n";
    488 
    489 		programCollection.glslSources.add("tese")
    490 				<< glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    491 	}
    492 	else
    493 	{
    494 		DE_FATAL("Unsupported shader stage");
    495 	}
    496 }
    497 
    498 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
    499 {
    500 	if (!subgroups::isSubgroupSupported(context))
    501 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
    502 
    503 	if (!subgroups::areSubgroupOperationsSupportedForStage(
    504 				context, caseDef.shaderStage))
    505 	{
    506 		if (subgroups::areSubgroupOperationsRequiredForStage(
    507 					caseDef.shaderStage))
    508 		{
    509 			return tcu::TestStatus::fail(
    510 					   "Shader stage " +
    511 					   subgroups::getShaderStageName(caseDef.shaderStage) +
    512 					   " is required to support subgroup operations!");
    513 		}
    514 		else
    515 		{
    516 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
    517 		}
    518 	}
    519 
    520 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
    521 	{
    522 		TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
    523 	}
    524 
    525 	if (subgroups::isDoubleFormat(caseDef.format) &&
    526 			!subgroups::isDoubleSupportedForDevice(context))
    527 	{
    528 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
    529 	}
    530 
    531 	//Tests which don't use the SSBO
    532 	if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    533 	{
    534 		subgroups::SSBOData inputData[1];
    535 		inputData[0].format = caseDef.format;
    536 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    537 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    538 
    539 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages);
    540 	}
    541 
    542 	if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
    543 			(VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
    544 	{
    545 		if (!subgroups::isVertexSSBOSupportedForDevice(context))
    546 		{
    547 			TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
    548 		}
    549 	}
    550 
    551 	if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
    552 	{
    553 		subgroups::SSBOData inputData[1];
    554 		inputData[0].format = caseDef.format;
    555 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    556 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    557 
    558 		return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
    559 										   inputData, 1, checkFragment);
    560 	}
    561 	else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    562 	{
    563 		subgroups::SSBOData inputData[1];
    564 		inputData[0].format = caseDef.format;
    565 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    566 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    567 
    568 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT,
    569 										  inputData, 1, checkCompute);
    570 	}
    571 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    572 	{
    573 		subgroups::SSBOData inputData[1];
    574 		inputData[0].format = caseDef.format;
    575 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    576 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    577 
    578 		return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT,
    579 										 inputData, 1, checkVertexPipelineStages);
    580 	}
    581 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    582 	{
    583 		subgroups::SSBOData inputData[1];
    584 		inputData[0].format = caseDef.format;
    585 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    586 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    587 
    588 		return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT,
    589 										   inputData, 1, checkVertexPipelineStages);
    590 	}
    591 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    592 	{
    593 		subgroups::SSBOData inputData[1];
    594 		inputData[0].format = caseDef.format;
    595 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    596 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    597 
    598 		return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT,
    599 				inputData, 1, checkVertexPipelineStages);
    600 	}
    601 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    602 	{
    603 		subgroups::SSBOData inputData[1];
    604 		inputData[0].format = caseDef.format;
    605 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    606 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    607 
    608 		return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT,
    609 				inputData, 1, checkVertexPipelineStages);
    610 	}
    611 	else
    612 	{
    613 		TCU_THROW(InternalError, "Unhandled shader stage");
    614 	}
    615 }
    616 }
    617 
    618 namespace vkt
    619 {
    620 namespace subgroups
    621 {
    622 tcu::TestCaseGroup* createSubgroupsBallotBroadcastTests(tcu::TestContext& testCtx)
    623 {
    624 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
    625 			testCtx, "ballot_broadcast", "Subgroup ballot broadcast category tests"));
    626 
    627 	const VkShaderStageFlags stages[] =
    628 	{
    629 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
    630 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
    631 		VK_SHADER_STAGE_GEOMETRY_BIT,
    632 		VK_SHADER_STAGE_VERTEX_BIT,
    633 		VK_SHADER_STAGE_FRAGMENT_BIT,
    634 		VK_SHADER_STAGE_COMPUTE_BIT
    635 	};
    636 
    637 	const VkFormat formats[] =
    638 	{
    639 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
    640 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
    641 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
    642 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
    643 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
    644 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
    645 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
    646 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
    647 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
    648 	};
    649 
    650 	for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
    651 	{
    652 		const VkShaderStageFlags stage = stages[stageIndex];
    653 
    654 		for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
    655 		{
    656 			const VkFormat format = formats[formatIndex];
    657 
    658 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
    659 			{
    660 				CaseDefinition caseDef = {opTypeIndex, stage, format, false};
    661 
    662 				std::ostringstream name;
    663 
    664 				std::string op = getOpTypeName(opTypeIndex);
    665 
    666 				name << de::toLower(op) << "_" << subgroups::getFormatNameForGLSL(format)
    667 					  << "_" << getShaderStageName(stage);
    668 
    669 				addFunctionCaseWithPrograms(group.get(), name.str(),
    670 											"", initPrograms, test, caseDef);
    671 
    672 				if (VK_SHADER_STAGE_VERTEX_BIT == stage )
    673 				{
    674 					caseDef.noSSBO = true;
    675 					addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "",
    676 								initFrameBufferPrograms, test, caseDef);
    677 				}
    678 			}
    679 		}
    680 	}
    681 
    682 	return group.release();
    683 }
    684 
    685 } // subgroups
    686 } // vkt
    687