tilelang.jit.adapter.utils

Classes

TMADescriptorParams

Parsed TMA descriptor parameters.

Functions

match_global_kernel(source[, annotation])

match_declare_kernel(source[, annotation])

match_declare_kernel_cutedsl(source[, annotation])

extract_python_func_declaration(source, func_name)

Extract the full Python function declaration from decorator to colon.

match_declare_kernel_cpu(source[, annotation])

is_cuda_target(target)

is_hip_target(target)

is_cpu_target(target)

is_metal_target(target)

is_cutedsl_target(target)

get_annotated_mod(func_or_mod[, target, target_host, ...])

pythonic_expr(expr[, dtype_map, ignore_cast, floor_div_op])

Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.

maybe_desc_name(name, matches, i[, desc_name_map])

Check if a parameter name corresponds to a TMA descriptor.

parse_function_call_args(declaration, function_args, ...)

Parse function call arguments from a kernel declaration.

parse_tma_descriptor_args(tma_descriptor_args, ...)

Parse TMA descriptor arguments into structured parameters.

Module Contents

tilelang.jit.adapter.utils.match_global_kernel(source, annotation='__global__')
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.match_declare_kernel(source, annotation='__global__')
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.match_declare_kernel_cutedsl(source, annotation='@cute.kernel')
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.extract_python_func_declaration(source, func_name)

Extract the full Python function declaration from decorator to colon.

Parameters:
  • source (str) – Source code containing the function

  • func_name (str) – Name of the function to extract (can include ‘(’ suffix)

Returns:

‘, including parameters

Return type:

The function declaration from ‘def’ to ‘

Example

For code:

@cute.kernel def kernel(arg1: cute.Tensor, arg2: int):

Returns: “def kernel(arg1: cute.Tensor, arg2: int)”

tilelang.jit.adapter.utils.match_declare_kernel_cpu(source, annotation='int32_t')
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.is_cuda_target(target)
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_hip_target(target)
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_cpu_target(target)
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_metal_target(target)
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_cutedsl_target(target)
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.get_annotated_mod(func_or_mod, target='auto', target_host=None, model_type='all')
Parameters:
  • func_or_mod (tvm.tir.PrimFunc | tilelang.tvm.IRModule)

  • target (str | tvm.target.Target)

  • target_host (str | tvm.target.Target | None)

  • model_type (Literal['device', 'host', 'all'])

Return type:

tvm.IRModule | tuple[tvm.IRModule, tvm.IRModule]

tilelang.jit.adapter.utils.pythonic_expr(expr, dtype_map=None, ignore_cast=False, floor_div_op='/')

Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.

Parameters:
  • expr (tilelang.tvm.tir.PrimExpr) – The TVM PrimExpr to convert.

  • dtype_map (dict[str, str] | None) – A dictionary mapping data types to their string representations.

  • ignore_cast (bool) – Whether to ignore the cast operator and return the string representation of the value without the cast.

  • floor_div_op (str) – Operator to use for tvm.tir.FloorDiv. Default ‘/’ preserves prior behavior (suitable for generating C/C++ expressions). For generating Python code where integer division is required (e.g. grid/block), pass ‘//’ explicitly.

Returns:

A string representation of the expression.

Return type:

str

tilelang.jit.adapter.utils.maybe_desc_name(name, matches, i, desc_name_map=None)

Check if a parameter name corresponds to a TMA descriptor.

Parameters:
  • name (str) – The parameter name to check.

  • matches (list[str]) – List of all matched parameter names.

  • i (int) – Index of the current match.

  • desc_name_map (dict[str, str] | None) – Optional mapping to store descriptor name relationships.

Returns:

True if the parameter is a TMA descriptor.

Return type:

bool

tilelang.jit.adapter.utils.parse_function_call_args(declaration, function_args, function_params, desc_name_map=None, desc_name_var_map=None, transform_arg=None)

Parse function call arguments from a kernel declaration.

Parameters:
  • declaration (str) – The kernel function declaration string.

  • function_args (list[dict[str, str]]) – List of function argument specifications.

  • function_params (list[Any]) – List of function parameters from TVM IR.

  • desc_name_map (dict[str, str] | None) – Optional mapping for descriptor names.

  • desc_name_var_map (dict[str, tilelang.tvm.tir.Var] | None) – Optional mapping from descriptor names to TVM variables.

  • transform_arg (Callable[[str, str], Any] | None) – Optional function to transform each argument (name, type) -> result.

Returns:

List of parsed call arguments.

Return type:

list[Any]

class tilelang.jit.adapter.utils.TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col=False)

Parsed TMA descriptor parameters.

Parameters:
  • handle_name (str)

  • dtype (str)

  • tensor_rank (int)

  • global_address (Any)

  • is_img2col (bool)

handle_name
dtype
tensor_rank
global_address
is_img2col = False
global_dim: list[str] = []
global_stride: list[str] = []
element_strides: list[str] = []
interleave: str = ''
swizzle: str = ''
l2_promotion: str = ''
oob_fill: str = ''
box_dim: list[str] = []
lower_corner: list[str] = []
upper_corner: list[str] = []
smem_box_channel: str = ''
smem_box_pixel: str = ''
tilelang.jit.adapter.utils.parse_tma_descriptor_args(tma_descriptor_args, desc_name_map, desc_name_var_map, pythonic_expr_func)

Parse TMA descriptor arguments into structured parameters.

Parameters:
  • tma_descriptor_args (dict[tilelang.tvm.tir.Var, list[Any]]) – Dictionary mapping TMA descriptor variables to their arguments.

  • desc_name_map (dict[str, str]) – Mapping from descriptor handles to parameter names.

  • desc_name_var_map (dict[str, tilelang.tvm.tir.Var]) – Mapping from descriptor handles to TVM variables.

  • pythonic_expr_func (Callable[[Any], str]) – Function to convert TVM expressions to strings.

Returns:

List of parsed TMA descriptor parameters.

Return type:

list[TMADescriptorParams]