tilelang.intrinsics.metal_macro_generator¶

Classes¶

Module Contents¶

class tilelang.intrinsics.metal_macro_generator.MPSIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float32', a_transposed=False, b_transposed=False, block_row_warps=1, block_col_warps=1, warp_row_tiles=8, warp_col_tiles=8, chunk=32, thread_var=None)¶
Parameters:
  • a_dtype (str)

  • b_dtype (str)

  • accum_dtype (str)

  • a_transposed (bool)

  • b_transposed (bool)

  • block_row_warps (int)

  • block_col_warps (int)

  • warp_row_tiles (int)

  • warp_col_tiles (int)

  • chunk (int)

  • thread_var (tvm.tirx.Var | None)

WARP_SIZE = 32¶
a_dtype = 'float16'¶
b_dtype = 'float16'¶
accum_dtype = 'float32'¶
a_transposed = False¶
b_transposed = False¶
block_row_warps = 1¶
block_col_warps = 1¶
warp_row_tiles = 8¶
warp_col_tiles = 8¶
chunk = 32¶
thread_var = None¶
micro_size_x = 8¶
micro_size_y = 8¶
micro_size_k = 8¶
warp_rows = 1¶
warp_cols = 1¶
get_thread_binding()¶
ldmatrix_a(A_local_buf, A_shared_buf, ki)¶
Parameters:

A_shared_buf (tvm.tirx.Buffer | tvm.tirx.BufferRegion)

ldmatrix_b(B_local_buf, B_shared_buf, ki)¶
Parameters:

B_shared_buf (tvm.tirx.Buffer | tvm.tirx.BufferRegion)

mma(A_local_buf, B_local_buf, C_local_buf)¶
simdgroup_copy(C_simd_buf, C_dst, is_store=True)¶
simd_store(C_simd_buf, C_dst)¶
simd_load(C_simd_buf, C_src)¶