Skip to content

Commit afc33f3

Browse files
committed
stuff
1 parent 63ac004 commit afc33f3

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

fast_llm/layers/language_model/loss/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_loss(
7777
sequence_parallel_logits: bool = False,
7878
kwargs: dict[str, typing.Any],
7979
) -> "tuple[torch.Tensor, torch.Tensor | None]":
80-
from fast_llm.functional.cross_entropy import entropy_loss_forward_backward
80+
from fast_llm.functional.entropy_loss import entropy_loss_forward_backward
8181

8282
labels = kwargs[LanguageModelKwargs.labels]
8383

@@ -167,7 +167,7 @@ def get_loss(
167167
sequence_parallel_logits: bool = False,
168168
kwargs: dict[str, typing.Any],
169169
) -> "tuple[torch.Tensor, torch.Tensor | None]":
170-
from fast_llm.functional.cross_entropy import entropy_loss_forward_backward
170+
from fast_llm.functional.entropy_loss import entropy_loss_forward_backward
171171

172172
if prediction_distance > 0:
173173
raise NotImplementedError()

tests/functional/test_entropy_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from fast_llm.engine.distributed.config import DistributedBackend
77
from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig
8-
from fast_llm.functional.cross_entropy import entropy_loss_forward_backward
8+
from fast_llm.functional.entropy_loss import entropy_loss_forward_backward
99
from fast_llm.utils import Assert
1010
from tests.utils.subtest import DistributedTestContext
1111

0 commit comments

Comments
 (0)