tilelang.jit.adapter.cutedsl.wrapper
====================================
.. py:module:: tilelang.jit.adapter.cutedsl.wrapper
.. autoapi-nested-parse::
CuTeDSL Source Wrapper for TileLang.
This module provides C++ kernel launcher generation for the CuTeDSL backend.
Key features:
- Automatic C++ launcher generation with CUDA Driver API
- TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed)
- cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space
- Support for single and multiple kernel launches
- Complete cache system integration
Attributes
----------
.. autoapisummary::
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_DESC_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_IM2COL_DESC_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_INIT_FUNC_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_LAUNCH_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_LAUNCH_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CPP_LAUNCHER_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_TMA_ATOM_INIT_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_KERNEL_LAUNCH_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_FAKE_TENSOR_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_GEN_CODE_TEMPLATE
tilelang.jit.adapter.cutedsl.wrapper.PYTHON_HOST_FUNC_TEMPLATE
Classes
-------
.. autoapisummary::
tilelang.jit.adapter.cutedsl.wrapper.TLCuTeDSLSourceWrapper
Module Contents
---------------
.. py:data:: CPP_TMA_DESC_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name})
{{
uint64_t globalDim[{rank}] = {{{global_dim_values}}};
uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
uint32_t boxDim[{rank}] = {{{box_dim_values}}};
uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
result = cuTensorMapEncodeTiled(
&tma_descs[{desc_idx}],
static_cast({dtype}),
{rank},
reinterpret_cast({tensor_name}_ptr),
globalDim,
globalStrides,
boxDim,
elemStrides,
static_cast({interleave}),
static_cast({swizzle}),
static_cast({l2_promotion}),
static_cast({oob_fill})
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_TMA_IM2COL_DESC_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col]
{{
uint64_t globalDim[{rank}] = {{{global_dim_values}}};
uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}};
int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}};
result = cuTensorMapEncodeIm2col(
&tma_descs[{desc_idx}],
static_cast({dtype}),
{rank},
reinterpret_cast({tensor_name}_ptr),
globalDim,
globalStrides,
lowerCorner,
upperCorner,
static_cast({channels_per_pixel}),
static_cast({pixels_per_column}),
elemStrides,
static_cast({interleave}),
static_cast({swizzle}),
static_cast({l2_promotion}),
static_cast({oob_fill})
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_TMA_INIT_FUNC_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{
// Initialize {num_descs} TMA descriptor(s) in caller-provided host array
// cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically
CUresult result;
{desc_init_code}
return CUDA_SUCCESS;
}}
"""
.. raw:: html
.. py:data:: CPP_KERNEL_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Find and configure kernel {kernel_idx}: {kernel_name}
result = find_kernel_by_pattern(g_module, "{kernel_name}", &g_kernels[{kernel_idx}]);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to find kernel {kernel_name}: " << result << "\n";
return result;
}}
if ({smem_size} > 0) {{
result = cuFuncSetAttribute(g_kernels[{kernel_idx}],
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
{smem_size});
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to set smem for {kernel_name}: " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_TMA_LAUNCH_INIT_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Declare stack-local TMA descriptor array (eliminates concurrency race)
CUtensorMap tma_descs[{num_tma_descs}];
// Initialize TMA descriptors (HOST memory - passed via __grid_constant__)
// NOTE: We intentionally do NOT reuse/cached descriptors across launches.
// Pointer-only reuse is a correctness trap (shape/stride may change with same ptr),
// and correctness beats micro-optimizations.
result = tma_init(tma_descs, {tma_tensor_args});
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to initialize TMA descriptors: " << result << "\n";
return result;
}}
"""
.. raw:: html
.. py:data:: CPP_KERNEL_LAUNCH_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" // Launch kernel {kernel_idx}: {kernel_name}
{{
void* args[] = {{{kernel_args}}};
result = cuLaunchKernel(
g_kernels[{kernel_idx}],
{grid_x}, {grid_y}, {grid_z},
{block_x}, {block_y}, {block_z},
{smem_size},
stream,
args,
nullptr
);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to launch kernel {kernel_name}: " << result << "\n";
return result;
}}
}}
"""
.. raw:: html
.. py:data:: CPP_LAUNCHER_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""#include
#include
#include
#include
#include
#include
#include
// TVM Headers
#include
#include
#include
// Cached module handle
static CUmodule g_module = nullptr;
static bool g_module_initialized = false;
// Cached kernel functions
static CUfunction g_kernels[{num_kernels}] = {{nullptr}};
static bool g_kernels_initialized = false;
// Find kernel by pattern (substring match, prefer base name over _N variants)
CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{
CUresult result;
unsigned int num_funcs = 0;
result = cuModuleGetFunctionCount(&num_funcs, module);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to get function count: " << result << "\n";
return result;
}}
std::vector func_list(num_funcs);
result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module);
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to enumerate functions: " << result << "\n";
return result;
}}
// Collect substring matches, separating base name from _N variants
std::vector> base_matches; // pattern not followed by _digit
std::vector> variant_matches; // pattern followed by _digit
size_t pattern_len = std::strlen(pattern);
for (unsigned int i = 0; i < num_funcs; i++) {{
const char* func_name = nullptr;
result = cuFuncGetName(&func_name, func_list[i]);
if (result != CUDA_SUCCESS || func_name == nullptr) {{
std::cerr << "Failed to get function name: " << result << "\n";
return result;
}}
std::string name_str(func_name);
size_t pos = name_str.find(pattern);
if (pos != std::string::npos) {{
// Found substring match
size_t after_pattern = pos + pattern_len;
// Check what follows the pattern
if (after_pattern < name_str.length() &&
name_str[after_pattern] == '_' &&
after_pattern + 1 < name_str.length() &&
std::isdigit(name_str[after_pattern + 1])) {{
// Pattern followed by _digit (e.g., "main_kernel_1")
variant_matches.push_back({{name_str, func_list[i]}});
}} else {{
// Pattern not followed by _digit (e.g., "main_kernel" itself)
base_matches.push_back({{name_str, func_list[i]}});
}}
}}
}}
// Decision logic: prefer base matches over variant matches
if (!base_matches.empty()) {{
if (base_matches.size() == 1) {{
*out_func = base_matches[0].second;
return CUDA_SUCCESS;
}}
// Multiple base matches - ambiguous
std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size()
<< " base kernels (ambiguous). Matches found:\n";
for (const auto& match : base_matches) {{
std::cerr << " - " << match.first << "\n";
}}
std::cerr << "Please use a more specific pattern.\n";
return CUDA_ERROR_NOT_FOUND;
}}
// No base matches, try variant matches
if (!variant_matches.empty()) {{
if (variant_matches.size() == 1) {{
*out_func = variant_matches[0].second;
return CUDA_SUCCESS;
}}
// Multiple variant matches - ambiguous
std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size()
<< " variant kernels (ambiguous). Matches found:\n";
for (const auto& match : variant_matches) {{
std::cerr << " - " << match.first << "\n";
}}
std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\n";
return CUDA_ERROR_NOT_FOUND;
}}
// No matches at all
std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\n";
return CUDA_ERROR_NOT_FOUND;
}}
// Initialize CUDA module (called once on first launch)
static CUresult tilelang_init_cuda_module(const std::string& cubin_path) {{
if (g_module_initialized) return CUDA_SUCCESS;
CUresult result;
result = cuInit(0);
if (result != CUDA_SUCCESS) return result;
std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary);
if (!cubin_file) {{
std::cerr << "Failed to open cubin file: " << cubin_path << "\n";
return CUDA_ERROR_FILE_NOT_FOUND;
}}
std::vector cubin_data((std::istreambuf_iterator(cubin_file)),
std::istreambuf_iterator());
cubin_file.close();
if (cubin_data.empty()) {{
std::cerr << "Empty cubin file: " << cubin_path << "\n";
return CUDA_ERROR_INVALID_IMAGE;
}}
result = cuModuleLoadData(&g_module, cubin_data.data());
if (result != CUDA_SUCCESS) {{
std::cerr << "Failed to load CUDA module: " << result << "\n";
return result;
}}
g_module_initialized = true;
return CUDA_SUCCESS;
}}
// Initialize all kernel functions (called once after module load)
static CUresult tilelang_init_kernels() {{
if (g_kernels_initialized) return CUDA_SUCCESS;
CUresult result;
{kernel_inits}
g_kernels_initialized = true;
return CUDA_SUCCESS;
}}
// TMA descriptor initialization (host-side)
{tma_init_func}
// Main kernel launcher
extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, tvm::ffi::Bytes cubin_path) {{
CUresult result;
std::string cubin_path_str(reinterpret_cast(cubin_path.data()), cubin_path.size());
result = tilelang_init_cuda_module(cubin_path_str);
if (result != CUDA_SUCCESS) return result;
result = tilelang_init_kernels();
if (result != CUDA_SUCCESS) return result;
{get_ptr_code}
CUstream stream = (CUstream)_stream;
{tma_init_in_launch}
{kernel_launches}
return CUDA_SUCCESS;
}}
// Cleanup function
extern "C" CUresult cleanup_module() {{
if (g_module_initialized && g_module != nullptr) {{
cuModuleUnload(g_module);
g_module = nullptr;
g_module_initialized = false;
}}
g_kernels_initialized = false;
return CUDA_SUCCESS;
}}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module);
"""
.. raw:: html
.. py:data:: CUBIN_TMA_ATOM_INIT_TEMPLATE
:value: ' {desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))'
.. py:data:: CUBIN_KERNEL_LAUNCH_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
""" {function_name}({call_args}).launch(
grid=[{grid_x}, {grid_y}, {grid_z}],
block=[{block_x}, {block_y}, {block_z}],
smem={smem_size},
stream=stream,
)"""
.. raw:: html
.. py:data:: CUBIN_FAKE_TENSOR_TEMPLATE
:value: ' __fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)],...
.. py:data:: CUBIN_GEN_CODE_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""{lib_code}
@cute.jit
def kernel_wrapper({wrapper_args}):
{tma_init_code}{kernel_launches}
# Compile kernels to generate cubin
{fake_tensor_code}{fake_tma_tensor_code} __fake_stream__ = make_fake_stream()
# Always generate cubin under a unique staging directory to avoid concurrent
# processes clobbering each other's intermediate artifacts.
_staging_dir = Path(tempfile.mkdtemp(
prefix=Path(__file__).stem + ".cubin.staging.",
dir=_module_dir,
))
try:
_kernel_wrapper = cute.compile(
kernel_wrapper,
{compile_args},
options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}",
)
# CuTeDSL generates a long, mangled cubin filename that includes argument/type info,
# e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin.
_cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime)
if len(_cubin_files) != 1:
raise RuntimeError(
f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}"
)
os.replace(_cubin_files[0], _cubin_path)
finally:
shutil.rmtree(_staging_dir, ignore_errors=True)"""
.. raw:: html
.. py:data:: PYTHON_HOST_FUNC_TEMPLATE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""import os
from pathlib import Path
# Minimal imports for runtime (no cutlass/cute - only needed for cubin generation)
import tvm.runtime as runtime
_cpp_launcher = None
_cpp_launcher_lib = None
_cubin_generated = False
# Pre-compute paths - cubin is stored alongside the launcher .so
# Use module basename to avoid conflicts when multiple kernels run concurrently
# e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin"
# "kernel.py" (in cache) -> "kernel.cubin"
_module_dir = Path(os.path.dirname(__file__))
_cubin_path = _module_dir / (Path(__file__).stem + ".cubin")
_cubin_path_bytes = _cubin_path.as_posix().encode('utf-8')
_cubin_needs_generation = not _cubin_path.exists()
def _generate_cubin_if_needed({cubin_gen_params}):
"""Generate cubin file on first call.
All CuTeDSL imports are inside this function to avoid slow
module-level initialization when loading from cache.
"""
global _cubin_generated, _cubin_path
# Lazy import CuTeDSL only when cubin generation is needed
from cuda.bindings.driver import CUstream
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor
import tilelang.contrib.cutedsl as tl
# We rely on CuTeDSL's keep-cubin artifact rather than custom extraction.
import tempfile
import shutil
_DTYPE_MAP = {{
"torch.float32": cutlass.Float32,
"torch.float16": cutlass.Float16,
"torch.bfloat16": cutlass.BFloat16,
"torch.float8_e4m3fnuz": cutlass.Float8E4M3FN,
"torch.float8_e4m3fn": cutlass.Float8E4M3FN,
"torch.float8_e5m2": cutlass.Float8E5M2,
"torch.float64": cutlass.Float64,
"torch.int64": cutlass.Int64,
"torch.int32": cutlass.Int32,
"torch.uint32": cutlass.Uint32,
"torch.bool": cutlass.Boolean,
"torch.int8": cutlass.Int8,
"torch.uint8": cutlass.Uint8,
"torch.int16": cutlass.Int16,
"torch.uint16": cutlass.Uint16,
"torch.uchar": cutlass.Uint8,
}}
{cubin_gen_code}
_cubin_generated = True
def _load_cpp_launcher():
"""Load C++ kernel launcher."""
global _cpp_launcher, _cpp_launcher_lib
if _cpp_launcher is not None:
return _cpp_launcher
lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}")
if not os.path.exists(lib_path):
raise FileNotFoundError(f"Launcher not found: {{lib_path}}")
_cpp_launcher_lib = runtime.load_module(lib_path)
_cpp_launcher = _cpp_launcher_lib["launch_kernel"]
return _cpp_launcher
def call({call_func_params}, stream):
"""Kernel dispatch function."""
global _cubin_path_bytes, _cubin_needs_generation
if _cubin_needs_generation:
_generate_cubin_if_needed({cubin_gen_call_args})
_cubin_needs_generation = False
{arg_prep_code}
launcher = _load_cpp_launcher()
result = launcher({launcher_call_args}, stream, _cubin_path_bytes)
if result != 0:
raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}")
"""
.. raw:: html
.. py:class:: TLCuTeDSLSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None)
Bases: :py:obj:`tilelang.jit.adapter.wrapper.TLCUDASourceWrapper`
Wrapper class for TileLang CuTe DSL backend with C++ launcher.
Generates optimized C++ launcher code that:
- Loads cubin via CUDA Driver API
- Passes TMA descriptors by value (host-side, no device copy)
- Launches kernels with minimal Python overhead
- Supports both single and multiple kernel scenarios
.. py:property:: host_func
Override parent's host_func to return generated Python code.
.. py:method:: generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map)
Generate TMA descriptor information for C++ code generation.
:returns: List of descriptor variable names in the order they were processed.
.. py:method:: create_dispatch_func(code, function_informations)
Create dispatch function - always use C++ launcher.
.. py:method:: create_dispatch_func_cpp_launcher(code, function_informations)
Create dispatch function using C++ launcher.
.. py:method:: get_launcher_cpp_code()
Get the generated C++ launcher code.
.. py:method:: update_lib_code(code)
Update the library code with the given code string.