tilelang.jit.adapter.tvm_ffi¶
Utilities to adapt TVM FFI kernels to Torch tensors.
This adapter intentionally captures PyTorch’s current CUDA stream and device via light-weight callables so that, when the wrapped function is invoked, the execution observes the same stream context as the active Torch code. On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
Classes¶
Adapter that runs a TVM runtime.Executable with Torch tensors. |
Module Contents¶
- class tilelang.jit.adapter.tvm_ffi.TVMFFIKernelAdapter(params, result_idx, target, func_or_mod, host_mod=None, device_mod=None, rt_mod=None, host_kernel_source=None, device_kernel_source=None, verbose=False, pass_configs=None, compile_flags=None)¶
Bases:
tilelang.jit.adapter.base.BaseKernelAdapterAdapter that runs a TVM runtime.Executable with Torch tensors.
Notes - We capture the “current” PyTorch CUDA stream/device as thunks (callables)
rather than materializing them at construction time. This ensures the actual stream/device is read just-in-time when the function runs, matching the user’s current Torch context (e.g., after a stream guard/switch).
The stream pointer returned is a raw CUDA stream handle compatible with TVM’s device API; on CPU or when CUDA is unavailable, we return 0.
- Parameters:
params (list[tilelang.engine.param.KernelParam])
result_idx (list[int])
target (str | tvm.target.Target)
func_or_mod (tvm.tir.PrimFunc | tilelang.tvm.IRModule)
host_mod (tilelang.tvm.IRModule | None)
device_mod (tilelang.tvm.IRModule | None)
rt_mod (tilelang.tvm.runtime.Module | None)
host_kernel_source (str | None)
device_kernel_source (str | None)
verbose (bool)
pass_configs (dict[str, Any] | None)
compile_flags (list[str] | None)
- target: str | tvm.target.Target = 'cuda'¶
- ir_module: tilelang.tvm.IRModule | None = None¶
- host_kernel_source: str | None = None¶
- device_kernel_source: str | None = None¶
- executable: tilelang.tvm.runtime.Executable | None = None¶
- pass_configs: dict[str, Any] | None = None¶
- host_mod: tilelang.tvm.IRModule | None = None¶
- device_mod: tilelang.tvm.IRModule | None = None¶
- rt_mod: tilelang.tvm.runtime.Module | None = None¶
- dynamic_symbolic_map: dict[tvm.tir.Var, tuple[int, int, int]] | None = None¶
- params¶
- result_idx¶
- verbose = False¶
- compile_flags = None¶
- classmethod from_database(params, result_idx, target, func_or_mod, host_kernel_source, device_kernel_source, kernel_lib_path, verbose=False, pass_configs=None, compile_flags=None)¶
- Parameters:
params (list[tvm.relax.TensorType])
result_idx (list[int])
target (str)
func_or_mod (tvm.tir.PrimFunc | tilelang.tvm.IRModule)
host_kernel_source (str)
device_kernel_source (str)
kernel_lib_path (str)
verbose (bool)
pass_configs (dict[str, Any] | None)
compile_flags (list[str] | None)
- get_host_source()¶
Returns the source code of the host module.
- get_device_source()¶
Returns the source code of the device module.
- get_kernel_source(kernel_only=False)¶
Returns the source code of the compiled kernel.
- Parameters:
kernel_only (bool)
- property prim_func: tvm.tir.PrimFunc¶
Returns the primary TIR function from the IR module.
- Return type:
tvm.tir.PrimFunc