Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _select_experts(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)
topk_weights.mul_(self.routed_scaling_factor)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if self.redundancy_expert_num > 0:
redundancy_topk_ids_repair(
topk_ids=topk_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def _select_experts(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)
topk_weights.mul_(self.routed_scaling_factor)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fused_topk(
sgl_ops.topk_softmax(
topk_weights,
topk_ids,
gating_output.float(), # TODO(woosuk): Optimize this.
gating_output,
renormalize=renormalize,
)
return topk_weights, topk_ids
Expand Down
136 changes: 136 additions & 0 deletions lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,139 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps):
num_warps=4,
)
return x


@triton.jit
def _qk_rms_norm_fused_kernel(
# Q Pointers & Strides
Q_ptr,
WQ_ptr,
stride_q_row,
stride_q_col,
# K Pointers & Strides
K_ptr,
WK_ptr,
stride_k_row,
stride_k_col,
# Dimensions
num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments within the new _qk_rms_norm_fused_kernel Triton kernel are in Chinese (e.g., "Q 的头数 (用于判断边界)"). For consistency with the rest of the codebase and to ensure it's understandable for all contributors, please translate these comments to English.

head_dim: tl.constexpr,
eps,
BLOCK_SIZE: tl.constexpr,
):
# PID 0: 处理第几个 Token (Row)
row_idx = tl.program_id(0)
# PID 1: 处理第几个 Head (Combo Index)
# 范围是 [0, num_heads_q + num_heads_k)
combo_head_idx = tl.program_id(1)

# 公共的 offset (0 ~ head_dim)
offs = tl.arange(0, BLOCK_SIZE)

# === 分支逻辑:判断是处理 Q 还是 K ===
if combo_head_idx < num_heads_q:
# ------------------ 处理 Q ------------------
# 指针计算
# Q 的实际 head index 就是 combo_head_idx
Q_ptr += row_idx * stride_q_row

# 定位 Q 数据: Base + Row偏移 + Head偏移 + 列偏移
q_ptr_offset = (combo_head_idx * head_dim + offs) * stride_q_col

# 加载 Q 数据
x = tl.load(Q_ptr + q_ptr_offset).to(tl.float32)
# RMSNorm 计算
var = tl.sum(x * x, axis=0) / head_dim
rstd = 1 / tl.sqrt(var + eps)

# 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重)
w = tl.load(WQ_ptr + offs).to(tl.float32)

y = x * rstd * w

# 写回 Q
tl.store(Q_ptr + q_ptr_offset, y.to(Q_ptr.dtype.element_ty))

else:
# ------------------ 处理 K ------------------
# 重新映射 K 的 head index (从 0 开始)
k_head_idx = combo_head_idx - num_heads_q

# 指针计算
K_ptr += row_idx * stride_k_row
k_ptr_offset = (k_head_idx * head_dim + offs) * stride_k_col

# 加载 K 数据
x = tl.load(K_ptr + k_ptr_offset).to(tl.float32)
# RMSNorm 计算
var = tl.sum(x * x, axis=0) / head_dim
rstd = 1 / tl.sqrt(var + eps)

# 加载 K 的权重
w = tl.load(WK_ptr + offs).to(tl.float32)

y = x * rstd * w

# 写回 K
tl.store(K_ptr + k_ptr_offset, y.to(K_ptr.dtype.element_ty))


def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6):
"""
In-place RMSNorm for both Q and K in a single kernel launch.
Supports GQA (different number of heads for Q and K).

Args:
q: (Total_Tokens, Hidden_Q) or (B, S, H_q, D) -> flattend to 2D inside
k: (Total_Tokens, Hidden_K)
w_q: (head_dim,) Scale parameter for Q
w_k: (head_dim,) Scale parameter for K
"""
# 1. 维度与连续性检查
# 将输入统一视为 (Total_Tokens, Hidden_Size) 的 2D 视图
q_view = q.view(-1, q.shape[-1])
k_view = k.view(-1, k.shape[-1])

assert w_q.is_contiguous() and w_k.is_contiguous()

M = q_view.shape[0] # Total Tokens
assert k_view.shape[0] == M, "Q and K must have the same number of tokens"

head_dim = w_q.shape[0]
assert w_k.shape[0] == head_dim, "Head dim of Q and K must match"

# 计算 Head 数量
N_q = q_view.shape[1]
N_k = k_view.shape[1]

assert N_q % head_dim == 0
assert N_k % head_dim == 0

num_heads_q = N_q // head_dim
num_heads_k = N_k // head_dim

# 2. Block Size 设置
BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert BLOCK_SIZE == head_dim, "Currently only supports head_dim power of 2 (e.g., 64, 128)"

# 3. 启动 Kernel
# Grid: (Token数量, Q头数 + K头数)
grid = (M, num_heads_q + num_heads_k)

_qk_rms_norm_fused_kernel[grid](
q_view,
w_q,
q_view.stride(0),
q_view.stride(1),
k_view,
w_k,
k_view.stride(0),
k_view.stride(1),
num_heads_q=num_heads_q,
head_dim=head_dim,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=4,
)

return q, k
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
"num_stages": 3,
"num_warps": 4
},
"192": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 32,
Expand All @@ -35,6 +44,15 @@
"num_stages": 2,
"num_warps": 4
},
"384": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
Expand All @@ -53,6 +71,24 @@
"num_stages": 2,
"num_warps": 4
},
"640": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"768": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
Expand All @@ -79,5 +115,23 @@
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"896": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 4,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@
"num_stages": 3,
"num_warps": 4
},
"112": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 8
},
"12": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 8
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
Expand All @@ -44,6 +62,15 @@
"num_stages": 3,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
Expand All @@ -62,6 +89,15 @@
"num_stages": 3,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 8
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
Expand All @@ -79,5 +115,23 @@
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 8
},
"80": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
"BLOCK_SIZE": 256,
"num_warps": 4
},
"112": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"12": {
"BLOCK_SIZE": 512,
"num_warps": 8
},
"128": {
"BLOCK_SIZE": 256,
"num_warps": 8
Expand All @@ -19,6 +27,14 @@
"BLOCK_SIZE": 128,
"num_warps": 8
},
"2": {
"BLOCK_SIZE": 256,
"num_warps": 4
},
"24": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"256": {
"BLOCK_SIZE": 128,
"num_warps": 8
Expand All @@ -27,12 +43,24 @@
"BLOCK_SIZE": 128,
"num_warps": 8
},
"48": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"64": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"8": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"80": {
"BLOCK_SIZE": 128,
"num_warps": 8
},
"96": {
"BLOCK_SIZE": 128,
"num_warps": 8
}
}
Loading