Skip to content

Multimodal (vision) support #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 99 commits into
base: main
Choose a base branch
from
Open

Multimodal (vision) support #227

wants to merge 99 commits into from

Conversation

sohamparikh
Copy link
Member

@sohamparikh sohamparikh commented Apr 8, 2025

✨ Description

Multi-modal support, starting with pixtral's vision encoder

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

List the key changes introduced in this PR:

  1. prepare now optionally takes images and image positions (where they should appear in the text). Images are stored in the memmap file along with text
  2. GPTBaseModel optionally has a vision encoder attached (can have audio/video encoders in future), which consists of a conv2D, vision transformer and a MLP adapter.

@tscholak tscholak mentioned this pull request May 9, 2025
@sohamparikh sohamparikh marked this pull request as ready for review July 9, 2025 18:38
@tscholak tscholak mentioned this pull request Jul 11, 2025
25 tasks
if self._indexed_dataset.has_images and self._truncate_documents:
raise RuntimeError(
"Truncating documents with images is not yet supported. Please turn off truncation to use images."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does that mean in practice? documents with images that are longer than the sequence length are discarded?


# Calculate basic stats.
if not self._truncate_documents:
assert _extension_available, (
"The C++ extension for dataset sampling is missing."
" Please make sure Fast-LLM is installed correctly."
)
long_docs_filter = document_sizes > self._parameters.sequence_length + 1
long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1
ignored_documents = long_docs_filter.sum().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess yes, long docs with images will be ignored. ok

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but I'm worried about the added complexity. Main suggestions:

  • Make vision into a separate model.
  • Replace the dim and kwarg name changes with simpler alternatives.
  • Break down some methods that have grown too big to be properly understandable.
  • Avoid abbreviations when possible so names are self-explanatory for everyone.

@@ -133,24 +141,48 @@ def _sample(self) -> None:
Create a `GPTSampledDataset` with the requested parameters.
"""
# Get the document sizes, the main information needed for sampling.
document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device)
document_sizes, image_sizes = self._indexed_dataset.get_document_sizes()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample is getting way too long and complicated to follow. Can we please break it down by step and/or feature? Same for __getitem__

self._config = config
self._tensor_space = tensor_space
# TODO Soham: fix assert
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size
self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size
self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size
self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

super().__init__(config, tensor_space)

# @torch.compile
def _forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this use super()._forward instead of copying it?

dtype=self._distributed_config.training_dtype.torch,
)

def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be broken down a bit?

@@ -71,6 +71,17 @@ class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointForma
trust_remote_code: typing.ClassVar[bool] = True


class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, is all that added complexity just so we can call it llava instead of either pixtral or mistral?

@@ -63,6 +70,10 @@ def __init__(
if self._config.enable_dpo: # TODO better way to pass in?
self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space))

if self._config.vision_encoder.enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be more appropriate as a separate model, like we did for SSM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is intentionally part of the base GPT model. It ensures that multimodal support is transparent and inherited by all model variants (including SSM) without needing separate parallel class hierarchies. We're not introducing architectural silos unless there's a concrete need.

Copy link
Collaborator

@jlamypoirier jlamypoirier Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an option, but it comes with important drawbacks. Fast-LLM models are not designed to support more than one task. We already pushed GPT far beyond what a model should do with hacks and patches, and things are quickly getting difficult to manage, especially when it comes to preprocessing. Keeping everything in the same model means we have to break things down, modularize and simplify. We also need to ensure the PR has no side effect on non-vision models, which is difficult to do with the current implementation.

In short, keeping vision in the same model makes this PR significantly more difficult to merge. Either way, I don't think SSM support is an obstacle, because SSMs need to be integrated into the GPT model, and that's a lot easier and safer to do.

@@ -47,6 +49,10 @@ class LanguageModelBaseConfig(BaseModelConfig):
desc="Configuration for the transformer architecture.",
hint=FieldHint.architecture,
)
vision_encoder: VisionEncoderConfig = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be part of a "language model" config. How about adding this in a separate Vision model instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree, see also above. The key advantages of integrating multimodal support directly into the GPT base model are:

  • Seamless transition between text-only and multimodal models.

  • Automatic inheritance of multimodal support by the existing SSM subclasses, without extra complexity or maintenance.

@tscholak tscholak changed the title WIP: multimodal support Multimodal (vision) support Jul 17, 2025

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "MistralForCausalLM"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the architecture classvars modified here?

@@ -137,7 +137,7 @@ def backward(
assert self._mode.support_backward
input_, output = grad_context
output.backward(output_grad)
return input_.grad
return input_.grad if input_.grad is not None else torch.zeros_like(input_)
Copy link
Collaborator

@jlamypoirier jlamypoirier Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a good idea, it will cause unnecessary operations when not needed, which is most of the time. Why is it needed?

@@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None:
global _ACTIVATION_FN_MAP

_ACTIVATION_FN_MAP = {
ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
ActivationType.gelu: torch.nn.functional.gelu,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, Fast-LLM uses the tanh version.

@@ -19,12 +19,26 @@ def get_document_sizes(self) -> np.ndarray:
and derived classes should try to avoid holding the whole array im memory.
"""

@abc.abstractmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated this because there were several issues with the previous version.

@@ -1,8 +1,10 @@
import io
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworked this file so the various concepts are properly compartmentalized. Next step would be to extract those components into separate classes, it's not strictly needed yet but will become if we keep adding stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants