Skip to content

[Feature][NSA] Implement Grouped-Query Attention (GQA) kernel with sliding window. #167

@michaelwithu

Description

@michaelwithu

Parent Issue

Part of #70

Task Type

  • L1: Kernel Implementation (Write TileLang kernel)
  • L2: Op Implementation (Wrapper + Unit Tests + Benchmarks)
  • L3: Function Implementation (Autograd Function)
  • L4: Layer Implementation (nn.Module Wrapper)
  • Benchmarks (Performance Profiling)

Description

This kernel is used in NSA (Neural Sequence Attention) to preserve local precision during long-context modeling by retaining fine-grained token selection. The kernel supports both:

  1. Variable-length sequences (varlen)
  2. Sliding window attention (local attention with fixed window size)

The kernel should operate on unpadded tensors and accept the following inputs and outputs:

Inputs

Q_unpad, shape is [UQ, heads, dim]
K_unpad, shape is [UKV, head_kv, dim]
V_unpad, shape is [UKV, head_kv, dim]
cu_seqlens_q, shape is [B + 1]
cu_seqlens_k, shape is [B + 1]
window_size_left,
window_size_right.

Outputs:

Output_unpad shape is [UQ, heads, dim]。

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions