Skip to content

Commit f08ac90

Browse files
authored
Loss masking for distillation (#250)
1 parent 24871d0 commit f08ac90

File tree

8 files changed

+146
-60
lines changed

8 files changed

+146
-60
lines changed

fast_llm/functional/cross_entropy.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from fast_llm.utils import Assert
99

1010

11-
def torch_cross_entropy_forward_backward(
11+
def _torch_cross_entropy_forward_backward(
1212
logits: torch.Tensor,
1313
target: torch.Tensor,
14+
loss_mask: torch.Tensor | None,
1415
grad_output: float | None,
1516
logits_scale_factor: float,
1617
target_format: TargetFormat,
@@ -28,9 +29,17 @@ def torch_cross_entropy_forward_backward(
2829
if logits_scale_factor != 1.0:
2930
target = target * logits_scale_factor
3031
target = torch.softmax(target, dim=-1)
31-
loss = torch.nn.functional.cross_entropy(
32-
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target
33-
).mean()
32+
if loss_mask is None:
33+
loss = torch.nn.functional.cross_entropy(
34+
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target
35+
)
36+
else:
37+
loss = (
38+
torch.nn.functional.cross_entropy(
39+
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none"
40+
)
41+
* loss_mask
42+
).mean()
3443
if grad_output is None:
3544
grad = None
3645
else:
@@ -39,7 +48,7 @@ def torch_cross_entropy_forward_backward(
3948
return loss.detach_(), grad
4049

4150

42-
# @torch.compile
51+
@torch.compile
4352
def _fused_softmax_base(
4453
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1
4554
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -57,18 +66,19 @@ def _fused_softmax_base(
5766
return logits_norm, exp_logits, sum_exp_logits
5867

5968

60-
# @torch.compile
61-
def fused_softmax(
69+
@torch.compile
70+
def _fused_softmax(
6271
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1
6372
) -> torch.Tensor:
6473
_, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim)
6574
return exp_logits / sum_exp_logits
6675

6776

68-
@torch.compile
69-
def fused_cross_entropy_forward_backward(
77+
# @torch.compile
78+
def _fused_cross_entropy_forward_backward(
7079
logits: torch.Tensor,
7180
target: torch.Tensor,
81+
loss_mask: torch.Tensor | None,
7282
grad_output: float | None,
7383
logits_scale_factor: float,
7484
target_format: TargetFormat,
@@ -85,7 +95,7 @@ def fused_cross_entropy_forward_backward(
8595
logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group)
8696

8797
if target_format == TargetFormat.logits:
88-
target = fused_softmax(target, logits_scale_factor, group)
98+
target = _fused_softmax(target, logits_scale_factor, group)
8999

90100
if target_format == TargetFormat.labels:
91101
target = target.unsqueeze(-1)
@@ -101,10 +111,10 @@ def fused_cross_entropy_forward_backward(
101111
target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1))
102112
target = (target - vocab_start_index) * target_mask
103113
else:
104-
# TODO: Support masking
105-
loss_mask = None
106114
# Target should be tensor-parallel already, no further manipulation needed.
107115
target_mask = None
116+
if loss_mask is not None:
117+
loss_mask = loss_mask.unsqueeze(-1)
108118

109119
if grad_output is None:
110120
grad = None
@@ -120,9 +130,9 @@ def fused_cross_entropy_forward_backward(
120130
grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits)
121131
if logits_scale_factor != 1.0:
122132
grad *= logits_scale_factor
123-
grad = grad.to(logits.dtype)
124133
if loss_mask is not None:
125-
grad = torch.where(loss_mask, grad.to(logits.dtype), 0)
134+
grad *= loss_mask
135+
grad = grad.to(logits.dtype)
126136

127137
# loss = mean(log(sum_exp_logits) - sum(probabilities * logits))
128138
if target_format == TargetFormat.labels:
@@ -145,15 +155,16 @@ def fused_cross_entropy_forward_backward(
145155

146156

147157
_CROSS_ENTROPY_IMPLEMENTATIONS = {
148-
CrossEntropyImpl.torch: torch_cross_entropy_forward_backward,
149-
CrossEntropyImpl.fused: fused_cross_entropy_forward_backward,
158+
CrossEntropyImpl.torch: _torch_cross_entropy_forward_backward,
159+
CrossEntropyImpl.fused: _fused_cross_entropy_forward_backward,
150160
CrossEntropyImpl.triton: triton_cross_entropy_forward_backward,
151161
}
152162

153163

154164
def cross_entropy_forward_backward(
155165
logits: torch.Tensor,
156166
target: torch.Tensor,
167+
loss_mask: torch.Tensor | None,
157168
grad_output: float | None,
158169
group: ProcessGroup | None = None,
159170
implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
@@ -169,15 +180,18 @@ def cross_entropy_forward_backward(
169180
if target_format == TargetFormat.labels:
170181
Assert.eq(target.shape, logits.shape[:-1])
171182
Assert.eq(target.dtype, torch.int64)
183+
assert loss_mask is None
172184
else:
173185
Assert.eq(target.shape, logits.shape)
174186
assert target.dtype.is_floating_point, target.dtype
187+
if loss_mask is not None:
188+
Assert.eq(loss_mask.shape, logits.shape[:-1])
175189
if group:
176190
Assert.eq(implementation, CrossEntropyImpl.fused)
177-
return fused_cross_entropy_forward_backward(
178-
logits, target, grad_output, logits_scale_factor, target_format, group
191+
return _fused_cross_entropy_forward_backward(
192+
logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group
179193
)
180194
else:
181195
return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation](
182-
logits, target, grad_output, logits_scale_factor, target_format
196+
logits, target, loss_mask, grad_output, logits_scale_factor, target_format
183197
)

fast_llm/functional/triton/cross_entropy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def triton_cross_entropy_forward_backward_kernel(
5757
def triton_cross_entropy_from_distribution_forward_backward_kernel(
5858
logits_ptr,
5959
target_ptr,
60+
loss_mask_ptr,
6061
grad_logits_ptr,
6162
losses_ptr,
6263
grad_losses,
@@ -73,6 +74,14 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel(
7374
col_offsets = tl.arange(0, block_size)
7475
mask = col_offsets < n_cols
7576

77+
if loss_mask_ptr is not None:
78+
loss_mask = tl.load(loss_mask_ptr + block_idx)
79+
if loss_mask == 0:
80+
tl.store(losses_ptr + block_idx, 0)
81+
if grad_losses is not None:
82+
tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=mask)
83+
return
84+
7685
logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to(
7786
tl.float32
7887
)
@@ -104,12 +113,15 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel(
104113
grad_logits = grad_losses * (exp_logits / sum_exp_logits - target)
105114
if logits_scale_factor != 1.0:
106115
grad_logits *= logits_scale_factor
116+
if loss_mask_ptr is not None:
117+
grad_logits = grad_logits
107118
tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask)
108119

109120

110121
def triton_cross_entropy_forward_backward(
111122
logits: torch.Tensor,
112123
target: torch.Tensor,
124+
loss_mask: torch.Tensor | None,
113125
grad_output: float | None,
114126
logits_scale_factor: float,
115127
target_format: TargetFormat,
@@ -146,9 +158,12 @@ def triton_cross_entropy_forward_backward(
146158
num_warps=num_warps,
147159
)
148160
else:
161+
if loss_mask is not None:
162+
assert loss_mask.is_contiguous()
149163
triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)](
150164
logits,
151165
target,
166+
loss_mask,
152167
grad_logits,
153168
losses,
154169
None if grad_output is None else grad_output / n_rows,

fast_llm/layers/language_model/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class LanguageModelKwargs:
3434
# TODO: These are generic
3535
labels = "labels"
3636
phase = "phase"
37+
loss_mask = "loss_mask"
3738

3839

3940
@config_class()

fast_llm/layers/language_model/head.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def _forward_backward(
146146
if self._config.distillation_model is None
147147
else f"{self._config.distillation_model}_logits"
148148
)
149+
# Loss mask for distillation. (Labels are already masked.)
150+
loss_mask = None
149151
if target is not None:
150152
if self._config.distillation_model is None:
151153
# MTP: Shift the labels
@@ -160,9 +162,14 @@ def _forward_backward(
160162
else:
161163
# Target is reference model logits.
162164
target = target.flatten(0, -2)
165+
loss_mask = kwargs.get(LanguageModelKwargs.loss_mask)
166+
if loss_mask is not None:
167+
loss_mask = loss_mask.flatten()
163168

164169
if self._sequence_parallel_logits:
165170
target = split_op(target, self._tensor_space.distributed.tensor_group, 0)
171+
if loss_mask is not None:
172+
loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0)
166173
do_grad = target is not None and self.training
167174
input_ = input_.detach().requires_grad_(do_grad)
168175
with torch.enable_grad():
@@ -174,7 +181,7 @@ def _forward_backward(
174181

175182
output_weights = self._get_output_weights(kwargs)
176183
loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split(
177-
ln_output.detach(), target, output_weights, grad_output, kwargs, losses
184+
ln_output.detach(), target, loss_mask, output_weights, grad_output, kwargs, losses
178185
)
179186

180187
if do_grad:
@@ -194,14 +201,15 @@ def _logits_cross_entropy_forward_backward_split(
194201
self,
195202
input_: torch.Tensor,
196203
target: torch.Tensor | None,
204+
loss_mask: torch.Tensor | None,
197205
weight: torch.Tensor,
198206
grad_output: float,
199207
kwargs: dict,
200208
losses: dict | None = None,
201209
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
202210
if self._cross_entropy_splits is None or target is None:
203211
loss, logit_input_grad = self._logits_cross_entropy_forward_backward(
204-
input_, target, weight, grad_output, kwargs, losses
212+
input_, target, loss_mask, weight, grad_output, kwargs, losses
205213
)
206214
if target is None:
207215
# TODO: Make a proper way of returning the model output.
@@ -214,12 +222,17 @@ def _logits_cross_entropy_forward_backward_split(
214222
grad_output /= self._cross_entropy_splits
215223
logit_input = input_.flatten(0, -2)
216224
logit_input_grad = torch.empty_like(logit_input)
217-
for logit_input_, target_, logit_input_grad_ in zip(
218-
logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size)
225+
for logit_input_, target_, loss_mask_, logit_input_grad_ in zip(
226+
logit_input.split(split_size),
227+
target.split(split_size),
228+
[None] * self._cross_entropy_splits if loss_mask is None else loss_mask.split(split_size),
229+
logit_input_grad.split(split_size),
230+
strict=True,
219231
):
220232
loss_, grad_ = self._logits_cross_entropy_forward_backward(
221233
logit_input_,
222234
target_,
235+
loss_mask_,
223236
weight,
224237
grad_output,
225238
kwargs,
@@ -240,6 +253,7 @@ def _logits_cross_entropy_forward_backward(
240253
self,
241254
input_: torch.Tensor,
242255
target: torch.Tensor | None,
256+
loss_mask: torch.Tensor | None,
243257
weight: torch.Tensor,
244258
grad_output: float,
245259
kwargs: dict,
@@ -298,6 +312,7 @@ def _logits_cross_entropy_forward_backward(
298312
loss, grad = cross_entropy_forward_backward(
299313
logits.flatten(0, -2),
300314
target,
315+
loss_mask,
301316
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
302317
grad_output=grad_output,
303318
implementation=self._cross_entropy_impl,

fast_llm/models/gpt/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ def _validate(self) -> None:
179179
Assert.eq(self.reference_models.keys(), {name})
180180
if self.model.base_model.use_absolute_position_embeddings:
181181
Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length)
182-
if self.model.base_model.distillation_model is not None:
183-
# TODO: Support loss masking for distillation?
184-
assert not self.batch.use_loss_masking_spans
185182
for reference_model in self.reference_models.values():
186183
Assert.none(reference_model.model.base_model.distillation_model)
187184
# TODO: Support more LM head features.

fast_llm/models/gpt/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,15 @@ def preprocess(
315315
valid_spans[:, 0].clamp_(min=sequence_offset)
316316
valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1)
317317
valid_spans -= sequence_offset
318+
loss_mask = torch.ones_like(labels, dtype=torch.bool)
318319
for start, end in valid_spans:
319320
if sequence_first:
320-
labels[start : end + 1, i] = -100
321+
loss_mask[start : end + 1, i] = False
321322
else:
322-
labels[i, start : end + 1] = -100
323+
loss_mask[i, start : end + 1] = False
324+
if self._config.distillation_model is not None:
325+
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
326+
labels = torch.where(loss_mask, labels, -100)
323327
kwargs[LanguageModelKwargs.labels] = labels
324328
kwargs.update(reference_logits[i])
325329

0 commit comments

Comments
 (0)