Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 245 additions & 1 deletion xformers/ops/_triton/tiled_matmul_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,215 @@ def _xformers_tiled_matmul_kernel(
tl.atomic_add(C, acc, mask=mask)


@triton.autotune(
configs=TRITON_CONFIGS,
key=["M1", "M2", "M3", "N1", "N2", "N3", "K1", "K2", "K3"],
prune_configs_by={
"early_config_prune": our_early_config_prune,
"perf_model": our_estimate_matmul_time,
"top_k": 10,
},
)
@triton.heuristics(
{
"EVEN_K": lambda args: all(
k % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
for k in [args["K1"], args["K2"], args["K3"]]
),
}
)
@triton.jit()
def _xformers_tiled_matmul_kernel_int64(
A11,
A12,
A13,
A21,
A22,
A23,
A31,
A32,
A33,
B11,
B12,
B13,
B21,
B22,
B23,
B31,
B32,
B33,
C11,
C12,
C13,
C21,
C22,
C23,
C31,
C32,
C33,
M1,
M2,
M3,
N1,
N2,
N3,
K1,
K2,
K3,
stride_am1,
stride_am2,
stride_am3,
stride_ak1,
stride_ak2,
stride_ak3,
stride_bk1,
stride_bk2,
stride_bk3,
stride_bn1,
stride_bn2,
stride_bn3,
stride_cm1,
stride_cm2,
stride_cm3,
stride_cn1,
stride_cn2,
stride_cn3,
BLOCK_M: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
BLOCK_N: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
BLOCK_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0).to(tl.int64)
pid_k = tl.program_id(1).to(tl.int64)
Comment on lines +431 to +432
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect these are the only two lines that differ wrt the original version, is that so?

What is the downside of always casting to int64 in the original kernel? Did you observe some performance regression?

grid_m1 = tl.cdiv(M1, BLOCK_M)
grid_m2 = tl.cdiv(M2, BLOCK_M)
grid_m3 = tl.cdiv(M3, BLOCK_M)
grid_n1 = tl.cdiv(N1, BLOCK_N)
grid_n2 = tl.cdiv(N2, BLOCK_N)
grid_n3 = tl.cdiv(N3, BLOCK_N)
grid_m = grid_m1 + grid_m2 + grid_m3
grid_n = grid_n1 + grid_n2 + grid_n3

# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)

# We use tl.where to circumvent a regression in alignment auto-detection:
# https://github.com/openai/triton/issues/1784

A1 = tl.where(pid_m < grid_m1, A11, tl.where(pid_m < grid_m1 + grid_m2, A21, A31))
A2 = tl.where(pid_m < grid_m1, A12, tl.where(pid_m < grid_m1 + grid_m2, A22, A32))
A3 = tl.where(pid_m < grid_m1, A13, tl.where(pid_m < grid_m1 + grid_m2, A23, A33))
B1 = tl.where(pid_n < grid_n1, B11, tl.where(pid_n < grid_n1 + grid_n2, B12, B13))
B2 = tl.where(pid_n < grid_n1, B21, tl.where(pid_n < grid_n1 + grid_n2, B22, B23))
B3 = tl.where(pid_n < grid_n1, B31, tl.where(pid_n < grid_n1 + grid_n2, B32, B33))
C = tl.where(
pid_m < grid_m1,
tl.where(pid_n < grid_n1, C11, tl.where(pid_n < grid_n1 + grid_n2, C12, C13)),
tl.where(
pid_m < grid_m1 + grid_m2,
tl.where(
pid_n < grid_n1, C21, tl.where(pid_n < grid_n1 + grid_n2, C22, C23)
),
tl.where(
pid_n < grid_n1, C31, tl.where(pid_n < grid_n1 + grid_n2, C32, C33)
),
),
)
M = tl.where(pid_m < grid_m1, M1, tl.where(pid_m < grid_m1 + grid_m2, M2, M3))
N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3))
stride_ak = tl.where(
pid_m < grid_m1,
stride_ak1,
tl.where(pid_m < grid_m1 + grid_m2, stride_ak2, stride_ak3),
)
stride_bk = tl.where(
pid_n < grid_n1,
stride_bk1,
tl.where(pid_n < grid_n1 + grid_n2, stride_bk2, stride_bk3),
)
stride_cn = tl.where(
pid_m < grid_m1,
stride_cn1,
tl.where(pid_m < grid_m1 + grid_m2, stride_cn2, stride_cn3),
)
stride_cm = tl.where(
pid_n < grid_n1,
stride_cm1,
tl.where(pid_n < grid_n1 + grid_n2, stride_cm2, stride_cm3),
)
pid_m = tl.where(
pid_m < grid_m1,
pid_m,
tl.where(pid_m < grid_m1 + grid_m2, pid_m - grid_m1, pid_m - grid_m1 - grid_m2),
)
pid_n = tl.where(
pid_n < grid_n1,
pid_n,
tl.where(pid_n < grid_n1 + grid_n2, pid_n - grid_n1, pid_n - grid_n1 - grid_n2),
)

# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
# pointers
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
grid_k1 = tl.cdiv(K1, BLOCK_K)
grid_k2 = tl.cdiv(K2, BLOCK_K)
grid_k3 = tl.cdiv(K3, BLOCK_K)
for tile in range(pid_k, grid_k1 + grid_k2 + grid_k3, SPLIT_K):
A = tl.where(tile < grid_k1, A1, tl.where(tile < grid_k1 + grid_k2, A2, A3))
B = tl.where(tile < grid_k1, B1, tl.where(tile < grid_k1 + grid_k2, B2, B3))
K = tl.where(tile < grid_k1, K1, tl.where(tile < grid_k1 + grid_k2, K2, K3))
stride_am = tl.where(
tile < grid_k1,
stride_am1,
tl.where(tile < grid_k1 + grid_k2, stride_am2, stride_am3),
)
stride_bn = tl.where(
tile < grid_k1,
stride_bn1,
tl.where(tile < grid_k1 + grid_k2, stride_bn2, stride_bn3),
)
my_tile = tl.where(
tile < grid_k1,
tile,
tl.where(
tile < grid_k1 + grid_k2, tile - grid_k1, tile - grid_k1 - grid_k2
),
)
rk = my_tile * BLOCK_K + tl.arange(0, BLOCK_K)
Ain = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
Bin = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
if EVEN_K:
a = tl.load(Ain)
b = tl.load(Bin)
else:
a = tl.load(Ain, mask=rk[None, :] < K, other=0.0)
b = tl.load(Bin, mask=rk[:, None] < K, other=0.0)
acc += tl.dot(a, b, allow_tf32=False)
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)


def _check_row_or_column(row_or_col_type, row_or_col_idx, tensor_name, dim_name, vals):
assert len(vals) > 0
for pos, val in enumerate(vals[1:]):
Expand Down Expand Up @@ -407,7 +616,42 @@ def grid(META):
META["SPLIT_K"],
)

_xformers_tiled_matmul_kernel[grid](
# Decide whether 32-bit address arithmetic can overflow; if so, use int64-safe kernel
INT32_MAX = (1 << 31) - 1
def _dim_or_zero(xs, i):
return xs[i] if len(xs) > i else 0

use_int64 = False
for i in range(3):
Mi = max(0, _dim_or_zero(ms, i))
Ni = max(0, _dim_or_zero(ns, i))
Ki = max(0, _dim_or_zero(ks, i))

# A offsets
a_row_term = max(0, Mi - 1) * int(strides_am[i])
a_col_term = max(0, Ki - 1) * int(strides_ak[i])
# B offsets
b_row_term = max(0, Ki - 1) * int(strides_bk[i])
b_col_term = max(0, Ni - 1) * int(strides_bn[i])
# C offsets
c_row_term = max(0, Mi - 1) * int(strides_cm[i])
c_col_term = max(0, Ni - 1) * int(strides_cn[i])

# Check per-term and per-address sums
if (
a_row_term > INT32_MAX or a_col_term > INT32_MAX or
b_row_term > INT32_MAX or b_col_term > INT32_MAX or
c_row_term > INT32_MAX or c_col_term > INT32_MAX or
(a_row_term + a_col_term) > INT32_MAX or
(b_row_term + b_col_term) > INT32_MAX or
(c_row_term + c_col_term) > INT32_MAX
):
use_int64 = True
break
Comment on lines +619 to +650
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code can be simplified


kernel_to_launch = _xformers_tiled_matmul_kernel_int64 if use_int64 else _xformers_tiled_matmul_kernel

kernel_to_launch[grid](
*[
a[min(i, len(a) - 1)][min(j, len(a[0]) - 1)]
for i in range(3)
Expand Down