tilelang.jit.kernel¶

Classes¶

JITKernel

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

Module Contents¶

class tilelang.jit.kernel.JITKernel(func=None, out_idx=None, execution_backend='cython', target='auto', target_host=None, verbose=False, pass_configs=None, from_database=False, compile_flags=None)¶

Bases: object

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

Parameters:
  • func (tvm.tir.PrimFunc)

  • out_idx (Union[List[int], int])

  • execution_backend (Literal['dlpack', 'ctypes', 'cython', 'nvrtc'])

  • target (Union[str, tvm.target.Target])

  • target_host (Union[str, tvm.target.Target])

  • verbose (bool)

  • pass_configs (Optional[Dict[str, Any]])

  • from_database (bool)

  • compile_flags (Optional[List[str]])

artifact¶

The compiled artifact containing the runtime module and parameters.

Type:

CompiledArtifact

adapter¶

The adapter for the compiled function.

Type:

BaseKernelAdapter

torch_function¶

The compiled function that can be invoked as a PyTorch-compatible function.

Type:

Callable

prim_func: tvm.tir.PrimFunc = None¶
artifact: tilelang.engine.param.CompiledArtifact = None¶
adapter: tilelang.jit.adapter.BaseKernelAdapter = None¶
torch_function: Callable = None¶
latency: float = None¶
config: Dict[str, Any] = None¶
ref_latency: float = None¶
execution_backend = 'cython'¶
target_host = None¶
verbose = False¶
pass_configs = None¶
compile_flags = None¶
target¶
classmethod from_database(func, kernel_global_source, kernel_lib_path, params, target, target_host, out_idx, execution_backend, pass_configs=None, compile_flags=None)¶

Alternative constructor to create a TorchFunction directly from a database.

Parameters:
  • func (tvm.tir.PrimFunc)

  • kernel_global_source (str)

  • kernel_lib_path (str)

  • params (List[tilelang.engine.param.KernelParam])

  • target (Union[str, tvm.target.Target])

  • target_host (Union[str, tvm.target.Target])

  • out_idx (Union[List[int], int])

  • execution_backend (Literal['dlpack', 'ctypes', 'cython', 'nvrtc'])

  • pass_configs (Optional[Dict[str, Any]])

  • compile_flags (Optional[List[str]])

__call__(*args, **kwds)¶

Invokes the compiled function with the given arguments.

Parameters:
  • *args (Any) – Positional arguments for the function.

  • **kwds (Any) – Keyword arguments for the function.

Returns:

The result of the function execution.

Return type:

Any

classmethod from_tilelang_function(tilelang_func, **kwargs)¶

Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.

Parameters:
  • tilelang_func (tvm.tir.PrimFunc) – The TileLang (TVM TIR) function to compile.

  • **kwargs (dict) – Additional keyword arguments to pass to the constructor.

Returns:

An instance of TorchFunction wrapping the compiled function.

Return type:

TorchFunction

get_profiler(tensor_supply_type=TensorSupplyType.Auto)¶

Creates a profiler to benchmark the compiled runtime module.

Parameters:

tensor_supply_type (TensorSupplyType, optional) – The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).

Returns:

A Profiler instance for benchmarking the runtime module.

Return type:

Profiler

get_kernel_source()¶

Returns the source code of the compiled kernel function.

Returns:

The source code of the compiled kernel function.

Return type:

str

get_host_source()¶

Returns the source code of the host function.

Return type:

str

run_once(func=None)¶
Parameters:

func (Optional[Callable])

Return type:

None

update_tuner_result(latency, config, ref_latency)¶

Updates the tuning results for this kernel.

Parameters:
  • latency (float) – The measured latency of this kernel configuration.

  • config (Dict[str, Any]) – The configuration parameters used for this kernel.

  • ref_latency (float) – The reference latency to compare against.

Return type:

None

get_tuner_result()¶

Gets the tuning results for this kernel.

Returns:

A dictionary containing: - latency: The measured latency of this kernel - config: The configuration parameters used - ref_latency: The reference latency for comparison

Return type:

Dict[str, Any]

property out_idx: List[int]¶
Return type:

List[int]

property params: List[tilelang.engine.param.KernelParam]¶
Return type:

List[tilelang.engine.param.KernelParam]

property kernel_source: str¶
Return type:

str

property host_source: str¶
Return type:

str

export_library(kernel_file)¶

Exports the compiled kernel function to a shared library file.

Parameters:

kernel_file (str) – The path to the shared library file to create.

Return type:

None