tilelang.jit.kernel module#
- class tilelang.jit.kernel.JITKernel(func: Optional[PrimFunc] = None, out_idx: Optional[Union[List[int], int]] = None, execution_backend: Literal['dlpack', 'ctypes', 'cython'] = 'cython', target: Union[str, Target] = 'auto', target_host: Optional[Union[str, Target]] = None, verbose: bool = False, pass_configs: Optional[Dict[str, Any]] = None, from_database: bool = False)#
Bases:
object
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
- artifact#
The compiled artifact containing the runtime module and parameters.
- Type:
- adapter#
The adapter for the compiled function.
- Type:
- torch_function#
The compiled function that can be invoked as a PyTorch-compatible function.
- Type:
Callable
- adapter: BaseKernelAdapter = None#
- artifact: CompiledArtifact = None#
- export_library(kernel_file: str) None #
Exports the compiled kernel function to a shared library file.
- Parameters:
kernel_file (str) – The path to the shared library file to create.
- classmethod from_database(func: PrimFunc, kernel_global_source: str, kernel_lib_path: str, params: List[KernelParam], target: Union[str, Target], target_host: Union[str, Target], out_idx: Union[List[int], int], execution_backend: Literal['dlpack', 'ctypes', 'cython'], pass_configs: Optional[Dict[str, Any]] = None)#
Alternative constructor to create a TorchFunction directly from a database.
- classmethod from_tilelang_function(tilelang_func: PrimFunc, **kwargs)#
Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.
- Parameters:
tilelang_func (tvm.tir.PrimFunc) – The TileLang (TVM TIR) function to compile.
**kwargs (dict) – Additional keyword arguments to pass to the constructor.
- Returns:
An instance of TorchFunction wrapping the compiled function.
- Return type:
TorchFunction
- get_host_source() str #
Returns the source code of the host function.
- get_kernel_source() str #
Returns the source code of the compiled kernel function.
- Returns:
The source code of the compiled kernel function.
- Return type:
str
- get_profiler(tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) Profiler #
Creates a profiler to benchmark the compiled runtime module.
- Parameters:
tensor_supply_type (TensorSupplyType, optional) – The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).
- Returns:
A Profiler instance for benchmarking the runtime module.
- Return type:
- property host_source: str#
- property kernel_source: str#
- property out_idx: List[int]#
- property params: List[KernelParam]#
- run_once(func: Optional[Callable] = None) None #
- torch_function: Callable = None#