General Matrix-Vector Multiplication (GEMV)¶
===========================================
Warning
This document is still experimental and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
Tip
Example code can be found at examples/gemv/example_gemv.py
.
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using TileLang
.
Triton Implementation¶
When implementing a GEMV kernel, you might start with a high-level approach using a tool like Triton
.
A simple Triton kernel for GEMV might look like this:
@triton.jit
def _gemv_naive(
x_ptr, A_ptr, y_ptr,
N, K,
BLOCK_SIZE_K: tl.constexpr,
):
n = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_SIZE_K)
mask = offs_k < K
a_ptrs = A_ptr + n * K + offs_k
a_vals = tl.load(a_ptrs, mask=mask, other=0.0)
x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0)
dot = tl.sum(a_vals * x_vals, axis=0)
tl.store(y_ptr + n, dot)
Triton
is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in TileLang
that exposes more low-level control.
Naive Implementation in TileLang¶
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a (1, k) * (k, n)
GEMM. Below is a simple example:
def naive_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x
A_shared = T.alloc_shared((BLOCK_K,), dtype)
B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype)
C_reg = T.alloc_local((1,), accum_dtype)
T.clear(C_reg)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for tk in T.serial(BLOCK_K):
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
And your kernel will be compiled into CUDA by TileLang
(in ~/.tilelang/cache
):
extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_reg[1];
__shared__ uint64_t _mbarrier[2];
if (((int)threadIdx.x) == 0) {
tl::mbarrier_init(_mbarrier[0], 128);
tl::mbarrier_init(_mbarrier[1], 128);
}
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
for (int bk = 0; bk < 8; ++bk) {
tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1));
for (int tk = 0; tk < 128; ++tk) {
((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)];
((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)];
}
tl::fence_proxy_async();
tl::mbarrier_cp_async_arrive(_mbarrier[0]);
tl::mbarrier_arrive(_mbarrier[0]);
}
} else {
tl::warpgroup_reg_alloc<240>();
C_reg[0] = 0.000000e+00f;
for (int bk_1 = 0; bk_1 < 8; ++bk_1) {
tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1));
for (int tk_1 = 0; tk_1 < 128; ++tk_1) {
C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)])));
}
tl::fence_proxy_async();
tl::mbarrier_arrive(_mbarrier[1]);
}
C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]);
}
}
In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block).
At this level, we only gain very little computation power from our GPU with around ~0.17 ms compared to torch/cuBLAS’s ~0.008 ms, which is around 20x slower.
More Concurrency¶
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like atomicAdd
in CUDA.
Here’s a simplified version:
def naive_splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((1,), dtype)
B_local = T.alloc_local((1,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
A_local[0] = A[bk * BLOCK_K + tk]
B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
By introducing parallelism along K dimension, our kernel now achieves ~0.024 ms, an improvement, but still not on par with torch/cuBLAS.
Customizing Parallelism in K Dimension¶
If your K dimension is large, you can further customize how many elements each thread processes by introducing a reduce_threads
parameter. This way, each thread handles multiple elements per iteration:
def splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.serial(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
Vectorized Reads¶
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., float2
, float4
). In TileLang
, you can specify vectorized operations via T.vectorized
:
def splitk_gemv_vectorized(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
With vectorized read, now the kernel finishes in ~0.0084 ms, which is getting close to cuBLAS performance.
tvm_thread_allreduce
Instead of atomicAdd
¶
tvm_thread_allreduce
has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + atomidAdd
:
def splitk_gemv_vectorized_tvm(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
With this optimization, the kernel latency now reduces from ~0.0084 ms to ~0.0069 ms, which is faster than torch/cuBLAS!
Autotune¶
BLOCK_N
, BLOCK_K
, reduce_threads
are hyperparameters in our kernel, which can be tuned to improve performance. We can use the tilelang.autotune
feature to automatically search for optimal configurations:
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
@autotune(
configs=get_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
return kernel()
After autotuning, now our kernel gets ~0.0067 ms, the final generated CUDA kernel might like this:
extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
float C_accum[1];
half_t A_local[8];
half_t B_local[8];
__shared__ float red_buf0[64];
C_accum[0] = 0.000000e+00f;
for (int bk = 0; bk < 4; ++bk) {
*(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8)));
*(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8)));
for (int k = 0; k < 8; ++k) {
C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k])));
}
}
tl::fence_proxy_async();
__syncthreads();
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0];
__syncthreads();
if (((int)threadIdx.y) < 16) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]);
}
__syncthreads();
if (((int)threadIdx.y) < 8) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]);
}
__syncthreads();
if (((int)threadIdx.y) < 4) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]);
}
__syncthreads();
if (((int)threadIdx.y) < 2) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]);
}
__syncthreads();
if (((int)threadIdx.y) < 1) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]);
}
__syncthreads();
C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]);
}
This corresponds closely to our TileLang
program, with necessary synchronization and low-level optimizations inserted automatically.
Conclusion¶
Benchmark Table on Hopper GPU¶
Kernel Name |
Latency |
---|---|
torch/cuBLAS |
0.00784 ms |
Triton |
0.00773 ms |
naive_gemv |
0.16607 ms |
splitk_gemv |
0.02419 ms |
splitk_gemv_vectorized |
0.00809 ms |
splitk_gemv_vectorized_tvm |
0.00675 ms |
Triton Time: 0.0077344514429569244
In this tutorial, we implemented a simple GEMV kernel and learn that TileLang
exposes low level control to user such as thread-level programming and CUDA primitives.