tilelang.language.tir.opΒΆ
AttributesΒΆ
FunctionsΒΆ
|
Build expression by call an external packed function. |
|
Build expression by call an external packed function. |
|
Lowered version of call packed. |
|
Lowered version of call c-packed. |
|
Build expression by calling an intrinsic function. |
|
Build expression by calling a pure extern function. |
|
Build expression by calling a extern function. |
|
Build expression by calling a llvm intrinsic function |
|
Build expression by calling a pure llvm intrinsic function |
|
Return new on stack dtype[num] |
|
Return new on stack dtype[num] |
|
Allocate a shape tuple on stack, return the handle |
|
Allocate a NDArray(DLTensor) on stack, return the handle |
|
Provide a true statement that can be used for simplifications |
|
Returns an initialized but arbitrary value |
|
Performs a call into another PrimFunc in the same IRModule |
Start profile intrinsic. |
|
End profile intrinsic. |
|
|
Create a tuple structure in value field of AttrStmt |
|
Get struct field value in array |
|
Set value in struct field in array |
|
Returns the address of an element in the buffer |
|
Returns the param by name |
|
Perform allreduce inside threadblock. |
|
Mark condition as thread invariant. |
|
Perform synchronization in specified scope. |
|
Exchange value between threads inside a warp. |
|
Copy value from a lane with lower (by offset) index relative to caller. |
|
Copy value from a lane with higher (by offset) index relative to caller. |
Return a 32-bit mask indicates currently active threads in a calling warp. |
|
|
Create a type annotation expression |
|
Get head access address with memory access pattern info |
Throw TVMGetLastError() |
|
|
TVM intrinsic for tensor core load operators |
|
TVM intrinsic for tensor core mma_sync operators |
|
TVM intrinsic for tensor core bmma_sync operators |
|
TVM intrinsic for tensor core fill_fragment operators |
|
TVM intrinsic for tensor core store operators |
|
TVM intrinsic for ptx tensor core mma instructions |
|
TVM intrinsic for sparse tensor core ptx instructions |
|
TVM intrinsic for storing the result of PTX MMA into a destination pointer |
|
TVM intrinsic for zero-initalizing an MMA accumulation register |
|
TVM intrinsic for ptx load matrix from shared memory |
|
TVM intrinsic for ptx async copy from global to shared memory using cp.async |
|
TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk |
TVM intrinsic for ptx async copy commit |
|
|
TVM intrinsic for ptx async copy wait |
|
TVM intrinsic for amd matrix core mfma instructions |
|
TVM intrinsic for storing the result of PTX MMA into a destination pointer |
|
TVM intrinsic for amd matrix core mfma instructions |
|
TVM intrinsic for storing the result of PTX MMA into a destination pointer |
|
TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive |
|
TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init |
|
TVM intrinsic for ptx barrier arrival using mbarrier.arrive |
|
TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx |
|
TVM intrinsic for ptx barrier wait using mbarrier.try_wait |
|
TVM intrinsic to create N barriers |
|
Get the low level half of the vector |
|
Get the high level half of the vector |
|
Concat two vectors |
|
Create a tir return expression |
|
Create a new expression of the union of all conditions in the arguments |
|
Create a new expression of the intersection of all conditions in the |
|
Trace tensor data at the runtime. |
|
minimum value of dtype |
|
maximum value of dtype |
|
infinity value of dtype |
|
infinity value of dtype |
|
Take exponential of input x. |
|
Calculate 2**x |
|
Calculate 10**x |
|
Take gauss error function of the input x. |
|
Take hyperbolic tanh of input x. |
|
Quick function to get sigmoid |
|
Take log of input x. |
|
Take log2 of input x. |
|
Take log10 of input x. |
|
Take log(x + 1) with respect to input x. |
|
Take tan of input x. |
|
Take cos of input x. |
|
Take cosh of input x. |
|
Take acos of input x. |
|
Take acos of input x. |
|
Take sin of input x. |
|
Take sinh of input x. |
|
Take asin of input x. |
|
Take asinh of input x. |
|
Take atan of input x. |
|
Take atanh of input x. |
|
Take arctan2(x1, x2). |
|
Take square root of input x. |
|
Take reciprocal of square root of input x. |
|
Count leading zero bits of an integer x. |
|
Take floor of float input x. |
|
Take ceil of float input x. |
|
Get truncated value of the input. |
|
Get absolute value of the input element-wise. |
|
Take bitwise and of two values |
|
Take bitwise not of input value |
|
Take bitwise or of two values |
|
Take bitwise xor of two values |
|
Round elements of the array to the nearest integer. |
|
Round elements of the array to the nearest integer. |
|
Return the next floating-point value after x1 towards x2. |
|
Equivalent to sqrt(x1**2 + x2**2), element-wise. |
|
Change the sign of x1 to that of x2, element-wise. |
|
Returns x1 * (2 ** x2). |
|
Mark condition as likely. |
|
Check if input value is Nan. |
|
Check if input value is nullptr. |
|
Check if input value is finite. |
|
Check if input value is infinite. |
|
Fast power operation than pow(float, float). |
|
x power y |
|
x power y |
|
Count the number of set bits in input x. |
|
Execute a multiplication between two Q-numbers x and y |
|
Execute a multiplication between two Q-numbers x and y |
|
Return the result of x left shifted by y bits. |
|
Return the result of x right shifted by y bits. |
|
Return the remainder of x divided by y with the same sign as x. |
|
Conditional selection expression. |
|
Compute a / b as in C/C++ semantics. |
|
Compute floor(a / b) where a and b are non-negative. |
|
Compute the remainder of indexdiv. a and b are non-negative. |
|
Compute the truncdiv of two expressions. |
|
Compute the truncmod of two expressions. |
|
Compute the floordiv of two expressions. |
|
Compute the floormod of two expressions. |
|
Generic ceildiv operator. |
|
Create a commutative reducer for reduction. |
|
Backend function to allocate temporal workspace |
|
Backend function to free temporal workspace. |
|
Returns an item from any list. |
|
Reset an item from any list. |
|
Set anylist item by result of packed call. |
|
Set anylist item by result of packed call. |
|
Get the target's vscale value. It will be lowered to llvm.vscale intrinsic |
Module ContentsΒΆ
- tilelang.language.tir.op.call_packed(*args, span=None)ΒΆ
Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.
- Parameters:
args (list of Expr or Buffer.) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
See also
te.extern
Create tensor with extern function call.
- tilelang.language.tir.op.call_cpacked(*args, span=None)ΒΆ
Build expression by call an external packed function.
Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.
- Parameters:
args (list of Expr or Buffer.) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
See also
te.extern
Create tensor with extern function call.
- tilelang.language.tir.op.call_packed_lowered(*args, span=None)ΒΆ
Lowered version of call packed. The argument to packed function can be Expr or Buffer. The argument is the corresponding POD type when Expr is presented. When the argument is Buffer, the corresponding PackedFunc will receive an TVMArrayHandle whose content is valid during the callback period. If the PackedFunc is a python callback, then the corresponding argument is NDArray.
- Parameters:
args (list of Expr or Buffer.) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
See also
te.extern
Create tensor with extern function call.
- tilelang.language.tir.op.call_cpacked_lowered(*args, span=None)ΒΆ
Lowered version of call c-packed. Same as call_packed, except that the first argument is the function name (as in call_extern), and the last argument is the resource handle.
- Parameters:
args (list of Expr or Buffer.) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
See also
te.extern
Create tensor with extern function call.
- tilelang.language.tir.op.call_intrin(dtype, func_name, *args, span=None)ΒΆ
Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via the intrinsic translation rule.
- Parameters:
dtype (str) β The data type of the result.
func_name (str) β The intrinsic function name.
args (list) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.call_pure_extern(dtype, func_name, *args, span=None)ΒΆ
Build expression by calling a pure extern function.
- Parameters:
dtype (str) β The data type of the result.
func_name (str) β The extern function name.
args (list) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.call_extern(dtype, func_name, *args, span=None)ΒΆ
Build expression by calling a extern function.
- Parameters:
dtype (str) β The data type of the result.
func_name (str) β The extern function name.
args (list) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.call_llvm_intrin(dtype, name, *args, span=None)ΒΆ
Build expression by calling a llvm intrinsic function
- Parameters:
dtype (str) β The data type of the result.
name (str) β The name of the llvm intrinsic function.
args (list) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.call_llvm_pure_intrin(dtype, name, *args, span=None)ΒΆ
Build expression by calling a pure llvm intrinsic function
- Parameters:
dtype (str) β The data type of the result.
name (str) β The name of the llvm intrinsic function.
args (list) β Positional arguments.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_check_return(expected, return_unexpected, nested_call)ΒΆ
Return new on stack dtype[num] :param expected: The expected return code. :type expected: int :param return_unexpected: The unexpected return code. :type return_unexpected: int :param nested_call: The call expression to check return. :type nested_call: PrimExpr
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_stack_alloca(dtype_str, num)ΒΆ
Return new on stack dtype[num]
- Parameters:
dtype_str (str) β The data type of array.
num (int) β The size of array.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_stack_make_shape(*args)ΒΆ
Allocate a shape tuple on stack, return the handle
- Parameters:
args (int) β The tuple shape.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)ΒΆ
Allocate a NDArray(DLTensor) on stack, return the handle
- Parameters:
data (Expr) β The data of array.
shape (Expr) β The shape of array.
strides (Expr) β The strides of array.
ndim (Expr) β The dimensions of array.
arr_dtype (Expr) β The data type of array.
elem_offse (Expr) β The element offset of array.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.assume(cond=None)ΒΆ
Provide a true statement that can be used for simplifications
- Parameters:
cond (Expr) β The constraint condition.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.undef()ΒΆ
Returns an initialized but arbitrary value
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.call_tir(global_var, *args)ΒΆ
Performs a call into another PrimFunc in the same IRModule
- Returns:
call β The call expression.
- Return type:
PrimExpr
- Parameters:
global_var (tvm.ir.GlobalVar)
- tilelang.language.tir.op.start_profile_intrinsic(id)ΒΆ
Start profile intrinsic. :param id: The intrinsic id. :type id: int
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.end_profile_intrinsic(id)ΒΆ
End profile intrinsic. :param id: The intrinsic id. :type id: int
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_tuple(*value)ΒΆ
Create a tuple structure in value field of AttrStmt
- Parameters:
value (Expr) β The value in tuple.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_struct_get(arr, index, field, dtype)ΒΆ
Get struct field value in array
- Parameters:
dtype (str) β The date type of the result.
arr (StructType*) β The array of struct.
index (int) β The index of struct.
field (int) β The field of struct.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_struct_set(arr, index, field, value)ΒΆ
Set value in struct field in array
- Parameters:
arr (StructType*) β The array of struct.
index (int) β The index of struct.
field (int) β The field of struct.
value (Expr) β The value to be set in field.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.address_of(buffer_load, span=None)ΒΆ
Returns the address of an element in the buffer
- Parameters:
buffer_load (BufferLoad) β The buffer load.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.lookup_param(param_name, span=None)ΒΆ
Returns the param by name
- Parameters:
param_name (str) β The name of param.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_thread_allreduce(*freduce_args)ΒΆ
Perform allreduce inside threadblock.
- Parameters:
freduce_args (Expr) β The args.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_thread_invariant(cond)ΒΆ
Mark condition as thread invariant.
- Parameters:
cond (Expr) β The condition.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_storage_sync(storage_scope)ΒΆ
Perform synchronization in specified scope.
- Parameters:
storage_scope (str) β The storage scope to perform synchronization.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_warp_shuffle(mask, value, warp_id, width, warp_size)ΒΆ
Exchange value between threads inside a warp.
- Parameters:
mask (PrimExpr) β The warp mask indicates active threads inside warp.
value (PrimExpr) β The value to exchange.
warp_id (PrimExpr) β The source lane index to fetch value.
width (PrimExpr) β The width of sub-sections to perform warp shuffle.
warp_size (PrimExpr) β The warp size.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_warp_shuffle_up(mask, value, offset, width, warp_size)ΒΆ
Copy value from a lane with lower (by offset) index relative to caller.
- Parameters:
mask (PrimExpr) β The warp mask indicates active threads inside warp.
value (PrimExpr) β The value to exchange.
offset (PrimExpr) β The difference between source lane index and destination lane index: offset = dst_lane_idx - src_lane_idx
width (PrimExpr) β The width of sub-sections to perform warp shuffle.
warp_size (PrimExpr) β The warp size.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_warp_shuffle_down(mask, value, offset, width, warp_size)ΒΆ
Copy value from a lane with higher (by offset) index relative to caller.
- Parameters:
mask (PrimExpr) β The warp mask indicates active threads inside warp.
value (PrimExpr) β The value to exchange.
offset (PrimExpr) β The difference between source lane index and destination lane index: offset = src_lane_idx - dst_lane_idx
width (PrimExpr) β The width of sub-sections to perform warp shuffle.
warp_size (PrimExpr) β The warp size.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_warp_activemask()ΒΆ
Return a 32-bit mask indicates currently active threads in a calling warp.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.type_annotation(dtype)ΒΆ
Create a type annotation expression
- Parameters:
dtype (Expr) β The data type.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_access_ptr(ptype, data, offset, extent, rw_mask)ΒΆ
Get head access address with memory access pattern info
- Parameters:
ptype (Expr) β The data type of pointer.
data (DType*) β The data of pointer.
offset (int) β The offset of pointer.
extent (int) β The extent of pointer.
rw_mask (int) β The read write mask.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_throw_last_error()ΒΆ
Throw TVMGetLastError()
- Returns:
ret β The return expression
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)ΒΆ
TVM intrinsic for tensor core load operators
- Parameters:
fragment (Var) β The wmma fragment.
m (UIntImm) β The shape of wmma fragment.
n (UIntImm) β The shape of wmma fragment.
k (UIntImm) β The shape of wmma fragment.
index (Expr) β The fragment index.
buffer_ptr (Expr) β The fragment buffer pointer.
stride (Expr) β The fragment stride.
layout (Literal["row_major", "column_major"]) β The fragment layout.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)ΒΆ
TVM intrinsic for tensor core mma_sync operators
- Parameters:
fragment_d (Var) β The wmma fragment_d.
index_d (Expr) β The fragment_d index.
fragment_a (Var) β The wmma fragment_a.
index_a (Expr) β The fragment_a index.
fragment_b (Var) β The wmma fragment_b.
index_b (Expr) β The fragment_b index.
fragment_c (Var) β The wmma fragment_c.
index_c (Expr) β The fragment_c index.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)ΒΆ
TVM intrinsic for tensor core bmma_sync operators
- Parameters:
fragment_d (Var) β The bwmma fragment_d.
index_d (Expr) β The fragment_d index.
fragment_a (Var) β The bwmma fragment_a.
index_a (Expr) β The fragment_a index.
fragment_b (Var) β The bwmma fragment_b.
index_b (Expr) β The fragment_b index.
fragment_c (Var) β The bwmma fragment_c.
index_c (Expr) β The fragment_c index.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_fill_fragment(fragment, m, n, k, index, value)ΒΆ
TVM intrinsic for tensor core fill_fragment operators
- Parameters:
fragment (Var) β The wmma fragment
m (UIntImm) β The shape of wmma fragment.
n (UIntImm) β The shape of wmma fragment.
k (UIntImm) β The shape of wmma fragment.
index (Expr) β The fragment index.
value (Expr) β The value to be filled in fragment.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)ΒΆ
TVM intrinsic for tensor core store operators
- Parameters:
fragment (Var) β The wmma fragment.
m (UIntImm) β The shape of wmma fragment.
n (UIntImm) β The shape of wmma fragment.
k (UIntImm) β The shape of wmma fragment.
index (Expr) β The fragment index.
buffer_ptr (Expr) β The fragment buffer pointer.
stride (Expr) β The fragment stride.
layout (Literal["row_major", "column_major"]) β The fragment layout.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)ΒΆ
TVM intrinsic for ptx tensor core mma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
- Parameters:
dtype (str) β The data type of the result.
shape (str) β The shape of mma fragment.
A_layout (Literal["row", "col"]) β The layout of multiplicand fragment A.
B_layout (Literal["row", "col"]) β The layout of multiplicand fragment B.
A_dtype (str) β The data type of multiplicand fragment A.
B_dtype (str) β The data type of multiplicand fragment B.
C_dtype (str) β The data type of accumulator fragment C.
multiplicand_a (Var) β The multiplicand fragment A variable.
a_index (Expr) β The index of multiplicand fragment A.
multiplicand_b (Var) β The multiplicand fragment B variable.
b_index (Expr) β The index of multiplicand fragment A.
accumulator (Var) β The accumulator fragment C variable.
c_index (Expr) β The index of accumulator fragment C.
saturate (bool) β The optional saturation at the output.
operator (Optional[Literal["xor", "and"]]) β The 1-bit operator.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)ΒΆ
TVM intrinsic for sparse tensor core ptx instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
- Parameters:
dtype (str) β The data type of the result.
shape (str) β The shape of mma fragment.
A_layout (Literal["row", "col"]) β The layout of multiplicand fragment A.
B_layout (Literal["row", "col"]) β The layout of multiplicand fragment B.
A_dtype (str) β The data type of multiplicand fragment A.
B_dtype (str) β The data type of multiplicand fragment B.
C_dtype (str) β The data type of accumulator fragment C.
multiplicand_a (Var) β The multiplicand fragment A variable.
a_index (Expr) β The index of multiplicand fragment A.
multiplicand_b (Var) β The multiplicand fragment B variable.
b_index (Expr) β The index of multiplicand fragment B.
accumulator (Var) β The accumulator fragment C variable.
c_index (Expr) β The index of accumulator fragment C.
metadata (Expr) β The metadata of operand.
meta_index (Expr) β The metadata index of operand.
sparse_selector (Expr) β The sparse selector indicating the thread that stores the metadata.
saturate (bool) β The optional saturation at the output.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)ΒΆ
TVM intrinsic for storing the result of PTX MMA into a destination pointer
- Parameters:
dtype (str) β The data type of the result.
m (IntImm) β The shape of mma fragment.
n (IntImm) β The shape of mma fragment.
dst_ptr (Var) β The destination pointer variable.
src_ptr (Var) β The source pointer variable.
src_offset (Expr) β The source offset.
dst_stride (Var) β The destination stride.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.mma_fill(dtype, local_size, local_ptr, offset)ΒΆ
TVM intrinsic for zero-initalizing an MMA accumulation register
- Parameters:
dtype (str) β The data type of the result.
local_size (IntImm) β The number of elements.
local_ptr (Var) β The destination pointer variable.
offset (Expr) β The destination offset.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)ΒΆ
TVM intrinsic for ptx load matrix from shared memory https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
- Parameters:
dtype (str) β The data type of the result.
trans (bool) β The matrix is loaded in column-major format.
num (IntImm) β The number of matrices.
type (Literal[".b16"]) β The data type of the matrices.
local_ptr (Var) β The local pointer variable.
local_offset (Expr) β The offset of local pointer.
smem_ptr (Var) β The shared memory pointer variable.
smem_offset (Expr) β The offset of shared memort pointer.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)ΒΆ
TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
- Parameters:
dtype (str) β The data type of the result.
shared_ptr (Var) β The shared memory pointer variable.
shared_offset (Expr) β The offset of shared memory pointer.
global_ptr (Var) β The global memory pointer variable.
global_offset (Expr) β The offset of global memory pointer.
bytes (int) β The data size to copy.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)ΒΆ
TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
- Parameters:
dtype (str) β The data type of the result.
shared_ptr (Var) β The shared memory pointer variable.
shared_offset (Expr) β The offset of shared memory pointer.
global_ptr (Var) β The global memory pointer variable.
global_offset (Expr) β The offset of global memory pointer.
bytes (int) β The data size to copy.
barrier_id (int) β The ID of the barrier shared memory pointer.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_commit_group()ΒΆ
TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_wait_group(num)ΒΆ
TVM intrinsic for ptx async copy wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group
- Parameters:
num (int) β The number of the most recent uncommitted pending cp.async groups to wait.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_mfma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index)ΒΆ
TVM intrinsic for amd matrix core mfma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
- Parameters:
dtype (str) β The data type of the result.
shape (str) β The shape of mma fragment.
A_layout (Literal["row", "col"]) β The layout of multiplicand fragment A.
B_layout (Literal["row", "col"]) β The layout of multiplicand fragment B.
A_dtype (str) β The data type of multiplicand fragment A.
B_dtype (str) β The data type of multiplicand fragment B.
C_dtype (str) β The data type of accumulator fragment C.
multiplicand_a (Var) β The multiplicand fragment A variable.
a_index (Expr) β The index of multiplicand fragment A.
multiplicand_b (Var) β The multiplicand fragment B variable.
b_index (Expr) β The index of multiplicand fragment A.
accumulator (Var) β The accumulator fragment C variable.
c_index (Expr) β The index of accumulator fragment C.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)ΒΆ
TVM intrinsic for storing the result of PTX MMA into a destination pointer
- Parameters:
dtype (str) β The data type of the result.
m (IntImm) β The shape of mma fragment.
n (IntImm) β The shape of mma fragment.
dst_ptr (Var) β The destination pointer variable.
src_ptr (Var) β The source pointer variable.
src_offset (Expr) β The source offset.
dst_stride (Var) β The destination stride.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_rdna_wmma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index)ΒΆ
TVM intrinsic for amd matrix core mfma instructions https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
- Parameters:
dtype (str) β The data type of the result.
shape (str) β The shape of mma fragment.
A_layout (Literal["row", "col"]) β The layout of multiplicand fragment A.
B_layout (Literal["row", "col"]) β The layout of multiplicand fragment B.
A_dtype (str) β The data type of multiplicand fragment A.
B_dtype (str) β The data type of multiplicand fragment B.
C_dtype (str) β The data type of accumulator fragment C.
multiplicand_a (Var) β The multiplicand fragment A variable.
a_index (Expr) β The index of multiplicand fragment A.
multiplicand_b (Var) β The multiplicand fragment B variable.
b_index (Expr) β The index of multiplicand fragment A.
accumulator (Var) β The accumulator fragment C variable.
c_index (Expr) β The index of accumulator fragment C.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)ΒΆ
TVM intrinsic for storing the result of PTX MMA into a destination pointer
- Parameters:
dtype (str) β The data type of the result.
m (IntImm) β The shape of mma fragment.
n (IntImm) β The shape of mma fragment.
dst_ptr (Var) β The destination pointer variable.
src_ptr (Var) β The source pointer variable.
src_offset (Expr) β The source offset.
dst_stride (Var) β The destination stride.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_cp_async_barrier(barrier_id)ΒΆ
TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
- Parameters:
barrier_id (int) β The ID of the barrier shared memory pointer.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_init_barrier_thread_count(barrier_id, thread_count)ΒΆ
TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
- Parameters:
barrier_id (int) β The ID of the barrier shared memory pointer.
thread_count (int) β Number of threads expected to arrive at the barrier.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_arrive_barrier(barrier_id)ΒΆ
TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
- Parameters:
barrier_id (int) β The ID of the barrier shared memory pointer.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)ΒΆ
TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
- Parameters:
barrier_id (int) β The ID of the barrier shared memory pointer.
byte_count (int) β Increases the tx count of the mbarrier object to track completion of additional async transactions.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ptx_wait_barrier(barrier_id)ΒΆ
TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
- Parameters:
barrier_id (int) β The ID of the barrier shared memory pointer.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.create_barriers(barrier_count)ΒΆ
TVM intrinsic to create N barriers
- Parameters:
barrier_count (int) β The number of barriers to create.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.vectorlow(dtype, vec)ΒΆ
Get the low level half of the vector
- Parameters:
dtype (str) β The data type of the result.
vec (list) β The input vector.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.vectorhigh(dtype, vec)ΒΆ
Get the high level half of the vector
- Parameters:
dtype (str) β The data type of the result.
vec (list) β The input vector.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.vectorcombine(dtype, vec1, vec2)ΒΆ
Concat two vectors
- Parameters:
vec1 (list) β The input vector.
vec2 (list) β The input vector.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ret(val)ΒΆ
Create a tir return expression
- Parameters:
val (Expr) β The returned tir expression, whose data type is int, float or void pointer.
- Returns:
ret β The return expression
- Return type:
PrimExpr
- tilelang.language.tir.op.any(*args, span=None)ΒΆ
Create a new expression of the union of all conditions in the arguments
- Parameters:
args (list) β List of symbolic boolean expressions
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
expr β Expression
- Return type:
Expr
- tilelang.language.tir.op.all(*args, span=None)ΒΆ
- Create a new expression of the intersection of all conditions in the
arguments
- Parameters:
args (list) β List of symbolic boolean expressions
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
expr β Expression
- Return type:
Expr
- tilelang.language.tir.op.trace(args, trace_action='tvm.default_trace_action')ΒΆ
Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.
- Parameters:
args (list of Expr or Buffers.) β Positional arguments.
trace_action (str.) β The name of the trace action.
- Returns:
call β The call expression.
- Return type:
PrimExpr
See also
tvm.tir.call_packed
Creates packed function.
- tilelang.language.tir.op.min_value(dtype, span=None)ΒΆ
minimum value of dtype
- Parameters:
dtype (str) β The data type.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
value β The minimum value of dtype.
- Return type:
tvm.Expr
- tilelang.language.tir.op.max_value(dtype, span=None)ΒΆ
maximum value of dtype
- Parameters:
dtype (str) β The data type.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
value β The maximum value of dtype.
- Return type:
tvm.Expr
- tilelang.language.tir.op.infinity(dtype, span=None)ΒΆ
infinity value of dtype
- Parameters:
dtype (str) β The data type.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
value β The infinity value of dtype.
- Return type:
tvm.Expr
- tilelang.language.tir.op.reinterpret(dtype, value, span=None)ΒΆ
infinity value of dtype
- Parameters:
dtype (str) β The data type.
value (PrimExpr) β The input value.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
value β The reinterpret cast value of dtype.
- Return type:
tvm.Expr
- tilelang.language.tir.op.exp(x)ΒΆ
Take exponential of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.exp2(x)ΒΆ
Calculate 2**x
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.exp10(x)ΒΆ
Calculate 10**x
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.erf(x)ΒΆ
Take gauss error function of the input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.tanh(x)ΒΆ
Take hyperbolic tanh of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.sigmoid(x)ΒΆ
Quick function to get sigmoid
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.log(x)ΒΆ
Take log of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.log2(x)ΒΆ
Take log2 of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.log10(x)ΒΆ
Take log10 of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.log1p(x)ΒΆ
Take log(x + 1) with respect to input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.tan(x)ΒΆ
Take tan of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.cos(x)ΒΆ
Take cos of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.cosh(x)ΒΆ
Take cosh of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.acos(x)ΒΆ
Take acos of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.acosh(x)ΒΆ
Take acos of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.sin(x)ΒΆ
Take sin of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.sinh(x)ΒΆ
Take sinh of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.asin(x)ΒΆ
Take asin of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.asinh(x)ΒΆ
Take asinh of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.atan(x)ΒΆ
Take atan of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.atanh(x)ΒΆ
Take atanh of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.atan2(x1, x2)ΒΆ
Take arctan2(x1, x2).
- Parameters:
x1 (PrimExpr) β Input argument.
x2 (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.sqrt(x)ΒΆ
Take square root of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.rsqrt(x)ΒΆ
Take reciprocal of square root of input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.clz(x)ΒΆ
Count leading zero bits of an integer x.
- Parameters:
x (PrimExpr) β Input 32 or 64 bit integer. The result is undefined if the input is 0.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.floor(x, span=None)ΒΆ
Take floor of float input x.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.ceil(x, span=None)ΒΆ
Take ceil of float input x.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.trunc(x, span=None)ΒΆ
Get truncated value of the input.
The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.abs(x, span=None)ΒΆ
Get absolute value of the input element-wise.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.bitwise_and(x, y, span=None)ΒΆ
Take bitwise and of two values
- Parameters:
x (PrimExpr) β Left operand
y (PrimExpr) β Right operand
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
res β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.bitwise_not(x, span=None)ΒΆ
Take bitwise not of input value
- Parameters:
x (PrimExpr) β Input operand
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
res β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.bitwise_or(x, y, span=None)ΒΆ
Take bitwise or of two values
- Parameters:
x (PrimExpr) β Left operand
y (PrimExpr) β Right operand
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
res β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.bitwise_xor(x, y, span=None)ΒΆ
Take bitwise xor of two values
- Parameters:
x (PrimExpr) β Left operand
y (PrimExpr) β Right operand
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
res β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.round(x, span=None)ΒΆ
Round elements of the array to the nearest integer.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.nearbyint(x, span=None)ΒΆ
Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.nextafter(x1, x2)ΒΆ
Return the next floating-point value after x1 towards x2.
- Parameters:
x1 (PrimExpr) β Input argument.
x2 (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.hypot(x1, x2)ΒΆ
Equivalent to sqrt(x1**2 + x2**2), element-wise.
- Parameters:
x1 (PrimExpr) β Input argument.
x2 (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.copysign(x1, x2)ΒΆ
Change the sign of x1 to that of x2, element-wise.
- Parameters:
x1 (PrimExpr) β Input argument.
x2 (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.ldexp(x1, x2)ΒΆ
Returns x1 * (2 ** x2).
- Parameters:
x1 (PrimExpr) β Input argument.
x2 (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.likely(cond, span=None)ΒΆ
Mark condition as likely.
- Parameters:
cond (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The marked expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.isnan(x, span=None)ΒΆ
Check if input value is Nan.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.isnullptr(x, span=None)ΒΆ
Check if input value is nullptr.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.isfinite(x, span=None)ΒΆ
Check if input value is finite.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.isinf(x, span=None)ΒΆ
Check if input value is infinite.
- Parameters:
x (PrimExpr) β Input argument.
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.pow_of_int(x, y)ΒΆ
Fast power operation than pow(float, float).
- Parameters:
x (PrimExpr) β Base value
y (int) β Exponent value
- Return type:
tvm.ir.PrimExpr
- tilelang.language.tir.op.power(x, y, span=None)ΒΆ
x power y
- Parameters:
x (PrimExpr) β Input argument.
y (PrimExpr) β The exponent
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
z β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.pow(x, y, span=None)ΒΆ
x power y
- Parameters:
x (PrimExpr) β Input argument.
y (PrimExpr) β The exponent
span (Optional[Span]) β The location of this operator in the source code.
- Returns:
z β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.popcount(x)ΒΆ
Count the number of set bits in input x.
- Parameters:
x (PrimExpr) β Input argument.
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.q_multiply_shift(x, y, q, s)ΒΆ
Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:
out = round(x*y*2^-s)
More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1)
- Parameters:
x (PrimExpr) β First Q-number
y (PrimExpr) β Second Q-number
q (PrimExpr) β Number of fractional bits in x and y. Needs to be > 0
s (PrimExpr) β Integer shift
- Returns:
y β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required)ΒΆ
Execute a multiplication between two Q-numbers x and y
- Parameters:
x (PrimExpr) β First Q-number.
y (PrimExpr) β Second Q-number.
ls (PrimExpr) β Integer left shift.
rs (PrimExpr) β Integer right shift.
q (IntImm) β Number of fractional bits in x and y. Needs to be > 0.
is_lshift_required (IntImm) β Whether we need to do left shift or not.
is_rshift_required (IntImm) β Whether we need to do right shift or not.
- Returns:
z β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.shift_left(x, y, span=None)ΒΆ
Return the result of x left shifted by y bits.
- Parameters:
x (PrimExpr) β Input argument.
y (PrimExpr) β Input argument.
- Returns:
z β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.shift_right(x, y, span=None)ΒΆ
Return the result of x right shifted by y bits.
- Parameters:
x (PrimExpr) β Input argument.
y (PrimExpr) β Input argument.
- Returns:
z β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.fmod(x, y)ΒΆ
Return the remainder of x divided by y with the same sign as x.
- Parameters:
x (PrimExpr) β Input argument.
y (PrimExpr) β Input argument.
- Returns:
z β The result.
- Return type:
PrimExpr
- tilelang.language.tir.op.if_then_else(cond, t, f, span=None)ΒΆ
Conditional selection expression.
- Parameters:
cond (PrimExpr) β The condition
t (PrimExpr) β The result expression if cond is true.
f (PrimExpr) β The result expression if cond is false.
span (Optional[Span]) β The location of this operator in the source.
- Returns:
result β The result of conditional expression.
- Return type:
Note
Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.
- tilelang.language.tir.op.div(a, b, span=None)ΒΆ
Compute a / b as in C/C++ semantics.
- Parameters:
a (PrimExpr) β The left hand operand, known to be non-negative.
b (PrimExpr) β The right hand operand, known to be non-negative.
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
Note
When operands are integers, returns truncdiv(a, b, span).
- tilelang.language.tir.op.indexdiv(a, b, span=None)ΒΆ
Compute floor(a / b) where a and b are non-negative.
- Parameters:
a (PrimExpr) β The left hand operand, known to be non-negative.
b (PrimExpr) β The right hand operand, known to be non-negative.
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
Note
Use this function to split non-negative indices. This function may take advantage of operandsβ non-negativeness.
- tilelang.language.tir.op.indexmod(a, b, span=None)ΒΆ
Compute the remainder of indexdiv. a and b are non-negative.
- Parameters:
a (PrimExpr) β The left hand operand, known to be non-negative.
b (PrimExpr) β The right hand operand, known to be non-negative.
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
Note
Use this function to split non-negative indices. This function may take advantage of operandsβ non-negativeness.
- tilelang.language.tir.op.truncdiv(a, b, span=None)ΒΆ
Compute the truncdiv of two expressions.
- Parameters:
a (PrimExpr) β The left hand operand
b (PrimExpr) β The right hand operand
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
Note
This is the default integer division behavior in C.
- tilelang.language.tir.op.truncmod(a, b, span=None)ΒΆ
Compute the truncmod of two expressions.
- Parameters:
a (PrimExpr) β The left hand operand
b (PrimExpr) β The right hand operand
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
Note
This is the default integer division behavior in C.
- tilelang.language.tir.op.floordiv(a, b, span=None)ΒΆ
Compute the floordiv of two expressions.
- Parameters:
a (PrimExpr) β The left hand operand
b (PrimExpr) β The right hand operand
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.floormod(a, b, span=None)ΒΆ
Compute the floormod of two expressions.
- Parameters:
a (PrimExpr) β The left hand operand
b (PrimExpr) β The right hand operand
span (Optional[Span]) β The location of this operator in the source.
- Returns:
res β The result expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.ceildiv(lhs, rhs, span=None)ΒΆ
Generic ceildiv operator.
- Parameters:
lhs (object) β The left operand.
rhs (object) β The right operand.
span (Optional[Span]) β The location of this operator in the source.
- Returns:
op β The result Expr of ceildiv operation.
- Return type:
tvm.Expr
- tilelang.language.tir.op.comm_reducer(fcombine, fidentity, name='reduce')ΒΆ
Create a commutative reducer for reduction.
- Parameters:
fcombine (function(Expr -> Expr -> Expr)) β A binary function which takes two Expr as input to return a Expr.
fidentity (function(str -> Expr)) β A function which takes a type string as input to return a const Expr.
- Returns:
reducer β A function which creates a reduce expression over axis. There are two ways to use it:
accept (expr, axis, where) to produce an Reduce Expr on specified axis;
simply use it with multiple Exprs.
- Return type:
function
Example
n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
- tilelang.language.tir.op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)ΒΆ
Backend function to allocate temporal workspace
- Parameters:
device_type (int) β The device type which the space will be allocated.
device_id (int) β The device id which the space will be allocated.
nbytes (int) β The size of the space requested.
dtype_code_hint (int) β The type code of the array elements. Only used in certain backends such as OpenGL.
dtype_bits_hint (int) β The type bits of the array elements. Only used in certain backends such as OpenGL.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.TVMBackendFreeWorkspace(device_type, device_id, ptr)ΒΆ
Backend function to free temporal workspace.
- Parameters:
device_type (int) β The device type which the space will be allocated.
device_id (int) β The device id which the space will be allocated.
ptr (Var) β The result allocated space pointer.
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.anylist_getitem(list_handle, index)ΒΆ
Returns an item from any list. list_handle: Var
The handle to anylist
- indexint
The index
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.anylist_resetitem(list_handle, index)ΒΆ
Reset an item from any list. list_handle: Var
The handle to anylist
- indexint
The index
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.anylist_setitem_call_packed(list_handle, index, func_name, *args)ΒΆ
Set anylist item by result of packed call. list_handle: Var
The handle to anylist
- indexint
The index
- func_name: str
The name of the function to be called.
- Parameters:
arguments (Extra)
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.anylist_setitem_call_cpacked(list_handle, index, func_name, *args)ΒΆ
Set anylist item by result of packed call. list_handle: Var
The handle to anylist
- indexint
The index
- func_name: str
The name of the function to be called.
- Parameters:
arguments (Extra)
- Returns:
call β The call expression.
- Return type:
PrimExpr
- tilelang.language.tir.op.vscale()ΒΆ
Get the targetβs vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) :returns: call β Call to the vscale intrinsic :rtype: PrimExpr
- tilelang.language.tir.op.sumΒΆ
- tilelang.language.tir.op.minΒΆ
- tilelang.language.tir.op.maxΒΆ