tilelang.jit.kernel module#

class tilelang.jit.kernel.JITKernel(func: Optional[PrimFunc] = None, out_idx: Optional[Union[List[int], int]] = None, execution_backend: Literal['dlpack', 'ctypes', 'cython'] = 'cython', target: Union[str, Target] = 'auto', target_host: Optional[Union[str, Target]] = None, verbose: bool = False, pass_configs: Optional[Dict[str, Any]] = None, from_database: bool = False)#

Bases: object

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

artifact#

The compiled artifact containing the runtime module and parameters.

Type:

CompiledArtifact

adapter#

The adapter for the compiled function.

Type:

BaseKernelAdapter

torch_function#

The compiled function that can be invoked as a PyTorch-compatible function.

Type:

Callable

adapter: BaseKernelAdapter = None#
artifact: CompiledArtifact = None#
export_library(kernel_file: str) None#

Exports the compiled kernel function to a shared library file.

Parameters:

kernel_file (str) – The path to the shared library file to create.

classmethod from_database(func: PrimFunc, kernel_global_source: str, kernel_lib_path: str, params: List[KernelParam], target: Union[str, Target], target_host: Union[str, Target], out_idx: Union[List[int], int], execution_backend: Literal['dlpack', 'ctypes', 'cython'], pass_configs: Optional[Dict[str, Any]] = None)#

Alternative constructor to create a TorchFunction directly from a database.

classmethod from_tilelang_function(tilelang_func: PrimFunc, **kwargs)#

Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.

Parameters:
  • tilelang_func (tvm.tir.PrimFunc) – The TileLang (TVM TIR) function to compile.

  • **kwargs (dict) – Additional keyword arguments to pass to the constructor.

Returns:

An instance of TorchFunction wrapping the compiled function.

Return type:

TorchFunction

get_host_source() str#

Returns the source code of the host function.

get_kernel_source() str#

Returns the source code of the compiled kernel function.

Returns:

The source code of the compiled kernel function.

Return type:

str

get_profiler(tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) Profiler#

Creates a profiler to benchmark the compiled runtime module.

Parameters:

tensor_supply_type (TensorSupplyType, optional) – The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).

Returns:

A Profiler instance for benchmarking the runtime module.

Return type:

Profiler

property host_source: str#
property kernel_source: str#
property out_idx: List[int]#
property params: List[KernelParam]#
run_once(func: Optional[Callable] = None) None#
torch_function: Callable = None#