Home | History | Annotate | Download | only in subgroups
      1 /*------------------------------------------------------------------------
      2  * Vulkan Conformance Tests
      3  * ------------------------
      4  *
      5  * Copyright (c) 2017 The Khronos Group Inc.
      6  * Copyright (c) 2017 Codeplay Software Ltd.
      7  *
      8  * Licensed under the Apache License, Version 2.0 (the "License");
      9  * you may not use this file except in compliance with the License.
     10  * You may obtain a copy of the License at
     11  *
     12  *      http://www.apache.org/licenses/LICENSE-2.0
     13  *
     14  * Unless required by applicable law or agreed to in writing, software
     15  * distributed under the License is distributed on an "AS IS" BASIS,
     16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     17  * See the License for the specific language governing permissions and
     18  * limitations under the License.
     19  *
     20  */ /*!
     21  * \file
     22  * \brief Subgroups Tests
     23  */ /*--------------------------------------------------------------------*/
     24 
     25 #include "vktSubgroupsShuffleTests.hpp"
     26 #include "vktSubgroupsTestsUtils.hpp"
     27 
     28 #include <string>
     29 #include <vector>
     30 
     31 using namespace tcu;
     32 using namespace std;
     33 using namespace vk;
     34 using namespace vkt;
     35 
     36 namespace
     37 {
     38 enum OpType
     39 {
     40 	OPTYPE_SHUFFLE = 0,
     41 	OPTYPE_SHUFFLE_XOR,
     42 	OPTYPE_SHUFFLE_UP,
     43 	OPTYPE_SHUFFLE_DOWN,
     44 	OPTYPE_LAST
     45 };
     46 
     47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
     48 									  deUint32 width, deUint32)
     49 {
     50 	return vkt::subgroups::check(datas, width, 1);
     51 }
     52 
     53 static bool checkCompute(std::vector<const void*> datas,
     54 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
     55 						 deUint32)
     56 {
     57 	return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
     58 }
     59 
     60 std::string getOpTypeName(int opType)
     61 {
     62 	switch (opType)
     63 	{
     64 		default:
     65 			DE_FATAL("Unsupported op type");
     66 			return "";
     67 		case OPTYPE_SHUFFLE:
     68 			return "subgroupShuffle";
     69 		case OPTYPE_SHUFFLE_XOR:
     70 			return "subgroupShuffleXor";
     71 		case OPTYPE_SHUFFLE_UP:
     72 			return "subgroupShuffleUp";
     73 		case OPTYPE_SHUFFLE_DOWN:
     74 			return "subgroupShuffleDown";
     75 	}
     76 }
     77 
     78 struct CaseDefinition
     79 {
     80 	int					opType;
     81 	VkShaderStageFlags	shaderStage;
     82 	VkFormat			format;
     83 };
     84 
     85 const std::string to_string(int x) {
     86 	std::ostringstream oss;
     87 	oss << x;
     88 	return oss.str();
     89 }
     90 
     91 const std::string DeclSource(CaseDefinition caseDef, int baseBinding)
     92 {
     93 	return
     94 		"layout(set = 0, binding = " + to_string(baseBinding) + ", std430) readonly buffer Buffer2\n"
     95 		"{\n"
     96 		"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data1[];\n"
     97 		"};\n"
     98 		"layout(set = 0, binding = " + to_string(baseBinding + 1) + ", std430) readonly buffer Buffer3\n"
     99 		"{\n"
    100 		"  uint data2[];\n"
    101 		"};\n";
    102 }
    103 
    104 const std::string TestSource(CaseDefinition caseDef)
    105 {
    106 	std::string						idTable[OPTYPE_LAST];
    107 	idTable[OPTYPE_SHUFFLE]			= "id_in";
    108 	idTable[OPTYPE_SHUFFLE_XOR]		= "gl_SubgroupInvocationID ^ id_in";
    109 	idTable[OPTYPE_SHUFFLE_UP]		= "gl_SubgroupInvocationID - id_in";
    110 	idTable[OPTYPE_SHUFFLE_DOWN]	= "gl_SubgroupInvocationID + id_in";
    111 
    112 	const std::string testSource =
    113 		"  uint temp_res;\n"
    114 		"  uvec4 mask = subgroupBallot(true);\n"
    115 		"  uint id_in = data2[gl_SubgroupInvocationID] & (gl_SubgroupSize - 1);\n"
    116 		"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " op = "
    117 		+ getOpTypeName(caseDef.opType) + "(data1[gl_SubgroupInvocationID], id_in);\n"
    118 		"  uint id = " + idTable[caseDef.opType] + ";\n"
    119 		"  if ((id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
    120 		"  {\n"
    121 		"    temp_res = (op == data1[id]) ? 1 : 0;\n"
    122 		"  }\n"
    123 		"  else\n"
    124 		"  {\n"
    125 		"    temp_res = 1; // Invocation we read from was inactive, so we can't verify results!\n"
    126 		"  }\n";
    127 
    128 	return testSource;
    129 }
    130 
    131 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
    132 {
    133 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
    134 
    135 	subgroups::setFragmentShaderFrameBuffer(programCollection);
    136 
    137 	if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
    138 		subgroups::setVertexShaderFrameBuffer(programCollection);
    139 
    140 	const std::string extSource =
    141 	(OPTYPE_SHUFFLE == caseDef.opType || OPTYPE_SHUFFLE_XOR == caseDef.opType) ?
    142 		"#extension GL_KHR_shader_subgroup_shuffle: enable\n" :
    143 		"#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
    144 
    145 	const std::string testSource = TestSource(caseDef);
    146 
    147 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    148 	{
    149 		std::ostringstream vertexSrc;
    150 		vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    151 			<< "layout(location = 0) in highp vec4 in_position;\n"
    152 			<< "layout(location = 0) out float result;\n"
    153 			<< extSource
    154 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    155 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
    156 			<< "{\n"
    157 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    158 			<< "};\n"
    159 			<< "layout(set = 0, binding = 1) uniform Buffer2\n"
    160 			<< "{\n"
    161 			<< "  uint data2[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    162 			<< "};\n"
    163 			<< "\n"
    164 			<< "void main (void)\n"
    165 			<< "{\n"
    166 			<< testSource
    167 			<< "  result = temp_res;\n"
    168 			<< "  gl_Position = in_position;\n"
    169 			<< "  gl_PointSize = 1.0f;\n"
    170 			<< "}\n";
    171 		programCollection.glslSources.add("vert")
    172 			<< glu::VertexSource(vertexSrc.str()) << buildOptions;
    173 	}
    174 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    175 	{
    176 		std::ostringstream geometry;
    177 
    178 		geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    179 			<< extSource
    180 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    181 			<< "layout(points) in;\n"
    182 			<< "layout(points, max_vertices = 1) out;\n"
    183 			<< "layout(location = 0) out float out_color;\n"
    184 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
    185 			<< "{\n"
    186 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    187 			<< "};\n"
    188 			<< "layout(set = 0, binding = 1) uniform Buffer2\n"
    189 			<< "{\n"
    190 			<< "  uint data2[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    191 			<< "};\n"
    192 			<< "\n"
    193 			<< "void main (void)\n"
    194 			<< "{\n"
    195 			<< testSource
    196 			<< "  out_color = temp_res;\n"
    197 			<< "  gl_Position = gl_in[0].gl_Position;\n"
    198 			<< "  EmitVertex();\n"
    199 			<< "  EndPrimitive();\n"
    200 			<< "}\n";
    201 
    202 		programCollection.glslSources.add("geometry")
    203 			<< glu::GeometrySource(geometry.str()) << buildOptions;
    204 	}
    205 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    206 	{
    207 		std::ostringstream controlSource;
    208 
    209 		controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    210 			<< extSource
    211 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    212 			<< "layout(vertices = 2) out;\n"
    213 			<< "layout(location = 0) out float out_color[];\n"
    214 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
    215 			<< "{\n"
    216 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    217 			<< "};\n"
    218 			<< "layout(set = 0, binding = 1) uniform Buffer2\n"
    219 			<< "{\n"
    220 			<< "  uint data2[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    221 			<< "};\n"
    222 			<< "\n"
    223 			<< "void main (void)\n"
    224 			<< "{\n"
    225 			<< "  if (gl_InvocationID == 0)\n"
    226 			<<"  {\n"
    227 			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
    228 			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
    229 			<< "  }\n"
    230 			<< testSource
    231 			<< "  out_color[gl_InvocationID] = temp_res;\n"
    232 			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
    233 			<< "}\n";
    234 
    235 		programCollection.glslSources.add("tesc")
    236 			<< glu::TessellationControlSource(controlSource.str()) << buildOptions;
    237 		subgroups::setTesEvalShaderFrameBuffer(programCollection);
    238 
    239 	}
    240 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    241 	{
    242 		std::ostringstream evaluationSource;
    243 		evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    244 			<< extSource
    245 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    246 			<< "layout(isolines, equal_spacing, ccw ) in;\n"
    247 			<< "layout(location = 0) out float out_color;\n"
    248 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
    249 			<< "{\n"
    250 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    251 			<< "};\n"
    252 			<< "layout(set = 0, binding = 1) uniform Buffer2\n"
    253 			<< "{\n"
    254 			<< "  uint data2[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    255 			<< "};\n"
    256 			<< "\n"
    257 			<< "void main (void)\n"
    258 			<< "{\n"
    259 			<< testSource
    260 			<< "  out_color = temp_res;\n"
    261 			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
    262 			<< "}\n";
    263 
    264 		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
    265 		programCollection.glslSources.add("tese")
    266 				<< glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
    267 	}
    268 	else
    269 	{
    270 		DE_FATAL("Unsupported shader stage");
    271 	}
    272 }
    273 
    274 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
    275 {
    276 	const std::string vSource =
    277 		"#version 450\n"
    278 		"#extension GL_KHR_shader_subgroup_ballot: enable\n";
    279 	const std::string eSource =
    280 	(OPTYPE_SHUFFLE == caseDef.opType || OPTYPE_SHUFFLE_XOR == caseDef.opType) ?
    281 		"#extension GL_KHR_shader_subgroup_shuffle: enable\n" :
    282 		"#extension GL_KHR_shader_subgroup_shuffle_relative: enable\n";
    283 	const std::string extSource = vSource + eSource;
    284 
    285 	const std::string testSource = TestSource(caseDef);
    286 
    287 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    288 	{
    289 		std::ostringstream src;
    290 
    291 	src << extSource
    292 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;\n"
    293 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    294 			<< "{\n"
    295 			<< "  uint result[];\n"
    296 			<< "};\n"
    297 			<< DeclSource(caseDef, 1)
    298 			<< "\n"
    299 			<< "void main (void)\n"
    300 			<< "{\n"
    301 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
    302 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
    303 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
    304 			"gl_GlobalInvocationID.x;\n"
    305 			<< testSource
    306 			<< "  result[offset] = temp_res;\n"
    307 			<< "}\n";
    308 
    309 		programCollection.glslSources.add("comp")
    310 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
    311 	}
    312 	else
    313 	{
    314 		const std::string declSource = DeclSource(caseDef, 4);
    315 
    316 		{
    317 			const string vertex =
    318 				extSource +
    319 				"layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    320 				"{\n"
    321 				"  uint result[];\n"
    322 				"};\n"
    323 				+ declSource +
    324 				"\n"
    325 				"void main (void)\n"
    326 				"{\n"
    327 				+ testSource +
    328 				"  result[gl_VertexIndex] = temp_res;\n"
    329 				"  float pixelSize = 2.0f/1024.0f;\n"
    330 				"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
    331 				"  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
    332 				"  gl_PointSize = 1.0f;\n"
    333 				"}\n";
    334 
    335 			programCollection.glslSources.add("vert")
    336 				<< glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
    337 		}
    338 
    339 		{
    340 			const string tesc =
    341 				extSource +
    342 				"layout(vertices=1) out;\n"
    343 				"layout(set = 0, binding = 1, std430)  buffer Buffer1\n"
    344 				"{\n"
    345 				"  uint result[];\n"
    346 				"};\n"
    347 				+ declSource +
    348 				"\n"
    349 				"void main (void)\n"
    350 				"{\n"
    351 				+ testSource +
    352 				"  result[gl_PrimitiveID] = temp_res;\n"
    353 				"  if (gl_InvocationID == 0)\n"
    354 				"  {\n"
    355 				"    gl_TessLevelOuter[0] = 1.0f;\n"
    356 				"    gl_TessLevelOuter[1] = 1.0f;\n"
    357 				"  }\n"
    358 				"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
    359 				"}\n";
    360 
    361 			programCollection.glslSources.add("tesc")
    362 					<< glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
    363 		}
    364 
    365 		{
    366 			const string tese =
    367 				extSource +
    368 				"layout(isolines) in;\n"
    369 				"layout(set = 0, binding = 2, std430) buffer Buffer1\n"
    370 				"{\n"
    371 				"  uint result[];\n"
    372 				"};\n"
    373 				+ declSource +
    374 				"\n"
    375 				"void main (void)\n"
    376 				"{\n"
    377 				+ testSource +
    378 				"  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = temp_res;\n"
    379 				"  float pixelSize = 2.0f/1024.0f;\n"
    380 				"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
    381 				"}\n";
    382 
    383 			programCollection.glslSources.add("tese")
    384 					<< glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
    385 		}
    386 
    387 		{
    388 			const string geometry =
    389 				extSource +
    390 				"layout(${TOPOLOGY}) in;\n"
    391 				"layout(points, max_vertices = 1) out;\n"
    392 				"layout(set = 0, binding = 3, std430) buffer Buffer1\n"
    393 				"{\n"
    394 				"  uint result[];\n"
    395 				"};\n"
    396 				+ declSource +
    397 				"\n"
    398 				"void main (void)\n"
    399 				"{\n"
    400 				+ testSource +
    401 				"  result[gl_PrimitiveIDIn] = temp_res;\n"
    402 				"  gl_Position = gl_in[0].gl_Position;\n"
    403 				"  EmitVertex();\n"
    404 				"  EndPrimitive();\n"
    405 				"}\n";
    406 
    407 			subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
    408 													  programCollection.glslSources);
    409 		}
    410 		{
    411 			const string fragment =
    412 				extSource +
    413 				"layout(location = 0) out uint result;\n"
    414 				+ declSource +
    415 				"void main (void)\n"
    416 				"{\n"
    417 				+ testSource +
    418 				"  result = temp_res;\n"
    419 				"}\n";
    420 
    421 			programCollection.glslSources.add("fragment")
    422 				<< glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
    423 		}
    424 
    425 		subgroups::addNoSubgroupShader(programCollection);
    426 	}
    427 }
    428 
    429 void supportedCheck (Context& context, CaseDefinition caseDef)
    430 {
    431 	if (!subgroups::isSubgroupSupported(context))
    432 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
    433 
    434 	switch (caseDef.opType)
    435 	{
    436 		case OPTYPE_SHUFFLE:
    437 		case OPTYPE_SHUFFLE_XOR:
    438 			if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_BIT))
    439 			{
    440 				TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle operations");
    441 			}
    442 			break;
    443 		default:
    444 			if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT))
    445 			{
    446 				TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle relative operations");
    447 			}
    448 			break;
    449 	}
    450 
    451 	if (subgroups::isDoubleFormat(caseDef.format) &&
    452 			!subgroups::isDoubleSupportedForDevice(context))
    453 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
    454 }
    455 
    456 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
    457 {
    458 	if (!subgroups::areSubgroupOperationsSupportedForStage(
    459 				context, caseDef.shaderStage))
    460 	{
    461 		if (subgroups::areSubgroupOperationsRequiredForStage(
    462 					caseDef.shaderStage))
    463 		{
    464 			return tcu::TestStatus::fail(
    465 					   "Shader stage " +
    466 					   subgroups::getShaderStageName(caseDef.shaderStage) +
    467 					   " is required to support subgroup operations!");
    468 		}
    469 		else
    470 		{
    471 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
    472 		}
    473 	}
    474 
    475 	subgroups::SSBOData inputData[2];
    476 	inputData[0].format = caseDef.format;
    477 	inputData[0].layout = subgroups::SSBOData::LayoutStd140;
    478 	inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    479 	inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    480 
    481 	inputData[1].format = VK_FORMAT_R32_UINT;
    482 	inputData[1].layout = subgroups::SSBOData::LayoutStd140;
    483 	inputData[1].numElements = inputData[0].numElements;
    484 	inputData[1].initializeType = subgroups::SSBOData::InitializeNonZero;
    485 
    486 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    487 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, checkVertexPipelineStages);
    488 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    489 		return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, checkVertexPipelineStages);
    490 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    491 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 2, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
    492 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    493 		return subgroups::makeTessellationEvaluationFrameBufferTest(context,  VK_FORMAT_R32_UINT, inputData, 2, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
    494 	else
    495 		TCU_THROW(InternalError, "Unhandled shader stage");
    496 }
    497 
    498 
    499 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
    500 {
    501 	switch (caseDef.opType)
    502 	{
    503 		case OPTYPE_SHUFFLE:
    504 		case OPTYPE_SHUFFLE_XOR:
    505 			if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_BIT))
    506 			{
    507 				TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle operations");
    508 			}
    509 			break;
    510 		default:
    511 			if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_SHUFFLE_RELATIVE_BIT))
    512 			{
    513 				TCU_THROW(NotSupportedError, "Device does not support subgroup shuffle relative operations");
    514 			}
    515 			break;
    516 	}
    517 
    518 	if (subgroups::isDoubleFormat(caseDef.format) && !subgroups::isDoubleSupportedForDevice(context))
    519 	{
    520 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
    521 	}
    522 
    523 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    524 	{
    525 		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
    526 		{
    527 			return tcu::TestStatus::fail(
    528 					   "Shader stage " +
    529 					   subgroups::getShaderStageName(caseDef.shaderStage) +
    530 					   " is required to support subgroup operations!");
    531 		}
    532 		subgroups::SSBOData inputData[2];
    533 		inputData[0].format = caseDef.format;
    534 		inputData[0].layout = subgroups::SSBOData::LayoutStd430;
    535 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
    536 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
    537 
    538 		inputData[1].format = VK_FORMAT_R32_UINT;
    539 		inputData[1].layout = subgroups::SSBOData::LayoutStd430;
    540 		inputData[1].numElements = inputData[0].numElements;
    541 		inputData[1].initializeType = subgroups::SSBOData::InitializeNonZero;
    542 
    543 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, inputData, 2, checkCompute);
    544 	}
    545 
    546 	else
    547 	{
    548 		VkPhysicalDeviceSubgroupProperties subgroupProperties;
    549 		subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
    550 		subgroupProperties.pNext = DE_NULL;
    551 
    552 		VkPhysicalDeviceProperties2 properties;
    553 		properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
    554 		properties.pNext = &subgroupProperties;
    555 
    556 		context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
    557 
    558 		VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
    559 
    560 		if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
    561 		{
    562 			if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
    563 				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
    564 			else
    565 				stages = VK_SHADER_STAGE_FRAGMENT_BIT;
    566 		}
    567 
    568 		if ((VkShaderStageFlagBits)0u == stages)
    569 			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
    570 
    571 		subgroups::SSBOData inputData[2];
    572 		inputData[0].format			= caseDef.format;
    573 		inputData[0].layout			= subgroups::SSBOData::LayoutStd430;
    574 		inputData[0].numElements	= subgroups::maxSupportedSubgroupSize();
    575 		inputData[0].initializeType	= subgroups::SSBOData::InitializeNonZero;
    576 		inputData[0].binding		= 4u;
    577 		inputData[0].stages			= stages;
    578 
    579 		inputData[1].format			= VK_FORMAT_R32_UINT;
    580 		inputData[1].layout			= subgroups::SSBOData::LayoutStd430;
    581 		inputData[1].numElements	= inputData[0].numElements;
    582 		inputData[1].initializeType	= subgroups::SSBOData::InitializeNonZero;
    583 		inputData[1].binding		= 5u;
    584 		inputData[1].stages			= stages;
    585 
    586 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, inputData, 2, checkVertexPipelineStages, stages);
    587 	}
    588 }
    589 }
    590 
    591 namespace vkt
    592 {
    593 namespace subgroups
    594 {
    595 tcu::TestCaseGroup* createSubgroupsShuffleTests(tcu::TestContext& testCtx)
    596 {
    597 
    598 	de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
    599 		testCtx, "graphics", "Subgroup shuffle category tests: graphics"));
    600 	de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
    601 		testCtx, "compute", "Subgroup shuffle category tests: compute"));
    602 	de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
    603 		testCtx, "framebuffer", "Subgroup shuffle category tests: framebuffer"));
    604 
    605 	const VkFormat formats[] =
    606 	{
    607 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
    608 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
    609 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
    610 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
    611 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
    612 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
    613 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
    614 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
    615 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
    616 	};
    617 
    618 	const VkShaderStageFlags stages[] =
    619 	{
    620 		VK_SHADER_STAGE_VERTEX_BIT,
    621 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
    622 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
    623 		VK_SHADER_STAGE_GEOMETRY_BIT,
    624 	};
    625 
    626 	for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
    627 	{
    628 		const VkFormat format = formats[formatIndex];
    629 
    630 		for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
    631 		{
    632 
    633 			const string name =
    634 				de::toLower(getOpTypeName(opTypeIndex)) +
    635 				"_" + subgroups::getFormatNameForGLSL(format);
    636 
    637 			{
    638 				const CaseDefinition caseDef =
    639 				{
    640 					opTypeIndex,
    641 					VK_SHADER_STAGE_ALL_GRAPHICS,
    642 					format
    643 				};
    644 				addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
    645 			}
    646 
    647 			{
    648 				const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format};
    649 				addFunctionCaseWithPrograms(computeGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
    650 			}
    651 
    652 			for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
    653 			{
    654 				const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
    655 				addFunctionCaseWithPrograms(framebufferGroup.get(), name + "_" + getShaderStageName(caseDef.shaderStage), "",
    656 											supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
    657 			}
    658 		}
    659 	}
    660 
    661 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
    662 		testCtx, "shuffle", "Subgroup shuffle category tests"));
    663 
    664 	group->addChild(graphicGroup.release());
    665 	group->addChild(computeGroup.release());
    666 	group->addChild(framebufferGroup.release());
    667 
    668 	return group.release();
    669 }
    670 
    671 } // subgroups
    672 } // vkt
    673