# General Matrix-Vector Multiplication (GEMV)
===========================================
:::{warning}
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
Example code can be found at `examples/gemv/example_gemv.py`.
:::
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`.
## Triton Implementation
When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`.
A simple Triton kernel for GEMV might look like this:
```python
@triton.jit
def _gemv_naive(
x_ptr, A_ptr, y_ptr,
N, K,
BLOCK_SIZE_K: tl.constexpr,
):
n = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_SIZE_K)
mask = offs_k < K
a_ptrs = A_ptr + n * K + offs_k
a_vals = tl.load(a_ptrs, mask=mask, other=0.0)
x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0)
dot = tl.sum(a_vals * x_vals, axis=0)
tl.store(y_ptr + n, dot)
```
`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control.
## Naive Implementation in TileLang
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example:
```python
def naive_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x
A_shared = T.alloc_shared((BLOCK_K,), dtype)
B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype)
C_reg = T.alloc_local((1,), accum_dtype)
T.clear(C_reg)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for tk in T.serial(BLOCK_K):
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
```
And your kernel will be compiled into CUDA by `TileLang` (in `~/.tilelang/cache`):
```C++
extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_reg[1];
__shared__ uint64_t _mbarrier[2];
if (((int)threadIdx.x) == 0) {
tl::mbarrier_init(_mbarrier[0], 128);
tl::mbarrier_init(_mbarrier[1], 128);
}
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
for (int bk = 0; bk < 8; ++bk) {
tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1));
for (int tk = 0; tk < 128; ++tk) {
((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)];
((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)];
}
tl::fence_proxy_async();
tl::mbarrier_cp_async_arrive(_mbarrier[0]);
tl::mbarrier_arrive(_mbarrier[0]);
}
} else {
tl::warpgroup_reg_alloc<240>();
C_reg[0] = 0.000000e+00f;
for (int bk_1 = 0; bk_1 < 8; ++bk_1) {
tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1));
for (int tk_1 = 0; tk_1 < 128; ++tk_1) {
C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)])));
}
tl::fence_proxy_async();
tl::mbarrier_arrive(_mbarrier[1]);
}
C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]);
}
}
```
In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block).
At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower.
## More Concurrency
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA.
Here’s a simplified version:
```python
def naive_splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((1,), dtype)
B_local = T.alloc_local((1,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
A_local[0] = A[bk * BLOCK_K + tk]
B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS.
### Customizing Parallelism in K Dimension
If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration:
```python
def splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.serial(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
## Vectorized Reads
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
```python
def splitk_gemv_vectorized(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance.
## `tvm_thread_allreduce` Instead of `atomicAdd`
[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`:
```python
def splitk_gemv_vectorized_tvm(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
```
With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS!
## Autotune
`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations:
```python
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
@autotune(
configs=get_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
return kernel()
```
After autotuning, now our kernel gets **~0.0067 ms**, the final generated CUDA kernel might like this:
```C++
extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
float C_accum[1];
half_t A_local[8];
half_t B_local[8];
__shared__ float red_buf0[64];
C_accum[0] = 0.000000e+00f;
for (int bk = 0; bk < 4; ++bk) {
*(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8)));
*(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8)));
for (int k = 0; k < 8; ++k) {
C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k])));
}
}
tl::fence_proxy_async();
__syncthreads();
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0];
__syncthreads();
if (((int)threadIdx.y) < 16) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]);
}
__syncthreads();
if (((int)threadIdx.y) < 8) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]);
}
__syncthreads();
if (((int)threadIdx.y) < 4) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]);
}
__syncthreads();
if (((int)threadIdx.y) < 2) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]);
}
__syncthreads();
if (((int)threadIdx.y) < 1) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]);
}
__syncthreads();
C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]);
}
```
This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically.
## Conclusion
### Benchmark Table on Hopper GPU
| Kernel Name | Latency |
|------------|------------|
| torch/cuBLAS | 0.00784 ms |
| Triton | 0.00773 ms |
| naive_gemv | 0.16607 ms |
| splitk_gemv | 0.02419 ms |
| splitk_gemv_vectorized | 0.00809 ms |
| splitk_gemv_vectorized_tvm | 0.00675 ms |
Triton Time: 0.0077344514429569244
In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives.