tilelang.jit

This module provides an auto-tuning infrastructure for TileLang (tl) programs. It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM.

Submodules

Attributes

Classes

JITImpl

Just-In-Time compilation wrapper for TileLang programs.

Functions

compile([func, out_idx, execution_backend, target, ...])

Compile the given TileLang PrimFunc with TVM and build a JITKernel.

par_compile(funcs[, out_idx, execution_backend, ...])

Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.

jit(…)

JIT compiler decorator for TileLang functions.

Package Contents

tilelang.jit.logger
tilelang.jit.compile(func=None, out_idx=None, execution_backend=None, target=None, target_host=None, verbose=None, pass_configs=None, compile_flags=None)

Compile the given TileLang PrimFunc with TVM and build a JITKernel.

Parameters:
  • func (tvm.tir.PrimFunc, optional) – The TileLang TIR function to compile and wrap.

  • out_idx (Union[List[int], int], optional) – Index(es) of the output tensors to return (default: None).

  • execution_backend (Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional) – Execution backend to use for kernel execution. If None, reads from TILELANG_EXECUTION_BACKEND environment variable (defaults to “auto”).

  • target (Union[str, Target], optional) – Compilation target, either as a string or a TVM Target object. If None, reads from TILELANG_TARGET environment variable (defaults to “auto”).

  • target_host (Union[str, Target], optional) – Target host for cross-compilation (default: None).

  • verbose (bool, optional) – Whether to enable verbose output. If None, reads from TILELANG_VERBOSE environment variable (defaults to False).

  • pass_configs (dict, optional) – Additional keyword arguments to pass to the Compiler PassContext. Refer to tilelang.transform.PassConfigKey for supported options.

  • Variables (Environment)

  • ---------------------

  • TILELANG_TARGET (str) – Default compilation target (e.g., “cuda”, “llvm”). Defaults to “auto”.

  • TILELANG_EXECUTION_BACKEND (str) – Default execution backend. Defaults to “auto”.

  • TILELANG_VERBOSE (str) – Set to “1”, “true”, “yes”, or “on” to enable verbose compilation by default.

  • compile_flags (list[str] | str | None)

Return type:

kernel.JITKernel[_KP, _T]

tilelang.jit.par_compile(funcs, out_idx=None, execution_backend=None, target=None, target_host=None, verbose=None, pass_configs=None, compile_flags=None, num_workers=None, ignore_error=False)

Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.

Parameters:
  • funcs (Iterable[tvm.tir.PrimFunc]) – The TileLang TIR functions to compile and wrap.

  • out_idx (Union[List[int], int], optional) – Index(es) of the output tensors to return (default: None).

  • execution_backend (Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional) – Execution backend to use for kernel execution. If None, reads from TILELANG_EXECUTION_BACKEND environment variable (defaults to “auto”).

  • target (Union[str, Target], optional) – Compilation target, either as a string or a TVM Target object. If None, reads from TILELANG_TARGET environment variable (defaults to “auto”).

  • target_host (Union[str, Target], optional) – Target host for cross-compilation (default: None).

  • verbose (bool, optional) – Whether to enable verbose output. If None, reads from TILELANG_VERBOSE environment variable (defaults to False).

  • pass_configs (dict, optional) – Additional keyword arguments to pass to the Compiler PassContext. Refer to tilelang.transform.PassConfigKey for supported options.

  • Variables (Environment)

  • ---------------------

  • TILELANG_TARGET (str) – Default compilation target (e.g., “cuda”, “llvm”). Defaults to “auto”.

  • TILELANG_EXECUTION_BACKEND (str) – Default execution backend. Defaults to “auto”.

  • TILELANG_VERBOSE (str) – Set to “1”, “true”, “yes”, or “on” to enable verbose compilation by default.

  • compile_flags (list[str] | str | None)

  • num_workers (int | None)

  • ignore_error (bool)

Return type:

list[kernel.JITKernel[_KP, _T]]

class tilelang.jit.JITImpl

Bases: Generic[_P, _KP, _T, _Ret]

Just-In-Time compilation wrapper for TileLang programs.

This class provides a unified interface for compiling and executing TileLang kernels. It supports two execution modes that are automatically inferred:

Execution Modes

  • lazy: The decorated function returns a PrimFunc explicitly. Calling the JIT wrapper returns a compiled kernel object, which can be invoked separately. This mode is useful when you want to inspect or reuse the kernel object.

    Example (lazy mode):

    @tilelang.jit(out_idx=[-1])
    def matmul(M, N, K, block_M, block_N, block_K):
        @T.prim_func
        def kernel(A: T.Tensor((M, K), dtype), ...):
            ...
        return kernel  # explicitly return PrimFunc
    
    kernel = matmul(1024, 1024, 1024, 128, 128, 32)  # returns kernel
    result = kernel(a, b)  # execute separately
    
  • eager: The decorated function uses the DSL builder pattern with tensor type annotations. Calling the JIT wrapper compiles and immediately executes the kernel, returning the result directly.

    Example (eager mode):

    @tilelang.jit
    def gemm(A, B, C, block_M: int = 64):
        M, N, K = T.const("M N K")
        A: T.Tensor[[M, K], dtype]  # tensor shape via annotation
        B: T.Tensor[[K, N], dtype]
        C: T.Tensor[[M, N], dtype]
        with T.Kernel(...):
            ...
    
    gemm(A, B, C)  # compiles and executes immediately
    

The mode is automatically inferred based on whether the function returns a PrimFunc (lazy) or uses the builder pattern (eager).

out_idx

Index(es) of output tensor(s) to return (lazy mode only).

Type:

list[int] | int | None

execution_backend

Backend for kernel execution (“auto”, “dlpack”, “tvm_ffi”, etc.).

Type:

str | None

target

TVM compilation target (e.g., “cuda”, “llvm”, “auto”).

Type:

str | Target | None

target_host

Host target for cross-compilation.

Type:

str | Target | None

verbose

Enable verbose compilation output.

Type:

bool | None

pass_configs

TVM pass configuration options.

Type:

dict[str, Any] | None

debug_root_path

Directory to save compiled kernel source for debugging.

Type:

str | None

compile_flags

Additional compiler flags.

Type:

list[str] | str | None

func_source

Original Python source code of the decorated function.

Type:

str

signature

Function signature of the original function.

Type:

inspect.Signature

mode

Execution mode. “auto” infers from function behavior.

Type:

Literal[“auto”, “lazy”, “eager”]

func

The wrapped function object.

Type:

JITFunc

out_idx: list[int] | int | None
execution_backend: Literal['auto', 'dlpack', 'tvm_ffi', 'cython', 'nvrtc', 'torch', 'cutedsl'] | None
target: str | tvm.target.Target | None
target_host: str | tvm.target.Target | None
verbose: bool | None
pass_configs: dict[str, Any] | None
debug_root_path: str | None
compile_flags: list[str] | str | None
func_source: str
signature: inspect.Signature
mode: Literal['auto', 'lazy', 'eager']
func: tilelang.language.eager.JITFunc[_KP, _T]
__post_init__()
get_tir(*args, **kwargs)

Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.

Parameters:
  • args (_P)

  • kwargs (_P)

Return type:

tilelang.language.eager.PrimFunc[_KP, _T]

initialize_jit_mode(*args, **kwargs)
Parameters:
  • args (_P)

  • kwargs (_P)

Return type:

Literal[‘lazy’, ‘eager’]

par_compile(configs, num_workers=None, ignore_error=False)

Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. :param configs: The configurations to elaborate and compile. Each config can be either

a dictionary mapping keyword arguments to values, or a tuple of positional arguments.

Parameters:
  • num_workers (int, optional) – Number of parallel workers to use for compilation. Defaults to None, which lets the system decide.

  • ignore_error (bool, optional) – If True, compilation errors for individual configs will be logged as warnings and the corresponding result will be None. If False, any compilation error will raise an exception. Defaults to False.

  • configs (Iterable[Union[dict[str, Any], tuple[Any, ...]]])

Returns:

A list of compiled JITKernel objects corresponding to the provided configs.

Return type:

List[JITKernel]

compile(*args, **kwargs)
Parameters:
  • args (_P)

  • kwargs (_P)

Return type:

_Ret

parse_cache_key(*args, **kwargs)
Parameters:
  • args (_P)

  • kwargs (_P)

get_kernel_source(*args, **kwargs)
Parameters:
  • args (_P)

  • kwargs (_P)

Return type:

str

__call__(*args, **kwargs)
Parameters:
  • args (_P)

  • kwargs (_P)

Return type:

_Ret

tilelang.jit.ExecutionBackend
tilelang.jit.jit(func: Callable[_KP, _T]) JITImpl[_KP, _KP, _T, _T]
tilelang.jit.jit(*, out_idx: Any = None, target: str | tvm.target.Target | None = None, target_host: str | tvm.target.Target | None = None, execution_backend: ExecutionBackend | None = None, verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None) Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]

JIT compiler decorator for TileLang functions.

Supports two execution modes (automatically inferred): - lazy: Function returns PrimFunc explicitly. Returns compiled kernel object. - eager: Function uses DSL builder pattern. Executes kernel immediately.

Parameters:
  • out_idx (list[int] | int | None) – Output tensor index(es). Only supported in lazy mode.

  • target (str | Target | None) – TVM compilation target (e.g., “cuda”, “llvm”, “auto”).

  • target_host (str | Target | None) – Host target for cross-compilation.

  • execution_backend (ExecutionBackend | None) – Backend for kernel execution.

  • verbose (bool | None) – Enable verbose compilation output.

  • pass_configs (dict[str, Any] | None) – TVM pass configuration options.

  • debug_root_path (str | None) – Directory to save compiled kernel source for debugging.

  • compile_flags (list[str] | str | None) – Additional compiler flags.