Home | History | Annotate | Download | only in subgroups
      1 /*------------------------------------------------------------------------
      2  * Vulkan Conformance Tests
      3  * ------------------------
      4  *
      5  * Copyright (c) 2017 The Khronos Group Inc.
      6  * Copyright (c) 2017 Codeplay Software Ltd.
      7  *
      8  * Licensed under the Apache License, Version 2.0 (the "License");
      9  * you may not use this file except in compliance with the License.
     10  * You may obtain a copy of the License at
     11  *
     12  *      http://www.apache.org/licenses/LICENSE-2.0
     13  *
     14  * Unless required by applicable law or agreed to in writing, software
     15  * distributed under the License is distributed on an "AS IS" BASIS,
     16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     17  * See the License for the specific language governing permissions and
     18  * limitations under the License.
     19  *
     20  */ /*!
     21  * \file
     22  * \brief Subgroups Tests
     23  */ /*--------------------------------------------------------------------*/
     24 
     25 #include "vktSubgroupsClusteredTests.hpp"
     26 #include "vktSubgroupsTestsUtils.hpp"
     27 
     28 #include <string>
     29 #include <vector>
     30 
     31 using namespace tcu;
     32 using namespace std;
     33 using namespace vk;
     34 using namespace vkt;
     35 
     36 namespace
     37 {
     38 enum OpType
     39 {
     40 	OPTYPE_CLUSTERED_ADD = 0,
     41 	OPTYPE_CLUSTERED_MUL,
     42 	OPTYPE_CLUSTERED_MIN,
     43 	OPTYPE_CLUSTERED_MAX,
     44 	OPTYPE_CLUSTERED_AND,
     45 	OPTYPE_CLUSTERED_OR,
     46 	OPTYPE_CLUSTERED_XOR,
     47 	OPTYPE_CLUSTERED_LAST
     48 };
     49 
     50 static bool checkVertexPipelineStages(std::vector<const void*> datas,
     51 									  deUint32 width, deUint32)
     52 {
     53 	const deUint32* data =
     54 		reinterpret_cast<const deUint32*>(datas[0]);
     55 	for (deUint32 x = 0; x < width; ++x)
     56 	{
     57 		deUint32 val = data[x];
     58 
     59 		if (0x1 != val)
     60 		{
     61 			return false;
     62 		}
     63 	}
     64 
     65 	return true;
     66 }
     67 
     68 static bool checkFragment(std::vector<const void*> datas,
     69 						  deUint32 width, deUint32 height, deUint32)
     70 {
     71 	const deUint32* data =
     72 		reinterpret_cast<const deUint32*>(datas[0]);
     73 	for (deUint32 x = 0; x < width; ++x)
     74 	{
     75 		for (deUint32 y = 0; y < height; ++y)
     76 		{
     77 			deUint32 val = data[x * height + y];
     78 
     79 			if (0x1 != val)
     80 			{
     81 				return false;
     82 			}
     83 		}
     84 	}
     85 
     86 	return true;
     87 }
     88 
     89 static bool checkCompute(std::vector<const void*> datas,
     90 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
     91 						 deUint32)
     92 {
     93 	const deUint32* data =
     94 		reinterpret_cast<const deUint32*>(datas[0]);
     95 
     96 	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
     97 	{
     98 		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
     99 		{
    100 			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
    101 			{
    102 				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
    103 				{
    104 					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
    105 					{
    106 						for (deUint32 lZ = 0; lZ < localSize[2];
    107 								++lZ)
    108 						{
    109 							const deUint32 globalInvocationX =
    110 								nX * localSize[0] + lX;
    111 							const deUint32 globalInvocationY =
    112 								nY * localSize[1] + lY;
    113 							const deUint32 globalInvocationZ =
    114 								nZ * localSize[2] + lZ;
    115 
    116 							const deUint32 globalSizeX =
    117 								numWorkgroups[0] * localSize[0];
    118 							const deUint32 globalSizeY =
    119 								numWorkgroups[1] * localSize[1];
    120 
    121 							const deUint32 offset =
    122 								globalSizeX *
    123 								((globalSizeY *
    124 								  globalInvocationZ) +
    125 								 globalInvocationY) +
    126 								globalInvocationX;
    127 
    128 							if (0x1 != data[offset])
    129 							{
    130 								return false;
    131 							}
    132 						}
    133 					}
    134 				}
    135 			}
    136 		}
    137 	}
    138 
    139 	return true;
    140 }
    141 
    142 std::string getOpTypeName(int opType)
    143 {
    144 	switch (opType)
    145 	{
    146 		default:
    147 			DE_FATAL("Unsupported op type");
    148 		case OPTYPE_CLUSTERED_ADD:
    149 			return "subgroupClusteredAdd";
    150 		case OPTYPE_CLUSTERED_MUL:
    151 			return "subgroupClusteredMul";
    152 		case OPTYPE_CLUSTERED_MIN:
    153 			return "subgroupClusteredMin";
    154 		case OPTYPE_CLUSTERED_MAX:
    155 			return "subgroupClusteredMax";
    156 		case OPTYPE_CLUSTERED_AND:
    157 			return "subgroupClusteredAnd";
    158 		case OPTYPE_CLUSTERED_OR:
    159 			return "subgroupClusteredOr";
    160 		case OPTYPE_CLUSTERED_XOR:
    161 			return "subgroupClusteredXor";
    162 	}
    163 }
    164 
    165 std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
    166 {
    167 	switch (opType)
    168 	{
    169 		default:
    170 			DE_FATAL("Unsupported op type");
    171 		case OPTYPE_CLUSTERED_ADD:
    172 			return lhs + " + " + rhs;
    173 		case OPTYPE_CLUSTERED_MUL:
    174 			return lhs + " * " + rhs;
    175 		case OPTYPE_CLUSTERED_MIN:
    176 			switch (format)
    177 			{
    178 				default:
    179 					return "min(" + lhs + ", " + rhs + ")";
    180 				case VK_FORMAT_R32_SFLOAT:
    181 				case VK_FORMAT_R64_SFLOAT:
    182 					return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))";
    183 				case VK_FORMAT_R32G32_SFLOAT:
    184 				case VK_FORMAT_R32G32B32_SFLOAT:
    185 				case VK_FORMAT_R32G32B32A32_SFLOAT:
    186 				case VK_FORMAT_R64G64_SFLOAT:
    187 				case VK_FORMAT_R64G64B64_SFLOAT:
    188 				case VK_FORMAT_R64G64B64A64_SFLOAT:
    189 					return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
    190 			}
    191 		case OPTYPE_CLUSTERED_MAX:
    192 			switch (format)
    193 			{
    194 				default:
    195 					return "max(" + lhs + ", " + rhs + ")";
    196 				case VK_FORMAT_R32_SFLOAT:
    197 				case VK_FORMAT_R64_SFLOAT:
    198 					return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))";
    199 				case VK_FORMAT_R32G32_SFLOAT:
    200 				case VK_FORMAT_R32G32B32_SFLOAT:
    201 				case VK_FORMAT_R32G32B32A32_SFLOAT:
    202 				case VK_FORMAT_R64G64_SFLOAT:
    203 				case VK_FORMAT_R64G64B64_SFLOAT:
    204 				case VK_FORMAT_R64G64B64A64_SFLOAT:
    205 					return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
    206 			}
    207 		case OPTYPE_CLUSTERED_AND:
    208 			switch (format)
    209 			{
    210 				default:
    211 					return lhs + " & " + rhs;
    212 				case VK_FORMAT_R8_USCALED:
    213 					return lhs + " && " + rhs;
    214 				case VK_FORMAT_R8G8_USCALED:
    215 					return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
    216 				case VK_FORMAT_R8G8B8_USCALED:
    217 					return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)";
    218 				case VK_FORMAT_R8G8B8A8_USCALED:
    219 					return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)";
    220 			}
    221 		case OPTYPE_CLUSTERED_OR:
    222 			switch (format)
    223 			{
    224 				default:
    225 					return lhs + " | " + rhs;
    226 				case VK_FORMAT_R8_USCALED:
    227 					return lhs + " || " + rhs;
    228 				case VK_FORMAT_R8G8_USCALED:
    229 					return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
    230 				case VK_FORMAT_R8G8B8_USCALED:
    231 					return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)";
    232 				case VK_FORMAT_R8G8B8A8_USCALED:
    233 					return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)";
    234 			}
    235 		case OPTYPE_CLUSTERED_XOR:
    236 			switch (format)
    237 			{
    238 				default:
    239 					return lhs + " ^ " + rhs;
    240 				case VK_FORMAT_R8_USCALED:
    241 					return lhs + " ^^ " + rhs;
    242 				case VK_FORMAT_R8G8_USCALED:
    243 					return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
    244 				case VK_FORMAT_R8G8B8_USCALED:
    245 					return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)";
    246 				case VK_FORMAT_R8G8B8A8_USCALED:
    247 					return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)";
    248 			}
    249 	}
    250 }
    251 
    252 
    253 std::string getIdentity(int opType, vk::VkFormat format)
    254 {
    255 	bool isFloat = false;
    256 	bool isInt = false;
    257 	bool isUnsigned = false;
    258 
    259 	switch (format)
    260 	{
    261 		default:
    262 			DE_FATAL("Unhandled format!");
    263 		case VK_FORMAT_R32_SINT:
    264 		case VK_FORMAT_R32G32_SINT:
    265 		case VK_FORMAT_R32G32B32_SINT:
    266 		case VK_FORMAT_R32G32B32A32_SINT:
    267 			isInt = true;
    268 			break;
    269 		case VK_FORMAT_R32_UINT:
    270 		case VK_FORMAT_R32G32_UINT:
    271 		case VK_FORMAT_R32G32B32_UINT:
    272 		case VK_FORMAT_R32G32B32A32_UINT:
    273 			isUnsigned = true;
    274 			break;
    275 		case VK_FORMAT_R32_SFLOAT:
    276 		case VK_FORMAT_R32G32_SFLOAT:
    277 		case VK_FORMAT_R32G32B32_SFLOAT:
    278 		case VK_FORMAT_R32G32B32A32_SFLOAT:
    279 		case VK_FORMAT_R64_SFLOAT:
    280 		case VK_FORMAT_R64G64_SFLOAT:
    281 		case VK_FORMAT_R64G64B64_SFLOAT:
    282 		case VK_FORMAT_R64G64B64A64_SFLOAT:
    283 			isFloat = true;
    284 			break;
    285 		case VK_FORMAT_R8_USCALED:
    286 		case VK_FORMAT_R8G8_USCALED:
    287 		case VK_FORMAT_R8G8B8_USCALED:
    288 		case VK_FORMAT_R8G8B8A8_USCALED:
    289 			break; // bool types are not anything
    290 	}
    291 
    292 	switch (opType)
    293 	{
    294 		default:
    295 			DE_FATAL("Unsupported op type");
    296 		case OPTYPE_CLUSTERED_ADD:
    297 			return subgroups::getFormatNameForGLSL(format) + "(0)";
    298 		case OPTYPE_CLUSTERED_MUL:
    299 			return subgroups::getFormatNameForGLSL(format) + "(1)";
    300 		case OPTYPE_CLUSTERED_MIN:
    301 			if (isFloat)
    302 			{
    303 				return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
    304 			}
    305 			else if (isInt)
    306 			{
    307 				return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
    308 			}
    309 			else if (isUnsigned)
    310 			{
    311 				return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
    312 			}
    313 			else
    314 			{
    315 				DE_FATAL("Unhandled case");
    316 			}
    317 		case OPTYPE_CLUSTERED_MAX:
    318 			if (isFloat)
    319 			{
    320 				return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
    321 			}
    322 			else if (isInt)
    323 			{
    324 				return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
    325 			}
    326 			else if (isUnsigned)
    327 			{
    328 				return subgroups::getFormatNameForGLSL(format) + "(0)";
    329 			}
    330 			else
    331 			{
    332 				DE_FATAL("Unhandled case");
    333 			}
    334 		case OPTYPE_CLUSTERED_AND:
    335 			return subgroups::getFormatNameForGLSL(format) + "(~0)";
    336 		case OPTYPE_CLUSTERED_OR:
    337 			return subgroups::getFormatNameForGLSL(format) + "(0)";
    338 		case OPTYPE_CLUSTERED_XOR:
    339 			return subgroups::getFormatNameForGLSL(format) + "(0)";
    340 	}
    341 }
    342 
    343 std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
    344 {
    345 	std::string formatName = subgroups::getFormatNameForGLSL(format);
    346 	switch (format)
    347 	{
    348 		default:
    349 			return "all(equal(" + lhs + ", " + rhs + "))";
    350 		case VK_FORMAT_R8_USCALED:
    351 		case VK_FORMAT_R32_UINT:
    352 		case VK_FORMAT_R32_SINT:
    353 			return "(" + lhs + " == " + rhs + ")";
    354 		case VK_FORMAT_R32_SFLOAT:
    355 		case VK_FORMAT_R64_SFLOAT:
    356 			switch (opType)
    357 			{
    358 				default:
    359 					return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
    360 				case OPTYPE_CLUSTERED_MIN:
    361 				case OPTYPE_CLUSTERED_MAX:
    362 					return "(" + lhs + " == " + rhs + ")";
    363 			}
    364 		case VK_FORMAT_R32G32_SFLOAT:
    365 		case VK_FORMAT_R32G32B32_SFLOAT:
    366 		case VK_FORMAT_R32G32B32A32_SFLOAT:
    367 		case VK_FORMAT_R64G64_SFLOAT:
    368 		case VK_FORMAT_R64G64B64_SFLOAT:
    369 		case VK_FORMAT_R64G64B64A64_SFLOAT:
    370 			switch (opType)
    371 			{
    372 				default:
    373 					return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
    374 				case OPTYPE_CLUSTERED_MIN:
    375 				case OPTYPE_CLUSTERED_MAX:
    376 					return "all(equal(" + lhs + ", " + rhs + "))";
    377 			}
    378 	}
    379 }
    380 
    381 struct CaseDefinition
    382 {
    383 	int					opType;
    384 	VkShaderStageFlags	shaderStage;
    385 	VkFormat			format;
    386 	bool				noSSBO;
    387 };
    388 
    389 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
    390 {
    391 	std::ostringstream bdy;
    392 
    393 	bdy << "  bool tempResult = true;\n";
    394 
    395 	for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
    396 	{
    397 		bdy	<< "  {\n"
    398 			<< "    const uint clusterSize = " << i << ";\n"
    399 			<< "    if (clusterSize <= gl_SubgroupSize)\n"
    400 			<< "    {\n"
    401 			<< "      " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
    402 			<< getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n"
    403 			<< "      for (uint clusterOffset = 0; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n"
    404 			<< "      {\n"
    405 			<< "        " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
    406 			<< getIdentity(caseDef.opType, caseDef.format) << ";\n"
    407 			<< "        for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n"
    408 			<< "        {\n"
    409 			<< "          if (subgroupBallotBitExtract(mask, index))\n"
    410 			<< "          {\n"
    411 			<< "            ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
    412 			<< "          }\n"
    413 			<< "        }\n"
    414 			<< "        if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n"
    415 			<< "        {\n"
    416 			<< "          if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n"
    417 			<< "          {\n"
    418 			<< "            tempResult = false;\n"
    419 			<< "          }\n"
    420 			<< "        }\n"
    421 			<< "      }\n"
    422 			<< "    }\n"
    423 			<< "  }\n";
    424 	}
    425 
    426 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    427 	{
    428 		std::ostringstream src;
    429 		std::ostringstream	fragmentSrc;
    430 
    431 		src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450 )<< "\n"
    432 			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    433 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    434 			<< "layout(location = 0) in highp vec4 in_position;\n"
    435 			<< "layout(location = 0) out float out_color;\n"
    436 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
    437 			<< "{\n"
    438 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
    439 			<< "};\n"
    440 			<< "\n"
    441 			<< "void main (void)\n"
    442 			<< "{\n"
    443 			<< "  uvec4 mask = subgroupBallot(true);\n"
    444 			<< bdy.str()
    445 			<< "  out_color = float(tempResult ? 1 : 0);\n"
    446 			<< "  gl_Position = in_position;\n"
    447 			<< "}\n";
    448 
    449 		programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    450 
    451 		fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
    452 			<< "layout(location = 0) in float in_color;\n"
    453 			<< "layout(location = 0) out uint out_color;\n"
    454 			<< "void main()\n"
    455 			<<"{\n"
    456 			<< "	out_color = uint(in_color);\n"
    457 			<< "}\n";
    458 		programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    459 	}
    460 	else
    461 	{
    462 		DE_FATAL("Unsupported shader stage");
    463 	}
    464 }
    465 
    466 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
    467 {
    468 	std::ostringstream bdy;
    469 
    470 	bdy << "  bool tempResult = true;\n";
    471 
    472 	for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2)
    473 	{
    474 		bdy	<< "  {\n"
    475 			<< "    const uint clusterSize = " << i << ";\n"
    476 			<< "    if (clusterSize <= gl_SubgroupSize)\n"
    477 			<< "    {\n"
    478 			<< "      " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
    479 			<< getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n"
    480 			<< "      for (uint clusterOffset = 0; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n"
    481 			<< "      {\n"
    482 			<< "        " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
    483 			<< getIdentity(caseDef.opType, caseDef.format) << ";\n"
    484 			<< "        for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n"
    485 			<< "        {\n"
    486 			<< "          if (subgroupBallotBitExtract(mask, index))\n"
    487 			<< "          {\n"
    488 			<< "            ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
    489 			<< "          }\n"
    490 			<< "        }\n"
    491 			<< "        if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n"
    492 			<< "        {\n"
    493 			<< "          if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n"
    494 			<< "          {\n"
    495 			<< "            tempResult = false;\n"
    496 			<< "          }\n"
    497 			<< "        }\n"
    498 			<< "      }\n"
    499 			<< "    }\n"
    500 			<< "  }\n";
    501 	}
    502 
    503 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    504 	{
    505 		std::ostringstream src;
    506 
    507 		src << "#version 450\n"
    508 			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    509 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    510 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
    511 			"local_size_z_id = 2) in;\n"
    512 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    513 			<< "{\n"
    514 			<< "  uint result[];\n"
    515 			<< "};\n"
    516 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    517 			<< "{\n"
    518 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
    519 			<< "};\n"
    520 			<< "\n"
    521 			<< "void main (void)\n"
    522 			<< "{\n"
    523 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
    524 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
    525 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
    526 			"gl_GlobalInvocationID.x;\n"
    527 			<< "  uvec4 mask = subgroupBallot(true);\n"
    528 			<< bdy.str()
    529 			<< "  result[offset] = tempResult ? 1 : 0;\n"
    530 			<< "}\n";
    531 
    532 		programCollection.glslSources.add("comp")
    533 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    534 	}
    535 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
    536 	{
    537 		programCollection.glslSources.add("vert")
    538 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    539 
    540 		std::ostringstream frag;
    541 
    542 		frag << "#version 450\n"
    543 			 << "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    544 			 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    545 			 << "layout(location = 0) out uint result;\n"
    546 			 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer2\n"
    547 			 << "{\n"
    548 			 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
    549 			 << "};\n"
    550 			 << "void main (void)\n"
    551 			 << "{\n"
    552 			 << "  uvec4 mask = subgroupBallot(true);\n"
    553 			 << bdy.str()
    554 			 << "  result = tempResult ? 1 : 0;\n"
    555 			 << "}\n";
    556 
    557 		programCollection.glslSources.add("frag")
    558 				<< glu::FragmentSource(frag.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    559 	}
    560 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    561 	{
    562 		std::ostringstream src;
    563 
    564 		src << "#version 450\n"
    565 			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    566 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    567 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    568 			<< "{\n"
    569 			<< "  uint result[];\n"
    570 			<< "};\n"
    571 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    572 			<< "{\n"
    573 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
    574 			<< "};\n"
    575 			<< "\n"
    576 			<< "void main (void)\n"
    577 			<< "{\n"
    578 			<< "  uvec4 mask = subgroupBallot(true);\n"
    579 			<< bdy.str()
    580 			<< "  result[gl_VertexIndex] = tempResult ? 1 : 0;\n"
    581 			<< "}\n";
    582 
    583 		programCollection.glslSources.add("vert")
    584 				<< glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    585 	}
    586 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    587 	{
    588 		programCollection.glslSources.add("vert")
    589 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    590 
    591 		std::ostringstream src;
    592 
    593 		src << "#version 450\n"
    594 			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    595 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    596 			<< "layout(points) in;\n"
    597 			<< "layout(points, max_vertices = 1) out;\n"
    598 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    599 			<< "{\n"
    600 			<< "  uint result[];\n"
    601 			<< "};\n"
    602 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    603 			<< "{\n"
    604 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
    605 			<< "};\n"
    606 			<< "\n"
    607 			<< "void main (void)\n"
    608 			<< "{\n"
    609 			<< "  uvec4 mask = subgroupBallot(true);\n"
    610 			<< bdy.str()
    611 			<< "  result[gl_PrimitiveIDIn] = tempResult ? 1 : 0;\n"
    612 			<< "}\n";
    613 
    614 		programCollection.glslSources.add("geom")
    615 				<< glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    616 	}
    617 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    618 	{
    619 		programCollection.glslSources.add("vert")
    620 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    621 
    622 		programCollection.glslSources.add("tese")
    623 				<< glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
    624 
    625 		std::ostringstream src;
    626 
    627 		src << "#version 450\n"
    628 			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    629 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    630 			<< "layout(vertices=1) out;\n"
    631 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    632 			<< "{\n"
    633 			<< "  uint result[];\n"
    634 			<< "};\n"
    635 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    636 			<< "{\n"
    637 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
    638 			<< "};\n"
    639 			<< "\n"
    640 			<< "void main (void)\n"
    641 			<< "{\n"
    642 			<< "  uvec4 mask = subgroupBallot(true);\n"
    643 			<< bdy.str()
    644 			<< "  result[gl_PrimitiveID] = tempResult ? 1 : 0;\n"
    645 			<< "}\n";
    646 
    647 		programCollection.glslSources.add("tesc")
    648 				<< glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    649 	}
    650 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    651 	{
    652 		programCollection.glslSources.add("vert")
    653 				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    654 
    655 		programCollection.glslSources.add("tesc")
    656 				<< glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
    657 
    658 		std::ostringstream src;
    659 
    660 		src << "#version 450\n"
    661 			<< "#extension GL_KHR_shader_subgroup_clustered: enable\n"
    662 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
    663 			<< "layout(isolines) in;\n"
    664 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
    665 			<< "{\n"
    666 			<< "  uint result[];\n"
    667 			<< "};\n"
    668 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
    669 			<< "{\n"
    670 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
    671 			<< "};\n"
    672 			<< "\n"
    673 			<< "void main (void)\n"
    674 			<< "{\n"
    675 			<< "  uvec4 mask = subgroupBallot(true);\n"
    676 			<< bdy.str()
    677 			<< "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult ? 1 : 0;\n"
    678 			<< "}\n";
    679 
    680 		programCollection.glslSources.add("tese")
    681 				<< glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
    682 	}
    683 	else
    684 	{
    685 		DE_FATAL("Unsupported shader stage");
    686 	}
    687 }
    688 
    689 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
    690 {
    691 	if (!subgroups::isSubgroupSupported(context))
    692 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
    693 
    694 	if (!subgroups::areSubgroupOperationsSupportedForStage(
    695 				context, caseDef.shaderStage))
    696 	{
    697 		if (subgroups::areSubgroupOperationsRequiredForStage(
    698 					caseDef.shaderStage))
    699 		{
    700 			return tcu::TestStatus::fail(
    701 					   "Shader stage " +
    702 					   subgroups::getShaderStageName(caseDef.shaderStage) +
    703 					   " is required to support subgroup operations!");
    704 		}
    705 		else
    706 		{
    707 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
    708 		}
    709 	}
    710 
    711 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_CLUSTERED_BIT))
    712 	{
    713 		TCU_THROW(NotSupportedError, "Device does not support subgroup clustered operations");
    714 	}
    715 
    716 	if (subgroups::isDoubleFormat(caseDef.format) &&
    717 			!subgroups::isDoubleSupportedForDevice(context))
    718 	{
    719 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
    720 	}
    721 
    722 	//Tests which don't use the SSBO
    723 	if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    724 	{
    725 		subgroups::SSBOData inputData;
    726 		inputData.format = caseDef.format;
    727 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    728 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    729 
    730 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
    731 	}
    732 
    733 	if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
    734 			(VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
    735 	{
    736 		if (!subgroups::isVertexSSBOSupportedForDevice(context))
    737 		{
    738 			TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
    739 		}
    740 	}
    741 
    742 	if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
    743 	{
    744 		subgroups::SSBOData inputData;
    745 		inputData.format = caseDef.format;
    746 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    747 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    748 
    749 		return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
    750 										   &inputData, 1, checkFragment);
    751 	}
    752 	else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
    753 	{
    754 		subgroups::SSBOData inputData;
    755 		inputData.format = caseDef.format;
    756 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    757 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    758 
    759 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData,
    760 										  1, checkCompute);
    761 	}
    762 	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
    763 	{
    764 		subgroups::SSBOData inputData;
    765 		inputData.format = caseDef.format;
    766 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    767 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    768 
    769 		return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT, &inputData,
    770 										 1, checkVertexPipelineStages);
    771 	}
    772 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
    773 	{
    774 		subgroups::SSBOData inputData;
    775 		inputData.format = caseDef.format;
    776 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    777 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    778 
    779 		return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT, &inputData,
    780 										   1, checkVertexPipelineStages);
    781 	}
    782 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
    783 	{
    784 		subgroups::SSBOData inputData;
    785 		inputData.format = caseDef.format;
    786 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    787 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    788 
    789 		return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT, &inputData,
    790 				1, checkVertexPipelineStages);
    791 	}
    792 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
    793 	{
    794 		subgroups::SSBOData inputData;
    795 		inputData.format = caseDef.format;
    796 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
    797 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
    798 
    799 		return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT, &inputData,
    800 				1, checkVertexPipelineStages);
    801 	}
    802 	else
    803 	{
    804 		return tcu::TestStatus::pass("Unhandled shader stage!");
    805 	}
    806 }
    807 }
    808 
    809 namespace vkt
    810 {
    811 namespace subgroups
    812 {
    813 tcu::TestCaseGroup* createSubgroupsClusteredTests(tcu::TestContext& testCtx)
    814 {
    815 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
    816 			testCtx, "clustered", "Subgroup clustered category tests"));
    817 
    818 	const VkShaderStageFlags stages[] =
    819 	{
    820 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
    821 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
    822 		VK_SHADER_STAGE_GEOMETRY_BIT,
    823 		VK_SHADER_STAGE_VERTEX_BIT,
    824 		VK_SHADER_STAGE_FRAGMENT_BIT,
    825 		VK_SHADER_STAGE_COMPUTE_BIT
    826 	};
    827 
    828 	const VkFormat formats[] =
    829 	{
    830 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
    831 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
    832 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
    833 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
    834 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
    835 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
    836 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
    837 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
    838 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
    839 	};
    840 
    841 	for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
    842 	{
    843 		const VkShaderStageFlags stage = stages[stageIndex];
    844 
    845 		for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
    846 		{
    847 			const VkFormat format = formats[formatIndex];
    848 
    849 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex)
    850 			{
    851 				bool isBool = false;
    852 				bool isFloat = false;
    853 
    854 				switch (format)
    855 				{
    856 					default:
    857 						break;
    858 					case VK_FORMAT_R32_SFLOAT:
    859 					case VK_FORMAT_R32G32_SFLOAT:
    860 					case VK_FORMAT_R32G32B32_SFLOAT:
    861 					case VK_FORMAT_R32G32B32A32_SFLOAT:
    862 					case VK_FORMAT_R64_SFLOAT:
    863 					case VK_FORMAT_R64G64_SFLOAT:
    864 					case VK_FORMAT_R64G64B64_SFLOAT:
    865 					case VK_FORMAT_R64G64B64A64_SFLOAT:
    866 						isFloat = true;
    867 						break;
    868 					case VK_FORMAT_R8_USCALED:
    869 					case VK_FORMAT_R8G8_USCALED:
    870 					case VK_FORMAT_R8G8B8_USCALED:
    871 					case VK_FORMAT_R8G8B8A8_USCALED:
    872 						isBool = true;
    873 						break;
    874 				}
    875 
    876 				bool isBitwiseOp = false;
    877 
    878 				switch (opTypeIndex)
    879 				{
    880 					default:
    881 						break;
    882 					case OPTYPE_CLUSTERED_AND:
    883 					case OPTYPE_CLUSTERED_OR:
    884 					case OPTYPE_CLUSTERED_XOR:
    885 						isBitwiseOp = true;
    886 						break;
    887 				}
    888 
    889 				if (isFloat && isBitwiseOp)
    890 				{
    891 					// Skip float with bitwise category.
    892 					continue;
    893 				}
    894 
    895 				if (isBool && !isBitwiseOp)
    896 				{
    897 					// Skip bool when its not the bitwise category.
    898 					continue;
    899 				}
    900 
    901 				CaseDefinition caseDef = {opTypeIndex, stage, format, false};
    902 
    903 				std::ostringstream name;
    904 
    905 				std::string op = getOpTypeName(opTypeIndex);
    906 
    907 				name << de::toLower(op)
    908 					 << "_" << subgroups::getFormatNameForGLSL(format)
    909 					 << "_" << getShaderStageName(stage);
    910 
    911 				addFunctionCaseWithPrograms(group.get(), name.str(),
    912 											"", initPrograms, test, caseDef);
    913 
    914 				if (VK_SHADER_STAGE_VERTEX_BIT == stage)
    915 				{
    916 					caseDef.noSSBO = true;
    917 					addFunctionCaseWithPrograms(group.get(), name.str()+"_framebuffer", "",
    918 												initFrameBufferPrograms, test, caseDef);
    919 				}
    920 			}
    921 		}
    922 	}
    923 
    924 	return group.release();
    925 }
    926 
    927 } // subgroups
    928 } // vkt
    929