tilelang.jit.adapter.wrapper ============================ .. py:module:: tilelang.jit.adapter.wrapper Attributes ---------- .. autoapisummary:: tilelang.jit.adapter.wrapper.PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY tilelang.jit.adapter.wrapper.PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP tilelang.jit.adapter.wrapper.PREDEF_INIT_FUNC tilelang.jit.adapter.wrapper.PREDEF_HOST_FUNC tilelang.jit.adapter.wrapper.PREDEF_HOST_FUNC_PY tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_CREATE_HANDLE tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_INIT_FUNC tilelang.jit.adapter.wrapper.L2_PERSISTENT_MAP_RESET_HANDLE tilelang.jit.adapter.wrapper.TMA_DESC_INIT_FUNC tilelang.jit.adapter.wrapper.TMA_DESC_INIT_FUNC_PY tilelang.jit.adapter.wrapper.KERNEL_LAUNCH_FUNC_PY tilelang.jit.adapter.wrapper.logger Classes ------- .. autoapisummary:: tilelang.jit.adapter.wrapper.BaseWrapper tilelang.jit.adapter.wrapper.TLCUDASourceWrapper tilelang.jit.adapter.wrapper.TLNVRTCSourceWrapper tilelang.jit.adapter.wrapper.TLHIPSourceWrapper tilelang.jit.adapter.wrapper.TLCPUSourceWrapper tilelang.jit.adapter.wrapper.TLWrapper tilelang.jit.adapter.wrapper.TLPyWrapper Module Contents --------------- .. py:data:: PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1}); if (result_{0} != CUDA_SUCCESS) {{ snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0})); return -1; }} """ .. raw:: html
.. py:data:: PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ if ({1} > 65536) {{ snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size for {0} to %d", {1}); return -1; }} return 0; """ .. raw:: html
.. py:data:: PREDEF_INIT_FUNC :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ #define ERROR_BUF_SIZE 1024 static char error_buf[ERROR_BUF_SIZE]; extern "C" const char* get_last_error() {{ return error_buf; }} extern "C" int init() {{ error_buf[0] = '\0'; {0} return 0; }} """ .. raw:: html
.. py:data:: PREDEF_HOST_FUNC :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ extern "C" int call({}) {{ {} return 0; }} """ .. raw:: html
.. py:data:: PREDEF_HOST_FUNC_PY :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ import cuda.bindings.driver import ctypes _function_names = {} def call({}): {} """ .. raw:: html
.. py:data:: L2_PERSISTENT_MAP_CREATE_HANDLE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ cudaStreamAttrValue stream_attribute; size_t init_persisting_l2_cache_size; cudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize); """ .. raw:: html
.. py:data:: L2_PERSISTENT_MAP_INIT_FUNC :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ stream_attribute.accessPolicyWindow.hitRatio = {1}; stream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting; stream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming; cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {2}); stream_attribute.accessPolicyWindow.base_ptr = (void*)({0}); stream_attribute.accessPolicyWindow.num_bytes = {2}; cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute); """ .. raw:: html
.. py:data:: L2_PERSISTENT_MAP_RESET_HANDLE :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ stream_attribute.accessPolicyWindow.num_bytes = 0; cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute); cudaCtxResetPersistingL2Cache(); cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size); """ .. raw:: html
.. py:data:: TMA_DESC_INIT_FUNC :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ CUtensorMap {0}; CUtensorMapDataType {0}_type= (CUtensorMapDataType){1}; cuuint32_t {0}_tensorRank= {2}; void *{0}_globalAddress= {3}; cuuint64_t {0}_globalDim[{2}]= {{{4}}}; cuuint64_t {0}_globalStride[{2}]= {{{5}}}; cuuint32_t {0}_boxDim[{2}]= {{{6}}}; cuuint32_t {0}_elementStrides[{2}]= {{{7}}}; CUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){8}; CUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9}; CUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10}; CUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11}; CUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill); if ({0}_result != CUDA_SUCCESS) {{ std::stringstream ss; ss << "Error: Failed to initialize the TMA descriptor {0}"; snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str()); return -1; }} """ .. raw:: html
.. py:data:: TMA_DESC_INIT_FUNC_PY :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ {0}_type = cuda.bindings.driver.CUtensorMapDataType({1}) {0}_tensorRank = {2} {0}_globalAddress = {3}.data_ptr() {0}_globalDim = [{4}] {0}_globalStride = [{5}][1:] {0}_boxDim = [{6}] {0}_elementStrides = [{7}] {0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({8}) {0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({9}) {0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({10}) {0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({11}) res, {0} = cuda.bindings.driver.cuTensorMapEncodeTiled( {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill, ) if res != cuda.bindings.driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to initialize the TMA descriptor {0}: {{res}}") """ .. raw:: html
.. py:data:: KERNEL_LAUNCH_FUNC_PY :value: Multiline-String .. raw:: html
Show Value .. code-block:: python """ res = cuda.bindings.driver.cuKernelSetAttribute( cuda.bindings.driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, {7}, kernels["{0}"], cuda.bindings.driver.CUdevice({10}) )[0] if res != cuda.bindings.driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to set max dynamic shared memory size to {7} for kernel {0}: {{res}}") config = cuda.bindings.driver.CUlaunchConfig() config.gridDimX = {1} config.gridDimY = {2} config.gridDimZ = {3} config.blockDimX = {4} config.blockDimY = {5} config.blockDimZ = {6} config.sharedMemBytes = {7} config.hStream = stream arg_values = {8} arg_types = {9} res = cuda.bindings.driver.cuLaunchKernelEx(config, kernels["{0}"], (arg_values, arg_types), 0)[0] if res != cuda.bindings.driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to launch kernel {0}: {{res}}") """ .. raw:: html
.. py:class:: BaseWrapper Bases: :py:obj:`abc.ABC` Helper class that provides a standard way to create an ABC using inheritance. .. py:method:: wrap(*args, **kwargs) :abstractmethod: .. py:data:: logger .. py:class:: TLCUDASourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None) Bases: :py:obj:`object` .. py:attribute:: backend :value: 'tl' .. py:attribute:: device_mod :type: Optional[tvm.IRModule] :value: None .. py:attribute:: host_mod :type: Optional[tvm.IRModule] :value: None .. py:attribute:: pass_configs :type: Optional[Dict[str, Any]] :value: None .. py:attribute:: mod .. py:attribute:: target .. py:attribute:: source .. py:attribute:: function_names :type: Optional[str] :value: None .. py:attribute:: dynamic_smem_buf :type: Optional[int] :value: None .. py:attribute:: block_info :type: Union[List[int], Dict] :value: [1, 1, 1] .. py:attribute:: grid_info :type: Union[List[int], Dict] :value: [1, 1, 1] .. py:attribute:: tma_descriptor_args :type: Optional[Dict] :value: None .. py:attribute:: l2_persistent_map :type: Optional[Dict[str, Dict]] .. py:attribute:: srcpath :type: Optional[str] :value: None .. py:attribute:: libpath :type: Optional[str] :value: None .. py:attribute:: lib_code :type: Optional[str] .. py:method:: is_tma_descriptor_arg(arg_name) .. py:method:: create_dispatch_func(code, function_informations) .. py:method:: generate_l2_persistent_map(function_name) .. py:method:: generate_tma_descriptor_args(desc_name_map) .. py:method:: parse_source_information() .. py:method:: get_dynamic_symbolic_set(prim_func) .. py:method:: get_init_func() .. py:method:: update_lib_code(code) .. py:method:: get_stream_type() .. py:property:: prim_func .. py:class:: TLNVRTCSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None) Bases: :py:obj:`TLCUDASourceWrapper` A wrapper class for the TileLang NVRTC backend. .. py:method:: create_dispatch_func(code, function_informations) .. py:method:: generate_tma_descriptor_args(desc_name_map) .. py:method:: update_lib_code(code) .. py:method:: get_stream_type() .. py:class:: TLHIPSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None) Bases: :py:obj:`TLCUDASourceWrapper` A wrapper class for the TileLang HIP backend. .. py:method:: get_init_func() .. py:method:: get_stream_type() .. py:class:: TLCPUSourceWrapper(scheduled_ir_module, source, target, device_mod = None, host_mod = None, pass_configs = None) Bases: :py:obj:`object` .. py:attribute:: INIT_FUNC .. py:attribute:: CALL_PREFIX .. py:attribute:: backend :value: 'tl' .. py:attribute:: device_mod :type: Optional[tvm.IRModule] :value: None .. py:attribute:: host_mod :type: Optional[tvm.IRModule] :value: None .. py:attribute:: pass_configs :type: Optional[Dict[str, Any]] :value: None .. py:attribute:: mod .. py:attribute:: target .. py:attribute:: source .. py:attribute:: function_names :type: Optional[str] :value: None .. py:attribute:: dynamic_smem_buf :type: Optional[int] :value: None .. py:attribute:: srcpath :type: Optional[str] :value: None .. py:attribute:: libpath :type: Optional[str] :value: None .. py:attribute:: lib_code :type: Optional[str] .. py:method:: create_call_func(code, function_informations) .. py:method:: parse_source_information() .. py:method:: get_dynamic_symbolic_set(prim_func) .. py:method:: get_cpu_init_func() .. py:method:: update_lib_code(code) .. py:property:: prim_func .. py:class:: TLWrapper(target) Bases: :py:obj:`BaseWrapper` A wrapper class for the TileLang backend. .. py:attribute:: device_mod :type: Optional[tvm.IRModule] :value: None .. py:attribute:: host_mod :type: Optional[tvm.IRModule] :value: None .. py:attribute:: pass_configs :type: Optional[Dict[str, Any]] :value: None .. py:attribute:: target :type: Optional[tvm.target.Target] :value: None .. py:attribute:: lib :type: Optional[object] :value: None .. py:attribute:: scheduled_ir_module :value: None .. py:method:: assign_optimized_module(scheduled_ir_module) .. py:method:: assign_pass_configs(pass_configs) .. py:method:: assign_host_module(host_mod) .. py:method:: assign_device_module(device_mod) .. py:method:: wrap(c_source) .. py:class:: TLPyWrapper(target) Bases: :py:obj:`TLWrapper` A wrapper class for the TileLang backend. .. py:method:: wrap(c_source)