-
Notifications
You must be signed in to change notification settings - Fork 33
fix loss masking and padding #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
06b76cd
0364fc6
d5cd5e8
02fbb7e
e4925f0
b31987a
9db9cdf
a81ba7f
eab15c0
e0e0f78
49760e5
6b2b598
29cb0a8
8fe536f
a8c63c0
1dc76a6
4266f02
29cc709
43c868a
b3d6d3c
650f0e2
ba11c56
68fc363
439241e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,7 +89,7 @@ def __init__( | |
self._indexed_dataset = indexed_dataset | ||
self._config = sampling.config | ||
self._parameters = sampling.parameters | ||
self._truncate_documents = sampling.truncate_documents | ||
self._truncate_documents = sampling.parameters.truncate_documents | ||
self._device = torch.device("cuda" if self._config.gpu else "cpu") | ||
|
||
if sampling.cache_directory is None: | ||
|
@@ -144,7 +144,7 @@ def _sample(self) -> None: | |
" Please make sure Fast-LLM is installed correctly." | ||
) | ||
long_docs_filter = document_sizes > self._parameters.sequence_length + 1 | ||
ignored_documents = sum(long_docs_filter) | ||
ignored_documents = long_docs_filter.sum().item() | ||
if ignored_documents: | ||
log_main_rank( | ||
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", | ||
|
@@ -201,9 +201,10 @@ def _sample(self) -> None: | |
|
||
if self._yaml_path is not None and self._yaml_path.is_file(): | ||
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) | ||
# Hack to make sure unshuffled tokens are loaded | ||
if not self._truncate_documents: | ||
yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"] | ||
self._load_yaml_data(yaml_data) | ||
if not self._truncate_documents and not self._parameters.use_preference_loss_spans: | ||
del loaded_yaml_data["unshuffled_tokens"] | ||
|
||
if loaded_yaml_data != yaml_data: | ||
raise RuntimeError( | ||
|
@@ -467,6 +468,12 @@ def __getitem__(self, index: int) -> typing.Any: | |
else: | ||
# Move on to the next sample. | ||
token_count += padding_size | ||
elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: | ||
if token_count + document_size == token_start: | ||
# Document belongs to the current sample but the condition below will include it for the next sample | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not following, why are we ignoring the document if it belongs to the current sample? (Also it clearly belongs to the previous sample) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I understand seems like in this scenario well have Seems to me the actual fix would be to replace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yes, i got confused because i faced this issue in the multimodal branch but it only occurs when there's images right after the text tokens. Will handle it there |
||
token_count += document_size | ||
document_sampling_index += 1 | ||
continue | ||
|
||
# Determine if the document belongs to the requested sample. | ||
if token_count + document_size >= token_start: | ||
|
@@ -487,7 +494,7 @@ def __getitem__(self, index: int) -> typing.Any: | |
0, | ||
self._parameters.sequence_length + self._parameters.extra_tokens, | ||
) | ||
if span[1] > span[0]: | ||
if span[1] >= span[0]: | ||
loss_masking_spans.append(span) | ||
|
||
# Go to the next document. | ||
|
@@ -547,7 +554,7 @@ def __init__( | |
): | ||
assert isinstance(sampling, GPTSamplingData) | ||
self._indexed_dataset = indexed_dataset | ||
if not sampling.truncate_documents: | ||
if not sampling.parameters.truncate_documents: | ||
raise NotImplementedError( | ||
"Legacy sampling only supports document truncation. Please use the latest dataset format." | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,8 +146,13 @@ def _fused_cross_entropy_forward_backward( | |
per_sample_loss = sum_exp_logits.log() - predicted_logits | ||
if loss_mask is not None: | ||
per_sample_loss = per_sample_loss * loss_mask | ||
|
||
loss = per_sample_loss.mean() | ||
unmasked_inputs = loss_mask.sum() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This still cause a cuda sync. You can just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my own understanding, how can i check whether a pytorch op causes cuda sync? |
||
if unmasked_inputs: | ||
loss = per_sample_loss.sum() / unmasked_inputs | ||
else: | ||
loss = torch.tensor(0.0, dtype=per_sample_loss.dtype, device=per_sample_loss.device) | ||
else: | ||
loss = per_sample_loss.mean() | ||
if target_format != TargetFormat.labels and group is not None: | ||
all_reduce(loss, op=ReduceOp.MEAN, group=group) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we need to move this, but if we do we need to add backward compatibility.