Skip to content

feat: dynamic data mixing based on per-domain loss#9142

Open
amanyara wants to merge 2 commits intomodelscope:mainfrom
amanyara:feat/dynamic-data-mixing
Open

feat: dynamic data mixing based on per-domain loss#9142
amanyara wants to merge 2 commits intomodelscope:mainfrom
amanyara:feat/dynamic-data-mixing

Conversation

@amanyara
Copy link
Copy Markdown

Summary

  • Add online dynamic data mixing that adjusts sampling weights during training based on per-domain loss via softmax(L_i / T)
  • Higher-loss domains automatically receive more sampling, helping the model focus on weaker areas
  • Reuses existing channel field + enable_channel_loss infrastructure

Changes

  • swift/trainers/arguments.py: 4 new args (dynamic_mix, dynamic_mix_update_steps, dynamic_mix_temperature, dynamic_mix_warmup_steps), auto-enable enable_channel_loss, mutual exclusion checks with streaming/packing/interleave_prob/group_by_length
  • swift/dataloader/shard.py: DynamicMixBatchSampler — per-domain weighted sampling with runtime probability updates, deterministic across ranks via shared Generator
  • swift/callbacks/dynamic_mix.py: DynamicMixingCallback — captures loss_{channel} from logs, computes softmax(L/T) every N steps, updates sampler weights, logs mix_prob_* to tensorboard/wandb
  • swift/callbacks/mapping.py: Register dynamic_mix callback
  • swift/dataloader/__init__.py: Export DynamicMixBatchSampler
  • swift/trainers/mixin.py: Integrate DynamicMixBatchSampler in get_train_dataloader(), add _build_domain_indices() helper
  • tests/train/test_dynamic_mix.py: Unit tests for sampler and callback

Usage

swift sft \
    --model Qwen/Qwen2.5-7B \
    --dataset math_data code_data general_data \
    --dynamic_mix true \
    --dynamic_mix_temperature 2.0 \
    --dynamic_mix_update_steps 100 \
    --dynamic_mix_warmup_steps 50

Requires "channel" field in each JSONL record (e.g. {"messages": [...], "channel": "math"}).

Limitations (v1)

  • Not compatible with packing, streaming, interleave_prob, or group_by_length
  • Sampling probabilities are not persisted on checkpoint resume

Test plan

  • Unit tests pass (tests/train/test_dynamic_mix.py)
  • Integration test with small model + multi-channel dataset
  • Verify loss_{channel} and mix_prob_{channel} appear in logs
  • Verify high-loss domain weights increase over time

🤖 Generated with Claude Code

Implement online dynamic data mixing that adjusts sampling weights
during training based on per-domain loss using softmax(L/T). Higher
loss domains get more sampling. This adds DynamicMixBatchSampler,
DynamicMixingCallback, and 4 new args (dynamic_mix, dynamic_mix_temperature,
dynamic_mix_update_steps, dynamic_mix_warmup_steps). Reuses existing
channel loss infrastructure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements dynamic data mixing, allowing sampling weights to be adjusted based on per-domain loss. Key additions include the DynamicMixingCallback, DynamicMixBatchSampler, and corresponding configuration arguments. Review feedback recommends optimizing the callback to avoid redundant dataloader initialization and improving the sampler's performance by vectorizing the multinomial sampling logic.

Comment thread swift/callbacks/dynamic_mix.py Outdated

def on_train_begin(self, args, state, control, **kwargs):
# Get sampler reference from the dataloader
dataloader = self.trainer.get_train_dataloader()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling self.trainer.get_train_dataloader() here is redundant and potentially inefficient. In the transformers.Trainer lifecycle, the training dataloader is already created and assigned to self.trainer.train_dataloader before on_train_begin is called. Re-calling the getter may trigger unnecessary re-initialization logic.

Suggested change
dataloader = self.trainer.get_train_dataloader()
dataloader = getattr(self.trainer, 'train_dataloader', None)

Comment thread swift/dataloader/shard.py Outdated
Comment on lines +164 to +166
for _ in range(self.batch_size * self.world_size):
domain_idx = torch.multinomial(prob_tensor, 1, generator=generator).item()
domain_name = self.domain_names[domain_idx]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation calls torch.multinomial inside a loop for every single sample in the global batch. This can become a significant performance bottleneck as the batch_size or world_size increases. It is much more efficient to sample all domain indices for the entire batch in a single vectorized call.

Suggested change
for _ in range(self.batch_size * self.world_size):
domain_idx = torch.multinomial(prob_tensor, 1, generator=generator).item()
domain_name = self.domain_names[domain_idx]
sampled_indices = torch.multinomial(prob_tensor, self.batch_size * self.world_size,
replacement=True, generator=generator).tolist()
for domain_idx in sampled_indices:
domain_name = self.domain_names[domain_idx]

- Use `self.trainer.train_dataloader` instead of re-calling
  `get_train_dataloader()` in `on_train_begin` to avoid redundant
  dataloader initialization
- Vectorize `torch.multinomial` call in `DynamicMixBatchSampler.__iter__`
  to sample all domain indices per batch in one shot instead of looping

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant