-
Notifications
You must be signed in to change notification settings - Fork 35
Multimodal (vision) support #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
if self._indexed_dataset.has_images and self._truncate_documents: | ||
raise RuntimeError( | ||
"Truncating documents with images is not yet supported. Please turn off truncation to use images." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does that mean in practice? documents with images that are longer than the sequence length are discarded?
|
||
# Calculate basic stats. | ||
if not self._truncate_documents: | ||
assert _extension_available, ( | ||
"The C++ extension for dataset sampling is missing." | ||
" Please make sure Fast-LLM is installed correctly." | ||
) | ||
long_docs_filter = document_sizes > self._parameters.sequence_length + 1 | ||
long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 | ||
ignored_documents = long_docs_filter.sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess yes, long docs with images will be ignored. ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but I'm worried about the added complexity. Main suggestions:
- Make vision into a separate model.
- Replace the dim and kwarg name changes with simpler alternatives.
- Break down some methods that have grown too big to be properly understandable.
- Avoid abbreviations when possible so names are self-explanatory for everyone.
@@ -133,24 +141,48 @@ def _sample(self) -> None: | |||
Create a `GPTSampledDataset` with the requested parameters. | |||
""" | |||
# Get the document sizes, the main information needed for sampling. | |||
document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) | |||
document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample
is getting way too long and complicated to follow. Can we please break it down by step and/or feature? Same for __getitem__
self._config = config | ||
self._tensor_space = tensor_space | ||
# TODO Soham: fix assert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size | ||
self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size | ||
self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size | ||
self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
super().__init__(config, tensor_space) | ||
|
||
# @torch.compile | ||
def _forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this use super()._forward
instead of copying it?
dtype=self._distributed_config.training_dtype.torch, | ||
) | ||
|
||
def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be broken down a bit?
@@ -71,6 +71,17 @@ class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointForma | |||
trust_remote_code: typing.ClassVar[bool] = True | |||
|
|||
|
|||
class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand, is all that added complexity just so we can call it llava
instead of either pixtral
or mistral
?
@@ -63,6 +70,10 @@ def __init__( | |||
if self._config.enable_dpo: # TODO better way to pass in? | |||
self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) | |||
|
|||
if self._config.vision_encoder.enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be more appropriate as a separate model, like we did for SSM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is intentionally part of the base GPT model. It ensures that multimodal support is transparent and inherited by all model variants (including SSM) without needing separate parallel class hierarchies. We're not introducing architectural silos unless there's a concrete need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an option, but it comes with important drawbacks. Fast-LLM models are not designed to support more than one task. We already pushed GPT
far beyond what a model should do with hacks and patches, and things are quickly getting difficult to manage, especially when it comes to preprocessing. Keeping everything in the same model means we have to break things down, modularize and simplify. We also need to ensure the PR has no side effect on non-vision models, which is difficult to do with the current implementation.
In short, keeping vision in the same model makes this PR significantly more difficult to merge. Either way, I don't think SSM support is an obstacle, because SSMs need to be integrated into the GPT model, and that's a lot easier and safer to do.
@@ -47,6 +49,10 @@ class LanguageModelBaseConfig(BaseModelConfig): | |||
desc="Configuration for the transformer architecture.", | |||
hint=FieldHint.architecture, | |||
) | |||
vision_encoder: VisionEncoderConfig = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be part of a "language model" config. How about adding this in a separate Vision
model instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree, see also above. The key advantages of integrating multimodal support directly into the GPT base model are:
-
Seamless transition between text-only and multimodal models.
-
Automatic inheritance of multimodal support by the existing SSM subclasses, without extra complexity or maintenance.
|
||
@classmethod | ||
def _create_config_converters(cls) -> list[ParamConverter]: | ||
cls.architecture = "MistralForCausalLM" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the architecture
classvars modified here?
@@ -137,7 +137,7 @@ def backward( | |||
assert self._mode.support_backward | |||
input_, output = grad_context | |||
output.backward(output_grad) | |||
return input_.grad | |||
return input_.grad if input_.grad is not None else torch.zeros_like(input_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not a good idea, it will cause unnecessary operations when not needed, which is most of the time. Why is it needed?
@@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: | |||
global _ACTIVATION_FN_MAP | |||
|
|||
_ACTIVATION_FN_MAP = { | |||
ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), | |||
ActivationType.gelu: torch.nn.functional.gelu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect, Fast-LLM uses the tanh version.
@@ -19,12 +19,26 @@ def get_document_sizes(self) -> np.ndarray: | |||
and derived classes should try to avoid holding the whole array im memory. | |||
""" | |||
|
|||
@abc.abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I separated this because there were several issues with the previous version.
@@ -1,8 +1,10 @@ | |||
import io |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reworked this file so the various concepts are properly compartmentalized. Next step would be to extract those components into separate classes, it's not strictly needed yet but will become if we keep adding stuff.
β¨ Description
Multi-modal support, starting with pixtral's vision encoder
π Type of change
Select all that apply:
π Changes
List the key changes introduced in this PR:
prepare
now optionally takes images and image positions (where they should appear in the text). Images are stored in the memmap file along with textGPTBaseModel
optionally has a vision encoder attached (can have audio/video encoders in future), which consists of a conv2D, vision transformer and a MLP adapter.