tilelang.jit.adapter.cutedsl.wrapper¶

CuTeDSL Source Wrapper for TileLang.

This module provides C++ kernel launcher generation for the CuTeDSL backend.

Key features: - Automatic C++ launcher generation with CUDA Driver API - TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed) - cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space - Support for single and multiple kernel launches - Complete cache system integration

Attributes¶

Classes¶

TLCuTeDSLSourceWrapper

Wrapper class for TileLang CuTe DSL backend with C++ launcher.

Module Contents¶

tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_DESC_INIT_TEMPLATE = Multiline-String¶
Show Value
"""  // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name})
  {{
    uint64_t globalDim[{rank}] = {{{global_dim_values}}};
    uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
    uint32_t boxDim[{rank}] = {{{box_dim_values}}};
    uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};

    result = cuTensorMapEncodeTiled(
        &tma_descs[{desc_idx}],
        static_cast<CUtensorMapDataType>({dtype}),
        {rank},
        reinterpret_cast<void*>({tensor_name}_ptr),
        globalDim,
        globalStrides,
        boxDim,
        elemStrides,
        static_cast<CUtensorMapInterleave>({interleave}),
        static_cast<CUtensorMapSwizzle>({swizzle}),
        static_cast<CUtensorMapL2promotion>({l2_promotion}),
        static_cast<CUtensorMapFloatOOBfill>({oob_fill})
    );

    if (result != CUDA_SUCCESS) {{
      std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\n";
      return result;
    }}
  }}
"""
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_IM2COL_DESC_INIT_TEMPLATE = Multiline-String¶
Show Value
"""  // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col]
  {{
    uint64_t globalDim[{rank}] = {{{global_dim_values}}};
    uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}};
    uint32_t elemStrides[{rank}] = {{{elem_stride_values}}};
    int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}};
    int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}};

    result = cuTensorMapEncodeIm2col(
        &tma_descs[{desc_idx}],
        static_cast<CUtensorMapDataType>({dtype}),
        {rank},
        reinterpret_cast<void*>({tensor_name}_ptr),
        globalDim,
        globalStrides,
        lowerCorner,
        upperCorner,
        static_cast<uint32_t>({channels_per_pixel}),
        static_cast<uint32_t>({pixels_per_column}),
        elemStrides,
        static_cast<CUtensorMapInterleave>({interleave}),
        static_cast<CUtensorMapSwizzle>({swizzle}),
        static_cast<CUtensorMapL2promotion>({l2_promotion}),
        static_cast<CUtensorMapFloatOOBfill>({oob_fill})
    );

    if (result != CUDA_SUCCESS) {{
      std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\n";
      return result;
    }}
  }}
"""
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_INIT_FUNC_TEMPLATE = Multiline-String¶
Show Value
"""CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{
  // Initialize {num_descs} TMA descriptor(s) in caller-provided host array
  // cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically
  CUresult result;

{desc_init_code}

  return CUDA_SUCCESS;
}}
"""
tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_INIT_TEMPLATE = Multiline-String¶
Show Value
"""  // Find and configure kernel {kernel_idx}: {kernel_name}
  result = find_kernel_by_pattern(g_module, "{kernel_name}", &g_kernels[{kernel_idx}]);
  if (result != CUDA_SUCCESS) {{
    std::cerr << "Failed to find kernel {kernel_name}: " << result << "\n";
    return result;
  }}

  if ({smem_size} > 0) {{
    result = cuFuncSetAttribute(g_kernels[{kernel_idx}],
                                CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
                                {smem_size});
    if (result != CUDA_SUCCESS) {{
      std::cerr << "Failed to set smem for {kernel_name}: " << result << "\n";
      return result;
    }}
  }}
"""
tilelang.jit.adapter.cutedsl.wrapper.CPP_TMA_LAUNCH_INIT_TEMPLATE = Multiline-String¶
Show Value
"""  // Declare stack-local TMA descriptor array (eliminates concurrency race)
  CUtensorMap tma_descs[{num_tma_descs}];

  // Initialize TMA descriptors (HOST memory - passed via __grid_constant__)
  // NOTE: We intentionally do NOT reuse/cached descriptors across launches.
  // Pointer-only reuse is a correctness trap (shape/stride may change with same ptr),
  // and correctness beats micro-optimizations.
  result = tma_init(tma_descs, {tma_tensor_args});
  if (result != CUDA_SUCCESS) {{
    std::cerr << "Failed to initialize TMA descriptors: " << result << "\n";
    return result;
  }}
"""
tilelang.jit.adapter.cutedsl.wrapper.CPP_KERNEL_LAUNCH_TEMPLATE = Multiline-String¶
Show Value
"""  // Launch kernel {kernel_idx}: {kernel_name}
  {{
    void* args[] = {{{kernel_args}}};
    result = cuLaunchKernel(
        g_kernels[{kernel_idx}],
        {grid_x}, {grid_y}, {grid_z},
        {block_x}, {block_y}, {block_z},
        {smem_size},
        stream,
        args,
        nullptr
    );
    if (result != CUDA_SUCCESS) {{
      std::cerr << "Failed to launch kernel {kernel_name}: " << result << "\n";
      return result;
    }}
  }}
"""
tilelang.jit.adapter.cutedsl.wrapper.CPP_LAUNCHER_TEMPLATE = Multiline-String¶
Show Value
"""#include <cuda.h>
#include <cstdint>
#include <iostream>
#include <fstream>
#include <vector>
#include <cstring>
#include <string>

// TVM Headers
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>

// Cached module handle
static CUmodule g_module = nullptr;
static bool g_module_initialized = false;

// Cached kernel functions
static CUfunction g_kernels[{num_kernels}] = {{nullptr}};
static bool g_kernels_initialized = false;

// Find kernel by pattern (substring match, prefer base name over _N variants)
CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{
  CUresult result;
  unsigned int num_funcs = 0;

  result = cuModuleGetFunctionCount(&num_funcs, module);
  if (result != CUDA_SUCCESS) {{
    std::cerr << "Failed to get function count: " << result << "\n";
    return result;
  }}

  std::vector<CUfunction> func_list(num_funcs);
  result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module);
  if (result != CUDA_SUCCESS) {{
    std::cerr << "Failed to enumerate functions: " << result << "\n";
    return result;
  }}

  // Collect substring matches, separating base name from _N variants
  std::vector<std::pair<std::string, CUfunction>> base_matches;     // pattern not followed by _digit
  std::vector<std::pair<std::string, CUfunction>> variant_matches;  // pattern followed by _digit

  size_t pattern_len = std::strlen(pattern);

  for (unsigned int i = 0; i < num_funcs; i++) {{
    const char* func_name = nullptr;
    result = cuFuncGetName(&func_name, func_list[i]);
    if (result != CUDA_SUCCESS || func_name == nullptr) {{
      std::cerr << "Failed to get function name: " << result << "\n";
      return result;
    }}

    std::string name_str(func_name);
    size_t pos = name_str.find(pattern);

    if (pos != std::string::npos) {{
      // Found substring match
      size_t after_pattern = pos + pattern_len;

      // Check what follows the pattern
      if (after_pattern < name_str.length() &&
          name_str[after_pattern] == '_' &&
          after_pattern + 1 < name_str.length() &&
          std::isdigit(name_str[after_pattern + 1])) {{
        // Pattern followed by _digit (e.g., "main_kernel_1")
        variant_matches.push_back({{name_str, func_list[i]}});
      }} else {{
        // Pattern not followed by _digit (e.g., "main_kernel" itself)
        base_matches.push_back({{name_str, func_list[i]}});
      }}
    }}
  }}

  // Decision logic: prefer base matches over variant matches
  if (!base_matches.empty()) {{
    if (base_matches.size() == 1) {{
      *out_func = base_matches[0].second;
      return CUDA_SUCCESS;
    }}

    // Multiple base matches - ambiguous
    std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size()
              << " base kernels (ambiguous). Matches found:\n";
    for (const auto& match : base_matches) {{
      std::cerr << "  - " << match.first << "\n";
    }}
    std::cerr << "Please use a more specific pattern.\n";
    return CUDA_ERROR_NOT_FOUND;
  }}

  // No base matches, try variant matches
  if (!variant_matches.empty()) {{
    if (variant_matches.size() == 1) {{
      *out_func = variant_matches[0].second;
      return CUDA_SUCCESS;
    }}

    // Multiple variant matches - ambiguous
    std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size()
              << " variant kernels (ambiguous). Matches found:\n";
    for (const auto& match : variant_matches) {{
      std::cerr << "  - " << match.first << "\n";
    }}
    std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\n";
    return CUDA_ERROR_NOT_FOUND;
  }}

  // No matches at all
  std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\n";
  return CUDA_ERROR_NOT_FOUND;
}}


// Initialize CUDA module (called once on first launch)
static CUresult tilelang_init_cuda_module(const std::string& cubin_path) {{
  if (g_module_initialized) return CUDA_SUCCESS;

  CUresult result;
  result = cuInit(0);
  if (result != CUDA_SUCCESS) return result;

  std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary);
  if (!cubin_file) {{
    std::cerr << "Failed to open cubin file: " << cubin_path << "\n";
    return CUDA_ERROR_FILE_NOT_FOUND;
  }}

  std::vector<char> cubin_data((std::istreambuf_iterator<char>(cubin_file)),
                                std::istreambuf_iterator<char>());
  cubin_file.close();

  if (cubin_data.empty()) {{
    std::cerr << "Empty cubin file: " << cubin_path << "\n";
    return CUDA_ERROR_INVALID_IMAGE;
  }}

  result = cuModuleLoadData(&g_module, cubin_data.data());
  if (result != CUDA_SUCCESS) {{
    std::cerr << "Failed to load CUDA module: " << result << "\n";
    return result;
  }}

  g_module_initialized = true;
  return CUDA_SUCCESS;
}}

// Initialize all kernel functions (called once after module load)
static CUresult tilelang_init_kernels() {{
  if (g_kernels_initialized) return CUDA_SUCCESS;
  CUresult result;

{kernel_inits}

  g_kernels_initialized = true;
  return CUDA_SUCCESS;
}}

// TMA descriptor initialization (host-side)
{tma_init_func}

// Main kernel launcher
extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, tvm::ffi::Bytes cubin_path) {{
  CUresult result;

  std::string cubin_path_str(reinterpret_cast<const char*>(cubin_path.data()), cubin_path.size());
  result = tilelang_init_cuda_module(cubin_path_str);
  if (result != CUDA_SUCCESS) return result;

  result = tilelang_init_kernels();
  if (result != CUDA_SUCCESS) return result;

{get_ptr_code}
  CUstream stream = (CUstream)_stream;

{tma_init_in_launch}

{kernel_launches}

  return CUDA_SUCCESS;
}}

// Cleanup function
extern "C" CUresult cleanup_module() {{
  if (g_module_initialized && g_module != nullptr) {{
    cuModuleUnload(g_module);
    g_module = nullptr;
    g_module_initialized = false;
  }}

  g_kernels_initialized = false;

  return CUDA_SUCCESS;
}}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module);
"""
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_TMA_ATOM_INIT_TEMPLATE = '    {desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))'¶
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_KERNEL_LAUNCH_TEMPLATE = Multiline-String¶
Show Value
"""    {function_name}({call_args}).launch(
      grid=[{grid_x}, {grid_y}, {grid_z}],
      block=[{block_x}, {block_y}, {block_z}],
      smem={smem_size},
      stream=stream,
    )"""
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_FAKE_TENSOR_TEMPLATE = '  __fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)],...¶
tilelang.jit.adapter.cutedsl.wrapper.CUBIN_GEN_CODE_TEMPLATE = Multiline-String¶
Show Value
"""{lib_code}

  @cute.jit
  def kernel_wrapper({wrapper_args}):
{tma_init_code}{kernel_launches}

  # Compile kernels to generate cubin
{fake_tensor_code}{fake_tma_tensor_code}  __fake_stream__ = make_fake_stream()
  # Always generate cubin under a unique staging directory to avoid concurrent
  # processes clobbering each other's intermediate artifacts.
  _staging_dir = Path(tempfile.mkdtemp(
      prefix=Path(__file__).stem + ".cubin.staging.",
      dir=_module_dir,
  ))
  try:
    _kernel_wrapper = cute.compile(
        kernel_wrapper,
        {compile_args},
        options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}",
    )

    # CuTeDSL generates a long, mangled cubin filename that includes argument/type info,
    # e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin.
    _cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime)
    if len(_cubin_files) != 1:
      raise RuntimeError(
          f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}"
      )
    os.replace(_cubin_files[0], _cubin_path)
  finally:
    shutil.rmtree(_staging_dir, ignore_errors=True)"""
tilelang.jit.adapter.cutedsl.wrapper.PYTHON_HOST_FUNC_TEMPLATE = Multiline-String¶
Show Value
"""import os
from pathlib import Path

# Minimal imports for runtime (no cutlass/cute - only needed for cubin generation)
import tvm.runtime as runtime

_cpp_launcher = None
_cpp_launcher_lib = None
_cubin_generated = False

# Pre-compute paths - cubin is stored alongside the launcher .so
# Use module basename to avoid conflicts when multiple kernels run concurrently
# e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin"
#       "kernel.py" (in cache) -> "kernel.cubin"
_module_dir = Path(os.path.dirname(__file__))
_cubin_path = _module_dir / (Path(__file__).stem + ".cubin")
_cubin_path_bytes = _cubin_path.as_posix().encode('utf-8')
_cubin_needs_generation = not _cubin_path.exists()

def _generate_cubin_if_needed({cubin_gen_params}):
  """Generate cubin file on first call.

  All CuTeDSL imports are inside this function to avoid slow
  module-level initialization when loading from cache.
  """
  global _cubin_generated, _cubin_path

  # Lazy import CuTeDSL only when cubin generation is needed
  from cuda.bindings.driver import CUstream
  import cutlass
  import cutlass.cute as cute
  from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor
  import tilelang.contrib.cutedsl as tl
  # We rely on CuTeDSL's keep-cubin artifact rather than custom extraction.
  import tempfile
  import shutil

  _DTYPE_MAP = {{
      "torch.float32": cutlass.Float32,
      "torch.float16": cutlass.Float16,
      "torch.bfloat16": cutlass.BFloat16,
      "torch.float8_e4m3fnuz": cutlass.Float8E4M3FN,
      "torch.float8_e4m3fn": cutlass.Float8E4M3FN,
      "torch.float8_e5m2": cutlass.Float8E5M2,
      "torch.float64": cutlass.Float64,
      "torch.int64": cutlass.Int64,
      "torch.int32": cutlass.Int32,
      "torch.uint32": cutlass.Uint32,
      "torch.bool": cutlass.Boolean,
      "torch.int8": cutlass.Int8,
      "torch.uint8": cutlass.Uint8,
      "torch.int16": cutlass.Int16,
      "torch.uint16": cutlass.Uint16,
      "torch.uchar": cutlass.Uint8,
  }}

{cubin_gen_code}

  _cubin_generated = True

def _load_cpp_launcher():
  """Load C++ kernel launcher."""
  global _cpp_launcher, _cpp_launcher_lib
  if _cpp_launcher is not None:
    return _cpp_launcher

  lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}")
  if not os.path.exists(lib_path):
    raise FileNotFoundError(f"Launcher not found: {{lib_path}}")

  _cpp_launcher_lib = runtime.load_module(lib_path)
  _cpp_launcher = _cpp_launcher_lib["launch_kernel"]
  return _cpp_launcher

def call({call_func_params}, stream):
  """Kernel dispatch function."""
  global _cubin_path_bytes, _cubin_needs_generation

  if _cubin_needs_generation:
    _generate_cubin_if_needed({cubin_gen_call_args})
    _cubin_needs_generation = False

{arg_prep_code}

  launcher = _load_cpp_launcher()
  result = launcher({launcher_call_args}, stream, _cubin_path_bytes)

  if result != 0:
    raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}")
"""
class tilelang.jit.adapter.cutedsl.wrapper.TLCuTeDSLSourceWrapper(scheduled_ir_module, source, target, device_mod=None, host_mod=None, pass_configs=None)¶

Bases: tilelang.jit.adapter.wrapper.TLCUDASourceWrapper

Wrapper class for TileLang CuTe DSL backend with C++ launcher.

Generates optimized C++ launcher code that: - Loads cubin via CUDA Driver API - Passes TMA descriptors by value (host-side, no device copy) - Launches kernels with minimal Python overhead - Supports both single and multiple kernel scenarios

Parameters:
  • scheduled_ir_module (tvm.IRModule)

  • source (str)

  • target (tvm.target.Target)

  • device_mod (tvm.IRModule | None)

  • host_mod (tvm.IRModule | None)

  • pass_configs (dict[str, Any] | None)

property host_func¶

Override parent’s host_func to return generated Python code.

generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map)¶

Generate TMA descriptor information for C++ code generation.

Returns:

List of descriptor variable names in the order they were processed.

Parameters:
  • desc_name_map (dict[str, str])

  • desc_name_var_map (dict[str, tilelang.tvm.tir.Var])

  • tma_desc_code_map (dict[str, str])

Return type:

list[str]

create_dispatch_func(code, function_informations)¶

Create dispatch function - always use C++ launcher.

create_dispatch_func_cpp_launcher(code, function_informations)¶

Create dispatch function using C++ launcher.

get_launcher_cpp_code()¶

Get the generated C++ launcher code.

Return type:

str

update_lib_code(code)¶

Update the library code with the given code string.

Parameters:

code (str)