tilelang.jit.adapter.cutedsl.wrapper ==================================== .. py:module:: tilelang.jit.adapter.cutedsl.wrapper .. autoapi-nested-parse:: CuTeDSL Source Wrapper for TileLang. This module provides C++ kernel launcher generation for the CuTeDSL backend. Key features: - Automatic C++ launcher generation with CUDA Driver API - TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed) - cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space - Support for single and multiple kernel launches - Complete cache system integration Attributes ---------- .. autoapisummary:: tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_DESC_INIT_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_IM2COL_DESC_INIT_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_INIT_FUNC_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_INIT_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_LAUNCH_INIT_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_LAUNCH_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CPP_LAUNCHER_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CUBIN_TMA_ATOM_INIT_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CUBIN_KERNEL_LAUNCH_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CUBIN_FAKE_TENSOR_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.CUBIN_GEN_CODE_TEMPLATE tilelang.jit.adapter.cutedsl.wrapper.PYTHON_HOST_FUNC_TEMPLATE Classes ------- .. autoapisummary:: tilelang.jit.adapter.cutedsl.wrapper.TLCuTeDSLSourceWrapper Module Contents --------------- .. py:data:: CPP_TMA_DESC_INIT_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) {{ uint64_t globalDim[{rank}] = {{{global_dim_values}}}; uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}}; uint32_t boxDim[{rank}] = {{{box_dim_values}}}; uint32_t elemStrides[{rank}] = {{{elem_stride_values}}}; result = cuTensorMapEncodeTiled( &tma_descs[{desc_idx}], static_cast({dtype}), {rank}, reinterpret_cast({tensor_name}_ptr), globalDim, globalStrides, boxDim, elemStrides, static_cast({interleave}), static_cast({swizzle}), static_cast({l2_promotion}), static_cast({oob_fill}) ); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\n"; return result; }} }} """ .. raw:: html
.. py:data:: CPP_TMA_IM2COL_DESC_INIT_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col] {{ uint64_t globalDim[{rank}] = {{{global_dim_values}}}; uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}}; uint32_t elemStrides[{rank}] = {{{elem_stride_values}}}; int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}}; int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}}; result = cuTensorMapEncodeIm2col( &tma_descs[{desc_idx}], static_cast({dtype}), {rank}, reinterpret_cast({tensor_name}_ptr), globalDim, globalStrides, lowerCorner, upperCorner, static_cast({channels_per_pixel}), static_cast({pixels_per_column}), elemStrides, static_cast({interleave}), static_cast({swizzle}), static_cast({l2_promotion}), static_cast({oob_fill}) ); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\n"; return result; }} }} """ .. raw:: html
.. py:data:: CPP_TMA_INIT_FUNC_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{ // Initialize {num_descs} TMA descriptor(s) in caller-provided host array // cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically CUresult result; {desc_init_code} return CUDA_SUCCESS; }} """ .. raw:: html
.. py:data:: CPP_KERNEL_INIT_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ // Find and configure kernel {kernel_idx}: {kernel_name} result = find_kernel_by_pattern(g_module, "{kernel_name}", &g_kernels[{kernel_idx}]); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to find kernel {kernel_name}: " << result << "\n"; return result; }} if ({smem_size} > 0) {{ result = cuFuncSetAttribute(g_kernels[{kernel_idx}], CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, {smem_size}); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to set smem for {kernel_name}: " << result << "\n"; return result; }} }} """ .. raw:: html
.. py:data:: CPP_TMA_LAUNCH_INIT_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ // Declare stack-local TMA descriptor array (eliminates concurrency race) CUtensorMap tma_descs[{num_tma_descs}]; // Initialize TMA descriptors (HOST memory - passed via __grid_constant__) // NOTE: We intentionally do NOT reuse/cached descriptors across launches. // Pointer-only reuse is a correctness trap (shape/stride may change with same ptr), // and correctness beats micro-optimizations. result = tma_init(tma_descs, {tma_tensor_args}); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to initialize TMA descriptors: " << result << "\n"; return result; }} """ .. raw:: html
.. py:data:: CPP_KERNEL_LAUNCH_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ // Launch kernel {kernel_idx}: {kernel_name} {{ void* args[] = {{{kernel_args}}}; result = cuLaunchKernel( g_kernels[{kernel_idx}], {grid_x}, {grid_y}, {grid_z}, {block_x}, {block_y}, {block_z}, {smem_size}, stream, args, nullptr ); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to launch kernel {kernel_name}: " << result << "\n"; return result; }} }} """ .. raw:: html
.. py:data:: CPP_LAUNCHER_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """#include #include #include #include #include #include #include // TVM Headers #include #include #include // Cached module handle static CUmodule g_module = nullptr; static bool g_module_initialized = false; // Cached kernel functions static CUfunction g_kernels[{num_kernels}] = {{nullptr}}; static bool g_kernels_initialized = false; // Find kernel by pattern (substring match, prefer base name over _N variants) CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{ CUresult result; unsigned int num_funcs = 0; result = cuModuleGetFunctionCount(&num_funcs, module); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to get function count: " << result << "\n"; return result; }} std::vector func_list(num_funcs); result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to enumerate functions: " << result << "\n"; return result; }} // Collect substring matches, separating base name from _N variants std::vector> base_matches; // pattern not followed by _digit std::vector> variant_matches; // pattern followed by _digit size_t pattern_len = std::strlen(pattern); for (unsigned int i = 0; i < num_funcs; i++) {{ const char* func_name = nullptr; result = cuFuncGetName(&func_name, func_list[i]); if (result != CUDA_SUCCESS || func_name == nullptr) {{ std::cerr << "Failed to get function name: " << result << "\n"; return result; }} std::string name_str(func_name); size_t pos = name_str.find(pattern); if (pos != std::string::npos) {{ // Found substring match size_t after_pattern = pos + pattern_len; // Check what follows the pattern if (after_pattern < name_str.length() && name_str[after_pattern] == '_' && after_pattern + 1 < name_str.length() && std::isdigit(name_str[after_pattern + 1])) {{ // Pattern followed by _digit (e.g., "main_kernel_1") variant_matches.push_back({{name_str, func_list[i]}}); }} else {{ // Pattern not followed by _digit (e.g., "main_kernel" itself) base_matches.push_back({{name_str, func_list[i]}}); }} }} }} // Decision logic: prefer base matches over variant matches if (!base_matches.empty()) {{ if (base_matches.size() == 1) {{ *out_func = base_matches[0].second; return CUDA_SUCCESS; }} // Multiple base matches - ambiguous std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size() << " base kernels (ambiguous). Matches found:\n"; for (const auto& match : base_matches) {{ std::cerr << " - " << match.first << "\n"; }} std::cerr << "Please use a more specific pattern.\n"; return CUDA_ERROR_NOT_FOUND; }} // No base matches, try variant matches if (!variant_matches.empty()) {{ if (variant_matches.size() == 1) {{ *out_func = variant_matches[0].second; return CUDA_SUCCESS; }} // Multiple variant matches - ambiguous std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size() << " variant kernels (ambiguous). Matches found:\n"; for (const auto& match : variant_matches) {{ std::cerr << " - " << match.first << "\n"; }} std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\n"; return CUDA_ERROR_NOT_FOUND; }} // No matches at all std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\n"; return CUDA_ERROR_NOT_FOUND; }} // Initialize CUDA module (called once on first launch) static CUresult tilelang_init_cuda_module(const std::string& cubin_path) {{ if (g_module_initialized) return CUDA_SUCCESS; CUresult result; result = cuInit(0); if (result != CUDA_SUCCESS) return result; std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary); if (!cubin_file) {{ std::cerr << "Failed to open cubin file: " << cubin_path << "\n"; return CUDA_ERROR_FILE_NOT_FOUND; }} std::vector cubin_data((std::istreambuf_iterator(cubin_file)), std::istreambuf_iterator()); cubin_file.close(); if (cubin_data.empty()) {{ std::cerr << "Empty cubin file: " << cubin_path << "\n"; return CUDA_ERROR_INVALID_IMAGE; }} result = cuModuleLoadData(&g_module, cubin_data.data()); if (result != CUDA_SUCCESS) {{ std::cerr << "Failed to load CUDA module: " << result << "\n"; return result; }} g_module_initialized = true; return CUDA_SUCCESS; }} // Initialize all kernel functions (called once after module load) static CUresult tilelang_init_kernels() {{ if (g_kernels_initialized) return CUDA_SUCCESS; CUresult result; {kernel_inits} g_kernels_initialized = true; return CUDA_SUCCESS; }} // TMA descriptor initialization (host-side) {tma_init_func} // Main kernel launcher extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, tvm::ffi::Bytes cubin_path) {{ CUresult result; std::string cubin_path_str(reinterpret_cast(cubin_path.data()), cubin_path.size()); result = tilelang_init_cuda_module(cubin_path_str); if (result != CUDA_SUCCESS) return result; result = tilelang_init_kernels(); if (result != CUDA_SUCCESS) return result; {get_ptr_code} CUstream stream = (CUstream)_stream; {tma_init_in_launch} {kernel_launches} return CUDA_SUCCESS; }} // Cleanup function extern "C" CUresult cleanup_module() {{ if (g_module_initialized && g_module != nullptr) {{ cuModuleUnload(g_module); g_module = nullptr; g_module_initialized = false; }} g_kernels_initialized = false; return CUDA_SUCCESS; }} TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel); TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module); """ .. raw:: html
.. py:data:: CUBIN_TMA_ATOM_INIT_TEMPLATE :value: ' {desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))' .. py:data:: CUBIN_KERNEL_LAUNCH_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ {function_name}({call_args}).launch( grid=[{grid_x}, {grid_y}, {grid_z}], block=[{block_x}, {block_y}, {block_z}], smem={smem_size}, stream=stream, )""" .. raw:: html
.. py:data:: CUBIN_FAKE_TENSOR_TEMPLATE :value: ' __fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)],... .. py:data:: CUBIN_GEN_CODE_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """{lib_code} @cute.jit def kernel_wrapper({wrapper_args}): {tma_init_code}{kernel_launches} # Compile kernels to generate cubin {fake_tensor_code}{fake_tma_tensor_code} __fake_stream__ = make_fake_stream() # Always generate cubin under a unique staging directory to avoid concurrent # processes clobbering each other's intermediate artifacts. _staging_dir = Path(tempfile.mkdtemp( prefix=Path(__file__).stem + ".cubin.staging.", dir=_module_dir, )) try: _kernel_wrapper = cute.compile( kernel_wrapper, {compile_args}, options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}", ) # CuTeDSL generates a long, mangled cubin filename that includes argument/type info, # e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin. _cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime) if len(_cubin_files) != 1: raise RuntimeError( f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}" ) os.replace(_cubin_files[0], _cubin_path) finally: shutil.rmtree(_staging_dir, ignore_errors=True)""" .. raw:: html
.. py:data:: PYTHON_HOST_FUNC_TEMPLATE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """import os from pathlib import Path # Minimal imports for runtime (no cutlass/cute - only needed for cubin generation) import tvm.runtime as runtime _cpp_launcher = None _cpp_launcher_lib = None _cubin_generated = False # Pre-compute paths - cubin is stored alongside the launcher .so # Use module basename to avoid conflicts when multiple kernels run concurrently # e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin" # "kernel.py" (in cache) -> "kernel.cubin" _module_dir = Path(os.path.dirname(__file__)) _cubin_path = _module_dir / (Path(__file__).stem + ".cubin") _cubin_path_bytes = _cubin_path.as_posix().encode('utf-8') _cubin_needs_generation = not _cubin_path.exists() def _generate_cubin_if_needed({cubin_gen_params}): """Generate cubin file on first call. All CuTeDSL imports are inside this function to avoid slow module-level initialization when loading from cache. """ global _cubin_generated, _cubin_path # Lazy import CuTeDSL only when cubin generation is needed from cuda.bindings.driver import CUstream import cutlass import cutlass.cute as cute from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor import tilelang.contrib.cutedsl as tl # We rely on CuTeDSL's keep-cubin artifact rather than custom extraction. import tempfile import shutil _DTYPE_MAP = {{ "torch.float32": cutlass.Float32, "torch.float16": cutlass.Float16, "torch.bfloat16": cutlass.BFloat16, "torch.float8_e4m3fnuz": cutlass.Float8E4M3FN, "torch.float8_e4m3fn": cutlass.Float8E4M3FN, "torch.float8_e5m2": cutlass.Float8E5M2, "torch.float64": cutlass.Float64, "torch.int64": cutlass.Int64, "torch.int32": cutlass.Int32, "torch.uint32": cutlass.Uint32, "torch.bool": cutlass.Boolean, "torch.int8": cutlass.Int8, "torch.uint8": cutlass.Uint8, "torch.int16": cutlass.Int16, "torch.uint16": cutlass.Uint16, "torch.uchar": cutlass.Uint8, }} {cubin_gen_code} _cubin_generated = True def _load_cpp_launcher(): """Load C++ kernel launcher.""" global _cpp_launcher, _cpp_launcher_lib if _cpp_launcher is not None: return _cpp_launcher lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}") if not os.path.exists(lib_path): raise FileNotFoundError(f"Launcher not found: {{lib_path}}") _cpp_launcher_lib = runtime.load_module(lib_path) _cpp_launcher = _cpp_launcher_lib["launch_kernel"] return _cpp_launcher def call({call_func_params}, stream): """Kernel dispatch function.""" global _cubin_path_bytes, _cubin_needs_generation if _cubin_needs_generation: _generate_cubin_if_needed({cubin_gen_call_args}) _cubin_needs_generation = False {arg_prep_code} launcher = _load_cpp_launcher() result = launcher({launcher_call_args}, stream, _cubin_path_bytes) if result != 0: raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}") """ .. raw:: html
.. py:class:: TLCuTeDSLSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None) Bases: :py:obj:`tilelang.jit.adapter.wrapper.TLCUDASourceWrapper` Wrapper class for TileLang CuTe DSL backend with C++ launcher. Generates optimized C++ launcher code that: - Loads cubin via CUDA Driver API - Passes TMA descriptors by value (host-side, no device copy) - Launches kernels with minimal Python overhead - Supports both single and multiple kernel scenarios .. py:property:: host_func Override parent's host_func to return generated Python code. .. py:method:: generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map) Generate TMA descriptor information for C++ code generation. :returns: List of descriptor variable names in the order they were processed. .. py:method:: create_dispatch_func(code, function_informations) Create dispatch function - always use C++ launcher. .. py:method:: create_dispatch_func_cpp_launcher(code, function_informations) Create dispatch function using C++ launcher. .. py:method:: get_launcher_cpp_code() Get the generated C++ launcher code. .. py:method:: update_lib_code(code) Update the library code with the given code string.