feat: dynamic data mixing based on per-domain loss#9142
feat: dynamic data mixing based on per-domain loss#9142amanyara wants to merge 2 commits intomodelscope:mainfrom
Conversation
Implement online dynamic data mixing that adjusts sampling weights during training based on per-domain loss using softmax(L/T). Higher loss domains get more sampling. This adds DynamicMixBatchSampler, DynamicMixingCallback, and 4 new args (dynamic_mix, dynamic_mix_temperature, dynamic_mix_update_steps, dynamic_mix_warmup_steps). Reuses existing channel loss infrastructure. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements dynamic data mixing, allowing sampling weights to be adjusted based on per-domain loss. Key additions include the DynamicMixingCallback, DynamicMixBatchSampler, and corresponding configuration arguments. Review feedback recommends optimizing the callback to avoid redundant dataloader initialization and improving the sampler's performance by vectorizing the multinomial sampling logic.
|
|
||
| def on_train_begin(self, args, state, control, **kwargs): | ||
| # Get sampler reference from the dataloader | ||
| dataloader = self.trainer.get_train_dataloader() |
There was a problem hiding this comment.
Calling self.trainer.get_train_dataloader() here is redundant and potentially inefficient. In the transformers.Trainer lifecycle, the training dataloader is already created and assigned to self.trainer.train_dataloader before on_train_begin is called. Re-calling the getter may trigger unnecessary re-initialization logic.
| dataloader = self.trainer.get_train_dataloader() | |
| dataloader = getattr(self.trainer, 'train_dataloader', None) |
| for _ in range(self.batch_size * self.world_size): | ||
| domain_idx = torch.multinomial(prob_tensor, 1, generator=generator).item() | ||
| domain_name = self.domain_names[domain_idx] |
There was a problem hiding this comment.
The current implementation calls torch.multinomial inside a loop for every single sample in the global batch. This can become a significant performance bottleneck as the batch_size or world_size increases. It is much more efficient to sample all domain indices for the entire batch in a single vectorized call.
| for _ in range(self.batch_size * self.world_size): | |
| domain_idx = torch.multinomial(prob_tensor, 1, generator=generator).item() | |
| domain_name = self.domain_names[domain_idx] | |
| sampled_indices = torch.multinomial(prob_tensor, self.batch_size * self.world_size, | |
| replacement=True, generator=generator).tolist() | |
| for domain_idx in sampled_indices: | |
| domain_name = self.domain_names[domain_idx] |
- Use `self.trainer.train_dataloader` instead of re-calling `get_train_dataloader()` in `on_train_begin` to avoid redundant dataloader initialization - Vectorize `torch.multinomial` call in `DynamicMixBatchSampler.__iter__` to sample all domain indices per batch in one shot instead of looping Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
softmax(L_i / T)channelfield +enable_channel_lossinfrastructureChanges
swift/trainers/arguments.py: 4 new args (dynamic_mix,dynamic_mix_update_steps,dynamic_mix_temperature,dynamic_mix_warmup_steps), auto-enableenable_channel_loss, mutual exclusion checks with streaming/packing/interleave_prob/group_by_lengthswift/dataloader/shard.py:DynamicMixBatchSampler— per-domain weighted sampling with runtime probability updates, deterministic across ranks via shared Generatorswift/callbacks/dynamic_mix.py:DynamicMixingCallback— capturesloss_{channel}from logs, computessoftmax(L/T)every N steps, updates sampler weights, logsmix_prob_*to tensorboard/wandbswift/callbacks/mapping.py: Registerdynamic_mixcallbackswift/dataloader/__init__.py: ExportDynamicMixBatchSamplerswift/trainers/mixin.py: IntegrateDynamicMixBatchSampleringet_train_dataloader(), add_build_domain_indices()helpertests/train/test_dynamic_mix.py: Unit tests for sampler and callbackUsage
swift sft \ --model Qwen/Qwen2.5-7B \ --dataset math_data code_data general_data \ --dynamic_mix true \ --dynamic_mix_temperature 2.0 \ --dynamic_mix_update_steps 100 \ --dynamic_mix_warmup_steps 50Requires
"channel"field in each JSONL record (e.g.{"messages": [...], "channel": "math"}).Limitations (v1)
interleave_prob, orgroup_by_lengthTest plan
tests/train/test_dynamic_mix.py)loss_{channel}andmix_prob_{channel}appear in logs🤖 Generated with Claude Code