Skip to content

Commit c9cb304

Browse files
lkhphucwwwjnGriffintaurtianyu-l
authored
VLM: Onboarding native resolution, native aspect ratio, interleaved VLM training (#1615)
First PR to onboarding modern VLM training to torchtitan. ## Features: - Native Aspect Ratio: not limited to square crops. - Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails. - Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model. ## Design Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two additional hyperparameters, number of images per batch `N` and max image patches length `L`, then we pad the actual image patches to this fixed size. <img width="1398" height="840" alt="Screenshot 2025-08-21 at 16 21 57" src="https://github.com/user-attachments/assets/63fcbbc1-c587-4a63-8246-411cb72f5789" /> - After `tok_embedding`, we obtain tokens of shape `BxS`. - After `encoder`, we obtain visual tokens of shape `NxL`. - We extract the valid visual tokens only - Then scatter those tokens to their actual positions in the LLM input tokens. This requires the dataloader to handle the following aspect: - Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size - Convert images/videos to 1D sequence of patchs: - `rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)` - Pad all image patches sequence to a fixed length and return `pixel_values.shape == [N, L, D]` - Return a `grid_thw.shape == [N, L, 3]` to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values `-1`. This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio: - Depending on data mixtures, we can set dataloader's hyperparameters `N, L` to have minimal empty image padding (in batch dimension). - Use modern pytorch features (Flex Attention, compile etc) for efficient handling of different attention mask per (padding in sequence dimension). - Interface nicely with TP, PP, etc ## In this PR - Minimal interleaved Obelics dataloader with native resolution and aspect ratio. - The dataloader is currently very slow, as it need to download images from internet everytime you run. (Same thing for the current imp in the `multimodal` experiment). - Siglip2 model code, mostly based on HF. - VLM model code called `Llama3Siglip2` connecting the two vision encoder and language decoder. - Minimal infra code for debug model to run <img width="1672" height="1303" alt="Screenshot 2025-08-21 at 15 25 25" src="https://github.com/user-attachments/assets/c5c70ae2-04cf-4459-90d7-77d045291b88" /> ## Todo: - [x] Add support for captioning HF dataset that has images stored inside the dataset (CC12M like Flux exp?) so it's not super slow to load - Flex Attention for encoder. - Modify Llama3 tokenizer to add special tokens. - Script to combine Siglip2 + Llama3 weights and load. - Test Siglip2 encoder correctness. - Multimodal CE loss to correct for image token bias - All the parallelisms DP, CP, TP, PP. --------- Co-authored-by: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Co-authored-by: Ankit Singh <ankitsingh135@gmail.com> Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
1 parent 60645bc commit c9cb304

File tree

21 files changed

+2078
-9
lines changed

21 files changed

+2078
-9
lines changed

tests/assets/tokenizer/tokenizer.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2029,7 +2029,11 @@
20292029
"land": 1994,
20302030
"?\n": 1995,
20312031
" respect": 1996,
2032-
"ances": 1997
2032+
"ances": 1997,
2033+
"<|image|>": 1998,
2034+
"<|begin_of_image|>": 1999,
2035+
"<|end_of_image|>": 2000,
2036+
"<|pad|>": 2001
20332037
},
20342038
"merges": [
20352039
]

tests/assets/tokenizer/tokenizer_config.json

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,47 @@
1515
"rstrip": false,
1616
"single_word": false,
1717
"special": true
18+
},
19+
"1998": {
20+
"content": "<|image|>",
21+
"lstrip": false,
22+
"normalized": false,
23+
"rstrip": false,
24+
"single_word": false,
25+
"special": true
26+
},
27+
"1999": {
28+
"content": "<|begin_of_image|>",
29+
"lstrip": false,
30+
"normalized": false,
31+
"rstrip": false,
32+
"single_word": false,
33+
"special": true
34+
},
35+
"2000": {
36+
"content": "<|end_of_image|>",
37+
"lstrip": false,
38+
"normalized": false,
39+
"rstrip": false,
40+
"single_word": false,
41+
"special": true
42+
},
43+
"2001": {
44+
"content": "<|pad|>",
45+
"lstrip": false,
46+
"normalized": false,
47+
"rstrip": false,
48+
"single_word": false,
49+
"special": true
1850
}
1951
},
2052
"bos_token": "<|begin_of_text|>",
2153
"clean_up_tokenization_spaces": true,
2254
"eos_token": "<|end_of_text|>",
55+
"img_token": "<|image|>",
56+
"boi_token": "<|begin_of_image|>",
57+
"eoi_token": "<|end_of_image|>",
58+
"pad_token": "<|pad|>",
2359
"model_input_names": [
2460
"input_ids",
2561
"attention_mask"

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
import torchtitan.experiments.llama4 # noqa: F401
88
import torchtitan.experiments.qwen3
99
import torchtitan.experiments.simple_fsdp # noqa: F401
10+
import torchtitan.experiments.vlm # noqa: F401

torchtitan/experiments/llama4/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
dim=256,
3131
n_layers=6,
3232
n_heads=16,
33-
vocab_size=2000,
33+
vocab_size=2048,
3434
rope_theta=500000,
3535
),
3636
"17bx16e": TransformerModelArgs(
@@ -59,7 +59,7 @@
5959
dim=256,
6060
n_layers=6,
6161
n_heads=16,
62-
vocab_size=2000,
62+
vocab_size=2048,
6363
rope_theta=500000,
6464
every_n_layers_nope=4,
6565
fixed_attn_block_size=256,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Vision Language Model training in `torchtitan`
2+
3+
**under active development**
4+
5+
This folder showcases how to train modern Vision Language Model (vlm) in torchtitan.
6+
7+
8+
## Features:
9+
- Native Aspect Ratio: not limited to square crops.
10+
- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
11+
- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.
12+
13+
14+
## Design
15+
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
16+
Then we scatter the patch embeddings to their actual positions in the LLM input tokens.
17+
18+
<img width="1398" height="840" alt="Screenshot 2025-08-21 at 16 21 57" src="https://github.com/user-attachments/assets/63fcbbc1-c587-4a63-8246-411cb72f5789" />
19+
20+
- After `tok_embedding`, we obtain tokens of shape `BxS`.
21+
- After `encoder`, we obtain visual tokens of shape `NxL`.
22+
- We extract the valid visual tokens only
23+
- Then scatter those tokens to their actual positions in the LLM input tokens.
24+
25+
26+
This results in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio:
27+
- Depending on data mixtures, we can set dataloader's hyperparameters `N, L` to have minimal empty image padding (in batch dimension).
28+
- We use modern Pytorch features like FlexAttention and torch.compile to efficient efficiently handle variable sequence length.
29+
- Interface nicely with TP, PP, etc
30+
31+
32+
## Implementation
33+
34+
### Dataloader
35+
This approach requires the dataloader to handle the following aspect:
36+
- [x] Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size
37+
- [x] Convert images/videos to 1D sequence of patchs:
38+
- `rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)`
39+
- Pad all image patches sequence to a fixed length and return `pixel_values.shape == [N, L, D]`
40+
- [x] Return a `grid_thw.shape == [N, L, 3]` to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values `-1`.
41+
- [x] LLM Sample / Document Packing.
42+
- [x] Captioning dataset: CC12M
43+
- [x] Interleaved dataset: Obelics
44+
45+
46+
47+
### Model
48+
We also need a pretrained vision encoder with support for native resolution and aspect ratio. There is relatively few Vision Encoder that have this capability up until recently, including Siglip2, AimV2, and most recently DINOv3.
49+
- [ ] Currently we support Siglip2 encoder using Positional Embedding interpolation approach.
50+
- [x] Base modelling code.
51+
- [ ] Weights conversion and loading from HF.
52+
- [x] FSDP for both Encoder and Decoder
53+
- [x] Context Parallel for LLM only, since we will use FlexAttention for Encoder.
54+
- [ ] FlexAttention for with different seq len per image.
55+
- [ ] Compile for Encoder + Deocoder
56+
- [ ] Tensor Parallel
57+
- [ ] Pipeline Parallel
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import asdict, replace
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.components.validate import build_validator
14+
from torchtitan.models.llama3 import llama3_configs
15+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
17+
from .datasets.mm_datasets import build_mm_dataloader
18+
from .infra.parallelize import parallelize_vlm
19+
from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs
20+
from .model.model import Llama3Siglip2Transformer
21+
22+
__all__ = [
23+
"parallelize_vlm",
24+
"Llama3Siglip2ModelArgs",
25+
"Llama3Siglip2Transformer",
26+
"llama3_siglip2_configs",
27+
]
28+
29+
30+
llama3_siglip2_configs = {
31+
"debugmodel": Llama3Siglip2ModelArgs(
32+
**asdict(replace(llama3_configs["debugmodel"], vocab_size=2048)),
33+
encoder=Siglip2ModelArgs(
34+
dim=128,
35+
ffn_dim=256,
36+
n_layers=4,
37+
n_heads=2,
38+
),
39+
),
40+
}
41+
42+
43+
register_train_spec(
44+
TrainSpec(
45+
name="llama3-siglip2",
46+
model_cls=Llama3Siglip2Transformer,
47+
model_args=llama3_siglip2_configs,
48+
parallelize_fn=parallelize_vlm,
49+
pipelining_fn=None,
50+
build_optimizers_fn=build_optimizers,
51+
build_lr_schedulers_fn=build_lr_schedulers,
52+
build_dataloader_fn=build_mm_dataloader,
53+
build_tokenizer_fn=build_hf_tokenizer,
54+
build_loss_fn=build_cross_entropy_loss,
55+
build_validator_fn=build_validator,
56+
)
57+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
10+
@dataclass
11+
class Data:
12+
max_images_per_batch: int = 10
13+
"""Vision encoder batch size (N)"""
14+
max_patches_per_image: int = 256
15+
"""Vision encoder sequence length (L)"""
16+
patch_size: int = 16
17+
""" Patch size of the vision encoder.
18+
For example, image size 256x256, patch size 16
19+
Number of visual tokens is: (256/16)**2=256
20+
"""
21+
spatial_merge_size: int = 1
22+
""" Spatially merge visual tokens after encoder. Default 1 means no merging.
23+
For example: image size 256x256, patch size 16, spaitl merge size is 2
24+
Number of visual tokens for the LLM: (256/16/2)**2 = 8
25+
"""
26+
packing_buffer_size: int = 0
27+
""" Set to a value >0 to enable sample packing.
28+
This control the buffer uses to store training samples avaliable for packing.
29+
"""
30+
31+
32+
@dataclass
33+
class JobConfig:
34+
data: Data = field(default_factory=Data)

0 commit comments

Comments
 (0)