tilelang.jit.adapter.wrapper
============================
.. py:module:: tilelang.jit.adapter.wrapper
Attributes
----------
.. autoapisummary::
tilelang.jit.adapter.wrapper.PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
tilelang.jit.adapter.wrapper.PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
tilelang.jit.adapter.wrapper.PREDEF_INIT_FUNC
tilelang.jit.adapter.wrapper.PREDEF_HOST_FUNC
tilelang.jit.adapter.wrapper.PREDEF_HOST_FUNC_PY
tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_CREATE_HANDLE
tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_INIT_FUNC
tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_RESET_HANDLE
tilelang.jit.adapter.wrapper.TMA_DESC_INIT_FUNC
tilelang.jit.adapter.wrapper.TMA_DESC_INIT_FUNC_PY
tilelang.jit.adapter.wrapper.KERNEL_LAUNCH_FUNC_PY
tilelang.jit.adapter.wrapper.logger
Classes
-------
.. autoapisummary::
tilelang.jit.adapter.wrapper.BaseWrapper
tilelang.jit.adapter.wrapper.TLCUDASourceWrapper
tilelang.jit.adapter.wrapper.TLNVRTCSourceWrapper
tilelang.jit.adapter.wrapper.TLHIPSourceWrapper
tilelang.jit.adapter.wrapper.TLCPUSourceWrapper
tilelang.jit.adapter.wrapper.TLWrapper
tilelang.jit.adapter.wrapper.TLPyWrapper
Module Contents
---------------
.. py:data:: PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result_{0} != CUDA_SUCCESS) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
return -1;
}}
"""
.. raw:: html
.. py:data:: PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
if ({1} > 65536) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size for {0} to %d", {1});
return -1;
}}
return 0;
"""
.. raw:: html
.. py:data:: PREDEF_INIT_FUNC
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];
extern "C" const char* get_last_error() {{
return error_buf;
}}
extern "C" int init() {{
error_buf[0] = '\0';
{0}
return 0;
}}
"""
.. raw:: html
.. py:data:: PREDEF_HOST_FUNC
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
extern "C" int call({}) {{
{}
return 0;
}}
"""
.. raw:: html
.. py:data:: PREDEF_HOST_FUNC_PY
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
import cuda.bindings.driver
import ctypes
_function_names = {}
def call({}):
{}
"""
.. raw:: html
.. py:data:: L2_PERSISTENT_MAP_CREATE_HANDLE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
cudaStreamAttrValue stream_attribute;
size_t init_persisting_l2_cache_size;
cudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize);
"""
.. raw:: html
.. py:data:: L2_PERSISTENT_MAP_INIT_FUNC
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
stream_attribute.accessPolicyWindow.hitRatio = {1};
stream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
stream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {2});
stream_attribute.accessPolicyWindow.base_ptr = (void*)({0});
stream_attribute.accessPolicyWindow.num_bytes = {2};
cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
"""
.. raw:: html
.. py:data:: L2_PERSISTENT_MAP_RESET_HANDLE
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
stream_attribute.accessPolicyWindow.num_bytes = 0;
cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
cudaCtxResetPersistingL2Cache();
cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size);
"""
.. raw:: html
.. py:data:: TMA_DESC_INIT_FUNC
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
CUtensorMap {0};
CUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
cuuint32_t {0}_tensorRank= {2};
void *{0}_globalAddress= {3};
cuuint64_t {0}_globalDim[{2}]= {{{4}}};
cuuint64_t {0}_globalStride[{2}]= {{{5}}};
cuuint32_t {0}_boxDim[{2}]= {{{6}}};
cuuint32_t {0}_elementStrides[{2}]= {{{7}}};
CUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){8};
CUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9};
CUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10};
CUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11};
CUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
if ({0}_result != CUDA_SUCCESS) {{
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor {0}";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1;
}}
"""
.. raw:: html
.. py:data:: TMA_DESC_INIT_FUNC_PY
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
{0}_tensorRank = {2}
{0}_globalAddress = {3}.data_ptr()
{0}_globalDim = [{4}]
{0}_globalStride = [{5}][1:]
{0}_boxDim = [{6}]
{0}_elementStrides = [{7}]
{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({8})
{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({9})
{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({10})
{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({11})
res, {0} = cuda.bindings.driver.cuTensorMapEncodeTiled(
{0}_type,
{0}_tensorRank,
{0}_globalAddress,
{0}_globalDim,
{0}_globalStride,
{0}_boxDim,
{0}_elementStrides,
{0}_interleave,
{0}_swizzle,
{0}_l2Promotion,
{0}_oobFill,
)
if res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}")
"""
.. raw:: html
.. py:data:: KERNEL_LAUNCH_FUNC_PY
:value: Multiline-String
.. raw:: html
Show Value
.. code-block:: python
"""
res = cuda.bindings.driver.cuKernelSetAttribute(
cuda.bindings.driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
{7},
kernels["{0}"],
cuda.bindings.driver.CUdevice({10})
)[0]
if res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}")
config = cuda.bindings.driver.CUlaunchConfig()
config.gridDimX = {1}
config.gridDimY = {2}
config.gridDimZ = {3}
config.blockDimX = {4}
config.blockDimY = {5}
config.blockDimZ = {6}
config.sharedMemBytes = {7}
config.hStream = stream
arg_values = {8}
arg_types = {9}
res = cuda.bindings.driver.cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0]
if res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to launch kernel {0}: {{res}}")
"""
.. raw:: html
.. py:class:: BaseWrapper
Bases: :py:obj:`abc.ABC`
Helper class that provides a standard way to create an ABC using
inheritance.
.. py:method:: wrap(*args, **kwargs)
:abstractmethod:
.. py:data:: logger
.. py:class:: TLCUDASourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None)
Bases: :py:obj:`object`
.. py:attribute:: backend
:value: 'tl'
.. py:attribute:: device_mod
:type: Optional[tvm.IRModule]
:value: None
.. py:attribute:: host_mod
:type: Optional[tvm.IRModule]
:value: None
.. py:attribute:: pass_configs
:type: Optional[Dict[str, Any]]
:value: None
.. py:attribute:: mod
.. py:attribute:: target
.. py:attribute:: source
.. py:attribute:: function_names
:type: Optional[str]
:value: None
.. py:attribute:: dynamic_smem_buf
:type: Optional[int]
:value: None
.. py:attribute:: block_info
:type: Union[List[int], Dict]
:value: [1, 1, 1]
.. py:attribute:: grid_info
:type: Union[List[int], Dict]
:value: [1, 1, 1]
.. py:attribute:: tma_descriptor_args
:type: Optional[Dict]
:value: None
.. py:attribute:: l2_persistent_map
:type: Optional[Dict[str, Dict]]
.. py:attribute:: srcpath
:type: Optional[str]
:value: None
.. py:attribute:: libpath
:type: Optional[str]
:value: None
.. py:attribute:: lib_code
:type: Optional[str]
.. py:method:: is_tma_descriptor_arg(arg_name)
.. py:method:: create_dispatch_func(code, function_informations)
.. py:method:: generate_l2_persistent_map(function_name)
.. py:method:: generate_tma_descriptor_args(desc_name_map)
.. py:method:: parse_source_information()
.. py:method:: get_dynamic_symbolic_set(prim_func)
.. py:method:: get_init_func()
.. py:method:: update_lib_code(code)
.. py:method:: get_stream_type()
.. py:property:: prim_func
.. py:class:: TLNVRTCSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None)
Bases: :py:obj:`TLCUDASourceWrapper`
A wrapper class for the TileLang NVRTC backend.
.. py:method:: create_dispatch_func(code, function_informations)
.. py:method:: generate_tma_descriptor_args(desc_name_map)
.. py:method:: update_lib_code(code)
.. py:method:: get_stream_type()
.. py:class:: TLHIPSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None)
Bases: :py:obj:`TLCUDASourceWrapper`
A wrapper class for the TileLang HIP backend.
.. py:method:: get_init_func()
.. py:method:: get_stream_type()
.. py:class:: TLCPUSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None)
Bases: :py:obj:`object`
.. py:attribute:: INIT_FUNC
.. py:attribute:: CALL_PREFIX
.. py:attribute:: backend
:value: 'tl'
.. py:attribute:: device_mod
:type: Optional[tvm.IRModule]
:value: None
.. py:attribute:: host_mod
:type: Optional[tvm.IRModule]
:value: None
.. py:attribute:: pass_configs
:type: Optional[Dict[str, Any]]
:value: None
.. py:attribute:: mod
.. py:attribute:: target
.. py:attribute:: source
.. py:attribute:: function_names
:type: Optional[str]
:value: None
.. py:attribute:: dynamic_smem_buf
:type: Optional[int]
:value: None
.. py:attribute:: srcpath
:type: Optional[str]
:value: None
.. py:attribute:: libpath
:type: Optional[str]
:value: None
.. py:attribute:: lib_code
:type: Optional[str]
.. py:method:: create_call_func(code, function_informations)
.. py:method:: parse_source_information()
.. py:method:: get_dynamic_symbolic_set(prim_func)
.. py:method:: get_cpu_init_func()
.. py:method:: update_lib_code(code)
.. py:property:: prim_func
.. py:class:: TLWrapper(target)
Bases: :py:obj:`BaseWrapper`
A wrapper class for the TileLang backend.
.. py:attribute:: device_mod
:type: Optional[tvm.IRModule]
:value: None
.. py:attribute:: host_mod
:type: Optional[tvm.IRModule]
:value: None
.. py:attribute:: pass_configs
:type: Optional[Dict[str, Any]]
:value: None
.. py:attribute:: target
:type: Optional[tvm.target.Target]
:value: None
.. py:attribute:: lib
:type: Optional[object]
:value: None
.. py:attribute:: scheduled_ir_module
:value: None
.. py:method:: assign_optimized_module(scheduled_ir_module)
.. py:method:: assign_pass_configs(pass_configs)
.. py:method:: assign_host_module(host_mod)
.. py:method:: assign_device_module(device_mod)
.. py:method:: wrap(c_source)
.. py:class:: TLPyWrapper(target)
Bases: :py:obj:`TLWrapper`
A wrapper class for the TileLang backend.
.. py:method:: wrap(c_source)