|
| 1 | +# Vision Language Model training in `torchtitan` |
| 2 | + |
| 3 | +**under active development** |
| 4 | + |
| 5 | +This folder showcases how to train modern Vision Language Model (vlm) in torchtitan. |
| 6 | + |
| 7 | + |
| 8 | +## Features: |
| 9 | +- Native Aspect Ratio: not limited to square crops. |
| 10 | +- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails. |
| 11 | +- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model. |
| 12 | + |
| 13 | + |
| 14 | +## Design |
| 15 | +Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size. |
| 16 | +Then we scatter the patch embeddings to their actual positions in the LLM input tokens. |
| 17 | + |
| 18 | +<img width="1398" height="840" alt="Screenshot 2025-08-21 at 16 21 57" src="https://github.com/user-attachments/assets/63fcbbc1-c587-4a63-8246-411cb72f5789" /> |
| 19 | + |
| 20 | +- After `tok_embedding`, we obtain tokens of shape `BxS`. |
| 21 | +- After `encoder`, we obtain visual tokens of shape `NxL`. |
| 22 | +- We extract the valid visual tokens only |
| 23 | +- Then scatter those tokens to their actual positions in the LLM input tokens. |
| 24 | + |
| 25 | + |
| 26 | +This results in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio: |
| 27 | +- Depending on data mixtures, we can set dataloader's hyperparameters `N, L` to have minimal empty image padding (in batch dimension). |
| 28 | +- We use modern Pytorch features like FlexAttention and torch.compile to efficient efficiently handle variable sequence length. |
| 29 | +- Interface nicely with TP, PP, etc |
| 30 | + |
| 31 | + |
| 32 | +## Implementation |
| 33 | + |
| 34 | +### Dataloader |
| 35 | +This approach requires the dataloader to handle the following aspect: |
| 36 | +- [x] Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size |
| 37 | +- [x] Convert images/videos to 1D sequence of patchs: |
| 38 | + - `rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)` |
| 39 | + - Pad all image patches sequence to a fixed length and return `pixel_values.shape == [N, L, D]` |
| 40 | +- [x] Return a `grid_thw.shape == [N, L, 3]` to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values `-1`. |
| 41 | +- [x] LLM Sample / Document Packing. |
| 42 | +- [x] Captioning dataset: CC12M |
| 43 | +- [x] Interleaved dataset: Obelics |
| 44 | + |
| 45 | + |
| 46 | + |
| 47 | +### Model |
| 48 | +We also need a pretrained vision encoder with support for native resolution and aspect ratio. There is relatively few Vision Encoder that have this capability up until recently, including Siglip2, AimV2, and most recently DINOv3. |
| 49 | +- [ ] Currently we support Siglip2 encoder using Positional Embedding interpolation approach. |
| 50 | + - [x] Base modelling code. |
| 51 | + - [ ] Weights conversion and loading from HF. |
| 52 | +- [x] FSDP for both Encoder and Decoder |
| 53 | +- [x] Context Parallel for LLM only, since we will use FlexAttention for Encoder. |
| 54 | +- [ ] FlexAttention for with different seq len per image. |
| 55 | +- [ ] Compile for Encoder + Deocoder |
| 56 | +- [ ] Tensor Parallel |
| 57 | +- [ ] Pipeline Parallel |
0 commit comments