Skip to content

Conversation

@jialei777
Copy link
Collaborator

@jialei777 jialei777 commented Aug 12, 2025

Worked locally with 4 chips fsdp=4: export PJRT_DEVICE=TPU; export TORCHPRIME_TPU_TYPE=v6e-4 && python torchprime/torch_xla_models/train.py model=flex-qwen-1b

MFU: 0.21

On a v5p-128 cluster with command tp run --name jialei-0812-qwen-fsdp32tensor2 torchprime/torch_xla_models/train.py model=flex-qwen-1b task.global_batch_size=64 ici_mesh.fsdp=32 ici_mesh.tensor=2

  • fsdp64: hang????
  • fsdp 32 tp2: finished MFU 0.22
  • fsdp 16 tp4: finished: MFU 0.19
  • fsdp 8 tp8: finished, MFU 0.11

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant