diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 765cf2872f..99817f0657 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1436,6 +1436,15 @@ def fused_attn( context_parallel_axis=context_parallel_axis, softmax_offset=softmax_offset, ) + if max_segments_per_seq > 1 and not qkv_layout.is_thd(): + warnings.warn( + f"max_segments_per_seq={max_segments_per_seq} is set but qkv_layout={qkv_layout} is " + "not a THD layout. max_segments_per_seq > 1 only applies when using THD layouts " + "(e.g. QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD) for sequence " + "packing.", + UserWarning, + stacklevel=2, + ) output = _fused_attn( qkv, bias,