ElementWise Operators#

Author: Chenghua Wang

Warning

This document is still experimental and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!

Elementwise operators are widely used in deep learning and often serve as the first example encountered by those beginning to explore parallel programming. This tutorial will analyze several implementations of the elementwise addition operator using TileLang and compare them with the corresponding CUDA implementation. By the end of this tutorial, you will learn:

  1. How to implement an elementwise operator using TileLang.

  2. How to compile operators with dynamic shapes.

  3. How TileLang addresses boundary-related issues.

  4. The similarities and differences between operators implemented in TileLang and those implemented in CUDA/CuTe.

Please note that this tutorial does not delve deeply into the design principles of TileLang. For a broader understanding of TileLang, we recommend consulting the Overview.

Elementwise add in TileLang#

def elementwise_add(N, threads=256, dtype="bfloat16"):

    @T.prim_func
    def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
        with T.Kernel(T.ceildiv(N, threads), threads=threads) as (b_x):
            # vector add.
            for i in T.Parallel(threads):
                C[b_x * threads + i] = A[b_x * threads + i] + B[b_x * threads + i]

    return main

All logic for TileLang kernels must be implemented within the T.Kernel(...) scope. In this example, initializing T.kernel(...) requires specifying both the grid size and the number of threads per block. The returned value bx corresponds to blockIdx.x in CUDA. In the provided implementation, T.Parallel is used to process the data tile (of size 1 x threads) assigned to the block for computation.

Those familiar with CUDA programming might wonder where threadIdx fits into this. Note that the code inside T.Kernel operates at the block level, not the thread level. In this example, your focus is solely on defining the block-level logic. During compilation, TileLang automatically maps computations to the corresponding threads and applies further optimizations. The optimized code generated by TileLang may closely align with carefully handcrafted computational logic, as demonstrated in Section 2 with a concrete example. While TileLang also supports thread-level programming semantics, this will be covered in subsequent discussions.

The program can be compiled using the following code:

program = elementwise_add(1024, threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")

Launching the kernel is straightforward, just call it directly like a function:

C = kernel(A, B)

The vector add operation can also be extended to two-dimensional cases, where both implementations demonstrate comparable efficiency in practice. Below is an example from the test section that readers can refer to: example. The code for this kernel is provided below:

import tilelang.language as T
def elementwise_add(
    M,
    N,
    block_M,
    block_N,
    in_dtype,
    out_dtype,
    threads,
):
    @T.prim_func
    def main(
            A: T.Tensor((M, N), in_dtype),
            B: T.Tensor((M, N), in_dtype),
            C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            start_x = bx * block_N
            start_y = by * block_M

            for (local_y, local_x) in T.Parallel(block_M, block_N):
                y = start_y + local_y
                x = start_x + local_x

                C[y, x] = A[y, x] + B[y, x]

    return main

How to compile operators with dynamic shapes?#

In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:

program = elementwise_add(T.symbolic("N"), threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")

The resulting CUDA code for the kernel will include an additional int N parameter after the bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, and bfloat16_t* __restrict__ C parameters.

Comparison of TileLang, CUDA, and CuTe#

For the subsequent examples, this tutorial will use the vector add operation for simplicity and brevity.

Typically, those new to CUDA programming often write CUDA code in a style similar to this:

// vector add
__global__ void elementwise_add(float* a, float* b, float* c, int N) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < N) {
        c[idx] = a[idx] + b[idx];
    }
}

The code above assigns each thread to compute a single element, which is evidently inefficient since common acceleration techniques like coalesced memory access and vectorization are not utilized. However, TileLang code written with similar logic (e.g., loop-based traversal) can be optimized by the compiler into highly efficient implementations, making it more accessible for beginners. Additionally, the final generated code from the compiler remains observable, providing transparency into the optimization process.

The CUDA code generated by TileLang for the compiled kernel can be retrieved using the kernel.get_kernel_source() method. Below is the CUDA code produced for the vector addition example from Section 1:

extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
  if (((int)threadIdx.x) < 32) {
    uint4 __1;
      uint4 v_ = *(uint4*)(A + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8)));
      uint4 v__1 = *(uint4*)(B + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8)));
      ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
      ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
      ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
      ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
      ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
      ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
      ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
      ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
    *(uint4*)(C + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))) = __1;
  }
}

In the code above, TileLang not only automatically maps block-level parallelism to threads but also applies optimizations such as vectorization and coalesced memory access.

While TileLang incorporates various optimizations for the aforementioned case, its behavior may sometimes appear counterintuitive. For example, when targeting 256 threads for task processing, applying vectorization can result in each thread computing 8 data elements—effectively utilizing only 32 active threads. Interestingly, the kernel launch configuration still retains the original allocation of 256 threads.

In such scenarios, explicitly specifying the number of elements computed per thread can help “guide” TileLang’s code generation process, leading to implementations that are more closely aligned with the intended design.

def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):

    @T.prim_func
    def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
        with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
            # vector add.
            for i, j in T.Parallel(threads, num_per_thread):
                offsets = (b_x * threads + i) * num_per_thread
                C[offsets + j] = A[offsets + j] + B[offsets + j]

    return main

The corresponding CUDA code generated for the above example is presented below:

extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
  uint4 __1;
    uint4 v_ = *(uint4*)(A + (((int)threadIdx.x) * 8));
    uint4 v__1 = *(uint4*)(B + (((int)threadIdx.x) * 8));
    ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
    ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
    ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
    ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
    ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
    ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
    ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
    ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
  *(uint4*)(C + (((int)threadIdx.x) * 8)) = __1;
}

Aha, this CUDA code aligns closely with conventional programming practices, making it more familiar and intuitive.

But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the T.copy(...) operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.

def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"):

    @T.prim_func
    def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
        with T.Kernel(T.ceildiv(N, threads * NUM_ELE_PER_THREAD), threads=threads) as (b_x):
            A_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
            B_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
            C_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)

            s_start = b_x * threads * NUM_ELE_PER_THREAD
            s_end = (b_x + 1) * threads * NUM_ELE_PER_THREAD

            # LDG. 128
            T.copy(
                A[s_start:s_end],
                A_register,
            )
            T.copy(
                B[s_start:s_end],
                B_register,
            )

            # vector add.
            for tid, i in T.Parallel(threads, NUM_ELE_PER_THREAD):
                C_register[tid * NUM_ELE_PER_THREAD + i] = (
                    A_register[tid * NUM_ELE_PER_THREAD + i] +
                    B_register[tid * NUM_ELE_PER_THREAD + i])

            # STG. 128
            T.copy(
                C_register,
                C[s_start:s_end],
            )

    return main

In the example above, each thread is responsible for computing 8 elements. The T.copy(...) method functions at the block level, and TileLang automatically maps data movement operations to individual threads. This design may resonate more intuitively with CUDA developers. Let us now analyze the CUDA code generated from this implementation.

// N is set to 8192 * 8192 when compiling
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
  bfloat16_t A_register[8];
  bfloat16_t B_register[8];
  *(uint4*)(A_register + 0) = *(uint4*)(A + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8)));
  *(uint4*)(B_register + 0) = *(uint4*)(B + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8)));
  uint4 __1;
    uint4 v_ = *(uint4*)(A_register + 0);
    uint4 v__1 = *(uint4*)(B_register + 0);
    ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
    ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
    ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
    ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
    ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
    ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
    ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
    ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
  *(uint4*)(A_register + 0) = __1;
  *(uint4*)(C + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))) = *(uint4*)(A_register + 0);
}

We observed the emergence of two additional registers, A_register and B_register. However, during the actual computation, these registers are simply reassigned to v_ and v__1, respectively.

To evaluate complexity, one could implement the same elementwise addition operator using CuTe and compare it with the TileLang version. The corresponding CuTe code is provided below:

template<int NUM_ELE_PER_THREAD=8>
__global__ void elementwise_add(nv_bfloat16* C, 
                                 const nv_bfloat16* A, 
                                 const nv_bfloat16* B,
                                 int N) {
  using namespace cute;

  const int idx = threadIdx.x + blockIdx.x * blockDim.x;

  Tensor t_C = make_tensor(make_gmem_ptr(C), make_shape(N));
  Tensor t_A = make_tensor(make_gmem_ptr(A), make_shape(N));
  Tensor t_B = make_tensor(make_gmem_ptr(B), make_shape(N));

  Tensor t_C_tile = local_tile(t_C, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
  Tensor t_A_tile = local_tile(t_A, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
  Tensor t_B_tile = local_tile(t_B, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));

  Tensor reg_buffer_A = make_tensor_like(t_A_tile);
  Tensor reg_buffer_B = make_tensor_like(t_B_tile);
  Tensor reg_buffer_C = make_tensor_like(t_C_tile);

  // LDG. 128
  copy(t_A_tile, reg_buffer_A);
  copy(t_B_tile, reg_buffer_B);

  auto reg_C_vector = recast<nv_bfloat162>(reg_buffer_C);
  auto reg_A_vector = recast<nv_bfloat162>(reg_buffer_A);
  auto reg_B_vector = recast<nv_bfloat162>(reg_buffer_B);

  // Perform vectorized addition
#pragma unroll
  for (int vec_idx = 0; vec_idx < size(reg_C_vector); ++vec_idx) {
    reg_C_vector(vec_idx) = reg_A_vector(vec_idx) + reg_B_vector(vec_idx);
  }

  auto reg_C_flat = recast<nv_bfloat16>(reg_C_vector);

  // STG. 128
  copy(reg_C_flat, t_C_tile);
}

Conclusion#

This tutorial showcases the implementation of the elementwise addition operator using TileLang, while also comparing various design approaches. TileLang significantly reduces the complexity of CUDA programming, enabling high performance with minimal code. Nevertheless, working with TileLang demands careful attention to specific implementation details. To ensure computational efficiency, it is essential to thoroughly examine the generated CUDA kernels.


Reference:

[1] The CuTe code implementation draws inspiration from the techniques discussed in this blog: https://zhuanlan.zhihu.com/p/690703999