Home | History | Annotate | Download | only in NVPTX
      1 # This test generates all variants of wmma intrinsics and verifies that LLVM
      2 # generates correct instructions for them.
      3 
      4 # RUN: python %s > %t.ll
      5 # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
      6 
      7 from itertools import product
      8 from string import Template
      9 
     10 def make_wmma_slice_ty(abcd, itype):
     11   elt_ty = "<2 x half>" if itype == "f16" else "float"
     12   num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
     13   return [elt_ty] * num_elts
     14 
     15 def make_wmma_ld_ret_ty(abc, itype):
     16   return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
     17 
     18 # returns address space
     19 def get_aspace(space):
     20   space_map = {
     21       ".global" : 1,
     22       ".shared" : 3,
     23       ".const"  : 4,
     24       ".local"  : 5,
     25       ".param"  : 101,
     26       ""        : 0,
     27       ".generic": 0
     28   }
     29   return space_map[space];
     30 
     31 def get_pspace(space):
     32   return "p%di8" % get_aspace(space);
     33 
     34 # Convenient test patterns.
     35 check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
     36 check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
     37 check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
     38 
     39 known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
     40 
     41 def gen_wmma_load_tests():
     42   load_template = """
     43 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
     44 
     45 ; CHECK-LABEL: .func {{.*}}test_${function}(
     46 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
     47 ; CHECK: ${instruction}
     48 ; CHECK: {${check_result}}
     49 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
     50   %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
     51   ret ${ret_ty} %v0;
     52 }
     53 
     54 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
     55 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
     56 ; CHECK: ${instruction}
     57 ; CHECK: {${check_result}}
     58 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
     59   %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
     60   %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
     61   ret ${ret_ty} %v0;
     62 }
     63 """
     64   intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
     65   instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
     66 
     67   for geom, abc, layout, space, stride, itype in product(
     68       known_geoms,
     69       "abc",
     70       ["row","col"],
     71       ["",".shared",".global"],
     72       ["", ".stride"],
     73       ["f16", "f32"]):
     74 
     75     params = {
     76         "abc" : abc,
     77         "layout" : layout,
     78         "space" : space,
     79         "stride" : stride,
     80         "itype" : itype,
     81         "pspace" : get_pspace(space),
     82         "as"     : "addrspace(%d)" % get_aspace(space),
     83         "geom"   : geom,
     84     }
     85 
     86     if itype == "f32" and abc != "c":
     87       continue
     88 
     89     test_params = params
     90     test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
     91     test_params["function"] = test_params["intrinsic"].replace(".","_")
     92     test_params["instruction"] = Template(instruction_template).substitute(params)
     93     test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
     94     if abc == "c" :
     95       test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
     96     else:
     97       test_params["check_result"] = check_f16_8
     98 
     99     if stride:
    100       test_params["extra_args"] = ", i32 %stride";
    101       test_params["stride_pattern"] = ", %r{{[0-9]+}}"
    102     else:
    103       test_params["extra_args"] = ""
    104       test_params["stride_pattern"] = ""
    105 
    106     print(Template(load_template).substitute(test_params))
    107 
    108 def make_wmma_slice_args(itype, abcd, prefix="v"):
    109   return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
    110                   in enumerate(make_wmma_slice_ty(abcd, itype))])
    111 
    112 def gen_wmma_store_tests():
    113   store_template = """
    114 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
    115 
    116 ; CHECK-LABEL: .func {{.*}}test_${function}(
    117 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
    118 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
    119 ; CHECK: {${check_args}}
    120 ; CHECK: ${stride_pattern}
    121   call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
    122   ret void
    123 }
    124 
    125 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
    126 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
    127 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
    128 ; CHECK: ${check_args}
    129 ; CHECK: ${stride_pattern}
    130   %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
    131   call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
    132   ret void
    133 }
    134 """
    135   intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
    136   instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
    137 
    138   for geom, abc, layout, space, stride, itype in product(
    139       known_geoms,
    140       "d",
    141       ["row","col"],
    142       ["",".shared",".global"],
    143       ["", ".stride"],
    144       ["f16", "f32"]):
    145 
    146     params = {
    147         "abc" : abc,
    148         "layout" : layout,
    149         "space" : space,
    150         "stride" : stride,
    151         "itype" : itype,
    152         "pspace" : get_pspace(space),
    153         "as"     : "addrspace(%d)" % get_aspace(space),
    154         "geom"   : geom,
    155     }
    156 
    157     test_params = params
    158     test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    159     test_params["function"] = test_params["intrinsic"].replace(".","_")
    160     test_params["instruction"] = Template(instruction_template).substitute(params)
    161     test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
    162     test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
    163     if stride:
    164       test_params["extra_args"] = ", i32 %stride";
    165       test_params["stride_pattern"] = ", %r{{[0-9]+}};"
    166     else:
    167       test_params["extra_args"] = ""
    168       test_params["stride_pattern"] = ";"
    169     test_params["args"] = make_wmma_slice_args(itype, "d");
    170 
    171     print(Template(store_template).substitute(test_params))
    172 
    173 def gen_wmma_mma_tests():
    174   mma_template = """
    175 declare ${ret_ty} @${intrinsic}(
    176         ${args});
    177 
    178 ; CHECK-LABEL: .func {{.*}}test_${function}(
    179 define ${ret_ty} @test_${function}(
    180         ${args}) {
    181 ; CHECK: ${instruction}
    182 ; CHECK-NEXT: ${check_d}
    183 ; CHECK-NEXT: ${check_ab}
    184 ; CHECK-NEXT: ${check_ab}
    185 ; CHECK-NEXT: ${check_c}
    186   %r = call ${ret_ty} @${intrinsic}(
    187         ${args});
    188   ret ${ret_ty} %r;
    189 }
    190 """
    191   intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
    192   instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
    193 
    194   for geom, alayout, blayout, ctype, dtype, satf in product(
    195       known_geoms,
    196       ["row","col"],
    197       ["row","col"],
    198       ["f16", "f32"],
    199       ["f16", "f32"],
    200       [".satfinite", ""]):
    201 
    202     params = {
    203         "alayout" : alayout,
    204         "blayout" : blayout,
    205         "ctype" : ctype,
    206         "dtype" : dtype,
    207         "satf"  : satf,
    208         "geom"  : geom,
    209     }
    210 
    211     test_params = params
    212     test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    213     test_params["function"] = test_params["intrinsic"].replace(".", "_")
    214     test_params["instruction"] = Template(instruction_template).substitute(params)
    215     test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
    216     test_params["check_ab"] = check_f16_8
    217     test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
    218     test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
    219     args = ",\n        ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
    220                               for abcd, t in (("a", "f16"),
    221                                               ("b", "f16"),
    222                                               ("c", ctype)))
    223     test_params["args"] = args
    224     print(Template(mma_template).substitute(test_params))
    225 
    226 def main():
    227   gen_wmma_load_tests()
    228   gen_wmma_store_tests()
    229   gen_wmma_mma_tests()
    230 
    231 main()
    232