Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c335f6e
train with only layer distillation losses
oleksost Dec 16, 2025
e06a4b2
unscaled loss llogging + training with distillation loss factor = 0
oleksost Dec 16, 2025
179ae25
make logging more explicit
oleksost Dec 17, 2025
af456f0
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 17, 2025
9968aac
clean + tests
oleksost Dec 17, 2025
945c5a7
nvm
oleksost Dec 17, 2025
4b6e3d7
forward KL
oleksost Dec 19, 2025
c5fefa0
test forward kl
oleksost Dec 19, 2025
4119596
wip: report unscaled + kl loss
oleksost Dec 19, 2025
b55a0a4
loss config
oleksost Dec 22, 2025
097baeb
wip
oleksost Dec 22, 2025
d773d98
tests
oleksost Dec 22, 2025
35400c1
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 22, 2025
282925c
test
oleksost Dec 22, 2025
0f73ea2
tests
oleksost Dec 22, 2025
04a0193
Merge branch 'main' into train_only_layer_losses
oleksost Dec 22, 2025
fa85c41
wip
oleksost Dec 22, 2025
feb416e
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 22, 2025
31cfb84
wip
oleksost Dec 23, 2025
24fe67b
no grad if factor 0
oleksost Dec 23, 2025
00f6118
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 23, 2025
0cadf98
Merge branch 'main' into train_only_layer_losses
oleksost Dec 23, 2025
0e562e9
addressed comments
oleksost Dec 23, 2025
2a474e2
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 23, 2025
52c1c11
addressed comments
oleksost Dec 23, 2025
406d0a2
Removed Targets class
oleksost Dec 30, 2025
f25380a
fixes
oleksost Dec 30, 2025
8adb7dd
imports
oleksost Dec 30, 2025
1ce641d
polish naming
oleksost Jan 6, 2026
95f14af
addresseing comments
oleksost Jan 8, 2026
5ad4c0c
explicit z_loss grads
oleksost Jan 8, 2026
0a66e14
removed z_loss as aux loss
oleksost Jan 8, 2026
f8f7041
move loss configs to the lm config
oleksost Jan 8, 2026
ab9c917
tests
oleksost Jan 8, 2026
89470dc
Merge branch 'main' into train_only_layer_losses
oleksost Jan 9, 2026
6e54c93
comments
oleksost Jan 12, 2026
8137b8c
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
jlamypoirier Jan 13, 2026
3c8f3c2
misc
jlamypoirier Jan 13, 2026
705c482
fix
jlamypoirier Jan 13, 2026
3c8ce50
Merge branch 'main' into train_only_layer_losses
jlamypoirier Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ class CrossEntropyImpl(str, enum.Enum):
triton = "triton"


class DistillationLossImpl(str, enum.Enum):
reverse_kl = "reverse_kl"
cross_entropy = "cross_entropy"


class TargetFormat(enum.StrEnum):
labels = "labels"
logits = "logits"
Expand Down
106 changes: 79 additions & 27 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _fused_softmax(
return exp_logits / sum_exp_logits


# @torch.compile
@torch.compile
def _fused_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
Expand All @@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward(
target_format: TargetFormat,
group: ProcessGroup | None = None,
teacher_softmax_temperature: float = 1.0,
return_kl_loss: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A fused implementation of cross-entropy with torch compile.
Expand All @@ -97,24 +98,28 @@ def _fused_cross_entropy_forward_backward(
logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group)

if target_format == TargetFormat.logits:
target = _fused_softmax(target, logits_scale_factor / teacher_softmax_temperature, group)
target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base(
target, logits_scale_factor / teacher_softmax_temperature, group
)
target = exp_logits_targets / sum_exp_target_logits

if target_format == TargetFormat.labels:
target = target.unsqueeze(-1)
loss_mask = target >= 0
if group is None:
# Keep values within range for scatter and gather ops to work.
target = target * loss_mask
target_masked = target * loss_mask
target_mask = None
else:
# Mask the target (fused)
# TODO: Could mask earlier on cpu or overlap with reduce?
vocab_start_index = logits.size(-1) * group.rank()
target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1))
target = (target - vocab_start_index) * target_mask
target_masked = (target - vocab_start_index) * target_mask
else:
# Target should be tensor-parallel already, no further manipulation needed.
target_mask = None
target_masked = target
if loss_mask is not None:
loss_mask = loss_mask.unsqueeze(-1)

Expand All @@ -124,10 +129,10 @@ def _fused_cross_entropy_forward_backward(
# grad / grad_output = exp_logits / sum_exp_logits - target_probabilities.
if target_format == TargetFormat.labels:
grad_base = exp_logits.scatter_add(
1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits)
1, target_masked, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits)
)
else:
grad_base = exp_logits - sum_exp_logits * target
grad_base = exp_logits - sum_exp_logits * target_masked

grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits)
if logits_scale_factor != 1.0:
Expand All @@ -138,13 +143,13 @@ def _fused_cross_entropy_forward_backward(

# loss = mean(log(sum_exp_logits) - sum(probabilities * logits))
if target_format == TargetFormat.labels:
predicted_logits = logits_norm.gather(1, target)
predicted_logits = logits_norm.gather(1, target_masked)
if group is not None:
predicted_logits = target_mask * predicted_logits

all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
else:
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)
predicted_logits = (target_masked * logits_norm).sum(dim=-1, keepdim=True)
if group is not None and target_format != TargetFormat.labels:
# this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit.
# Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i)
Expand All @@ -158,6 +163,18 @@ def _fused_cross_entropy_forward_backward(
loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.AVG, group=group)
if return_kl_loss:
if target_format == TargetFormat.logits:
teacher_log_prob = target_logits - sum_exp_target_logits.log()
else:
teacher_log_prob = torch.log(target + 1e-20)
target_entropy = -(target * teacher_log_prob).sum(dim=-1)
if loss_mask is not None:
target_entropy = target_entropy * loss_mask.squeeze(-1)
target_entropy = target_entropy.mean()
if group is not None:
all_reduce(target_entropy, op=ReduceOp.SUM, group=group)
loss -= target_entropy

return loss, grad

Expand Down Expand Up @@ -233,11 +250,7 @@ def _reverse_kl_forward_backward(
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
target_format: TargetFormat,
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
Expand All @@ -249,13 +262,6 @@ def _reverse_kl_forward_backward(
loss_mask: [BxS] or [B, S] or None
...
"""
Assert.eq(
teacher_softmax_temperature,
1,
msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL",
)
Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
Expand Down Expand Up @@ -311,7 +317,6 @@ def reverse_kl_forward_backward(
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
sequence_parallel_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher).
Expand All @@ -334,26 +339,73 @@ def reverse_kl_forward_backward(
loss: Reverse KL divergence loss
grad: Gradients w.r.t. logits
"""

if sequence_parallel_logits:
# TODO: see hybrid dev branch where it is implemented
raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true")

Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(
teacher_softmax_temperature,
1,
msg="Teacher softmax temperature must be 1 for reverse KL",
)
Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for reverse KL")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# TODO: implement fused?
distillation_loss, distillation_grad = _reverse_kl_forward_backward(
logits=logits,
target=target,
loss_mask=loss_mask,
grad_output=grad_output,
group=group,
)
return distillation_loss, distillation_grad


def forward_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student).
This is mode-covering (vs. mode-seeking for reverse KL) and useful for:
- Encouraging the model to cover all modes of the target distribution
- Spreading probability mass broadly across the target support
- Standard distillation scenarios where you want to match the full teacher distribution

Key differences from reverse KL:
- Forward KL: KL(p||q) = mode-covering (spreads mass broadly)
- Reverse KL: KL(q||p) = mode-seeking (focuses on target modes)

Takes:
logits: [BxS, V] or [B, S, V], where V is local vocab size
target: [BxS, V] or [B, S, V] (logits format)
loss_mask: [BxS] or [B, S] or None
...

Returns:
loss: Forward KL divergence loss
grad: Gradients w.r.t. logits
"""
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

return _fused_cross_entropy_forward_backward(
logits=logits,
target=target,
loss_mask=loss_mask,
grad_output=grad_output,
logits_scale_factor=logits_scale_factor,
target_format=target_format,
teacher_softmax_temperature=teacher_softmax_temperature,
group=group,
teacher_softmax_temperature=teacher_softmax_temperature,
return_kl_loss=True,
)
return distillation_loss, distillation_grad
39 changes: 33 additions & 6 deletions fast_llm/layers/common/auxiliary_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

class AuxiliaryLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, scores: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa
def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa
ctx.grad = torch.full_like(aux_loss, grad)
return scores
return input_

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa
Expand All @@ -14,12 +14,12 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]:

@torch.compile
def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor:
if logits_scale_factor != 1.0:
logits *= logits_scale_factor
return torch.mean(torch.logsumexp(logits, dim=-1) ** 2)
return torch.mean(
torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2
)


def z_loss(
def auxiliary_z_loss(
logits: torch.Tensor,
z_loss_factor: float,
training: bool,
Expand All @@ -36,3 +36,30 @@ def z_loss(
logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale)

return logits


def z_loss_forward_backward(
logits: torch.Tensor,
grad_output: float | None = None,
logits_scale_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute z-loss and its gradient.

Z-loss = mean(logsumexp(logits, dim=-1) ** 2)

Returns:
loss: The z-loss value (unscaled)
grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None
"""

with torch.set_grad_enabled(grad_output is not None):
logits_ = logits.detach().requires_grad_(grad_output is not None)
loss = calculate_z_loss(logits, logits_scale_factor)
if grad_output is None:
grad = None
else:
loss.backward(torch.full_like(loss, grad_output))
grad = logits_.grad.detach().to(logits.dtype)

return loss, grad
4 changes: 2 additions & 2 deletions fast_llm/layers/decoder/mlp/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fast_llm.functional.triton.sparse_copy import get_sparse_map
from fast_llm.layers.attention.config import AttentionKwargs
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, auxiliary_z_loss
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType
from fast_llm.layers.decoder.mlp.mlp import MLPBase
Expand Down Expand Up @@ -102,7 +102,7 @@ def _forward(

# Apply z_loss if applicable
if self._config.z_loss_coefficient > 0.0:
logits = z_loss(
logits = auxiliary_z_loss(
logits,
self._config.z_loss_coefficient,
self.training,
Expand Down
Loading