-
Notifications
You must be signed in to change notification settings - Fork 35
Open
Description
Currently, loss-masking for distillation is implemented by taking the mean over the masked per-token loss terms (thus including many 0
terms when there are masked tokens)
Instead, we should take the mean only over non-masked tokens.
places in the code we need to change:
- https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/cross_entropy.py#L43
- https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/cross_entropy.py#L152
- https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/cross_entropy.py#L277
- https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py#L45
- https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py#L87
Metadata
Metadata
Assignees
Labels
No labels