Skip to content

[Feat] Adding Intern-S1 #39722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Conversation

hhaAndroid
Copy link

@hhaAndroid hhaAndroid commented Jul 28, 2025

Adding Intern-S1

This PR adds the support of codes for the Intern-S1 models. Please visit https://huggingface.co/internlm/Intern-S1

Features

  • Strong performance across language and vision reasoning benchmarks, especially scientific tasks.
  • Continuously pretrained on a massive 5T token dataset, with over 50% specialized scientific data, embedding deep domain expertise.
  • Dynamic tokenizer enables native understanding of molecular formulas, protein sequences, and seismic signals.

Usage

from transformers import AutoProcessor, AutoModelForImageTextToText
import torch

model_checkpoint = 'xxxx'
processor = AutoProcessor.from_pretrained(model_checkpoint)
model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map="auto", torch_dtype="auto")
messages = [
        {
            "role": "user",
            "content": [
                {"type": "image",
                 "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
                {"type": "text", "text": "Please describe the image shortly."},
            ],
        }
    ]

inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True,
                                           return_tensors="pt").to(model.device, dtype=torch.bfloat16)

generate_ids = model.generate(**inputs, max_new_tokens=32768)
decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(decoded_output)

Progress

  • add modeling py
  • add tokenizer.py
  • add test
  • fix lint

@Rocketknight1
Copy link
Member

cc @zucchini-nlp for VLMs!

@zucchini-nlp
Copy link
Member

Taking a look tomorrow

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for late review, got caught up in another model release.

The model looks very much like InternVL and I want us to re-use as much code as possible with modular. In long term it will make maintenance easier for us, and much much faster review process if we can spot the differences between models. I left comments below about which class can be re-used from where

Feel free to tag me when it is ready for re-review or if you need any assistance :)

Comment on lines +969 to +973
if self.is_moe_model:
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits
)
kwargs['output_router_logits'] = output_router_logits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there be model checkpoints which are MoE and not MoE? If yes, for non-moe ones we can use InternVL class prob, looks identical to me tbh

Otherwise we need to write InternS1 model code as MoE only

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea, but we hope that InternS1 can serve as a unified model, supporting not only N-vision + dense but also N-vision + MoE. I’m not sure whether Hugging Face’s guidelines strictly require dense and MoE to be separated into two model folders. Looking forward to your reply! @zucchini-nlp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformers philosophy is to add support for an arch only when there is an official pre-trained checkpoint for it. So if InternS1 has MoE and dense checkpoint, we'd need to support both. That can be done by having separate InternS1Moe decoder layer and InternS1Dense decoder layer, imo that is more preferable

Otherwise let's just add the arch to support the released checkpoints

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your reply. InternS1 will have two sets of weights: ViT+235B MoE and ViT+8B MoE. Currently, they can be automatically constructed via AutoModel.from_config(config.text_config). The LLM module directly calls the internally supported LLM models from Transformers, so we don’t need to distinguish between dense layers or MoE layers.

In this case, what would be the most reasonable way to provide support?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool, then it makes everything much easier if we can just instantiate an existing LLM from transformers. All we need to do is to make sure configs are saved with correct model_type and the code calls AutoModel.from_config(config.text_config)

I see that the is_moe_model is needed only to decide on output_router_logits. Actually we can use can_record_outputs attribute which handles all extra model outputs (for ex in Coeher2Vision). It can be a bit tricky for multimodal so lmk if you need help with it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think you!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look at can_record_output as using it is much cleaner than checking config values at init time

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, interns1

@hhaAndroid
Copy link
Author

@zucchini-nlp Hello, I've revised a new version as requested. However, regarding the usage of can_record_outputs, I'm unsure if it fits my scenario. After adapting it for the MoE model, I need to pass the output_router_logits parameter to the MoE LLM, rather than just capturing the output results. Looking forward to your next round of review comments.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super clean after using the modular, thanks for iterating! There are some bits left especially moving all code into modular. We shouldn't be importing from other models unless it is in the modular file :)

Also, I will take a look at the can_record_output thing this week, would be nice to get it sorted

Update: Oh btw, let's make CI green and fix failing tests. You might need to rebase if unrelated test are failing

Comment on lines +21 to +25

class InternS1VisionConfig(InternVLVisionConfig):
r"""
This is the configuration class to store the configuration of a [`InternS1VisionModel`]. It is used to instantiate
an InternS1VisionModel model according to the specified arguments, defining the model architecture. Instantiating a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to move this to modular file. Inheriting from other models is not allowed and against transformers one model - one file philosophy

Comment on lines +57 to +70

class InternS1VisionRMSNorm(InternVLVisionRMSNorm):
pass


class InternS1VisionAttention(InternVLVisionAttention):
pass


@auto_docstring
class InternS1VisionPreTrainedModel(InternVLVisionPreTrainedModel):
config: InternS1VisionConfig


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks perfect and much less code 🤩

NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternS1VisionRMSNorm}


class InternS1VisionLayer(GradientCheckpointingLayer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the only difference is the drop_path thus it can be inherited and we only write out the new drop_path modules

For ex, this is how it was done for attn module with new QK-Norm layers

class Qwen3Attention(LlamaAttention):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(

return layer_output, attention_weights


class InternS1VisionEncoder(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward is identical, so we can inherit and override the init only



@auto_docstring
class InternS1VisionModel(InternS1VisionPreTrainedModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment not addressed, can be copied from InternVLVisionModel

Comment on lines +353 to +370
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id

n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)

if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was recently refactored to get_placeholder_mask helper method in all VLMs. Can you update here as well?

Comment on lines +969 to +973
if self.is_moe_model:
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits
)
kwargs['output_router_logits'] = output_router_logits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look at can_record_output as using it is much cleaner than checking config values at init time

self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
audio=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

audio is not needed, we decided to not add unused modalities for call

Comment on lines +46 to +48
# TODO: It will support temporal information processing in the future.
class InternS1Processor(InternVLProcessor):
r"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is inheritance, it has to be defined in modular file. AFAIU the diff is meant to be in docstring only? We can let modualr handle all re-naming and simply as below will be enough

# modular_interns1.py
class InternS1Processor(InternVLProcessor):
    pass

Comment on lines +24 to +34
class InternS1VideoProcessorInitKwargs(VideosKwargs):
initial_shift: Union[bool, float, int]


@requires(backends=("torchvision",))
class InternS1VideoProcessor(InternVLVideoProcessor):
valid_kwargs = InternS1VideoProcessorInitKwargs

def __init__(self, **kwargs: Unpack[InternS1VideoProcessorInitKwargs]):
super().__init__(**kwargs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, only defining that video processor is identical in modular will copy everything else for you

class InternS1VideoProcessor(InternVLVideoProcessor):
   pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants