Skip to content

[Frontend] Raise an extremely dangerous warning when using VLLM_ALLOW_LONG_MAX_MODEL_LEN #20904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Jul 14, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

VLLM_ALLOW_LONG_MAX_MODEL_LEN sounds harmless, and even seems like a method for context expansion.

raise ValueError(
    f"{msg} To allow overriding this maximum, set "
    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")

Even this ValueError encourages users to try using VLLM_ALLOW_LONG_MAX_MODEL_LEN

But in fact, using it Almost Anytime triggers an error.

  • If the model uses rope position encoding, positions exceeding derived_max_model_len lead to nan
  • If the model uses absolute position encoding, positions exceeding derived_max_model_len will cause a CUDA array out-of-bounds error

It would be the best case if users discover this error, I can hardly imagine what would happen if it were deployed in a production environment.

I can't think of any possible use cases for VLLM_ALLOW_LONG_MAX_MODEL_LEN.

Was it to save VRAM of cos_sin_cache? However, outputting nan makes it impossible to perform any effective tests.

Let's Remove VLLM_ALLOW_LONG_MAX_MODEL_LEN

Perhaps a extremely dangerous warning should be raised when using VLLM_ALLOW_LONG_MAX_MODEL_LEN, instead of deleting it directly.

WARNING 07-14 21:37:16 [config.py:3634]
WARNING 07-14 21:37:16 [config.py:3634] ===== EXTREMELY DANGEROUS =====
WARNING 07-14 21:37:16 [config.py:3634] Using VLLM_ALLOW_LONG_MAX_MODEL_LEN may lead to the following exceptions:
WARNING 07-14 21:37:16 [config.py:3634] - If the model uses rope position encoding, positions exceeding derived_max_model_len lead to nan.
WARNING 07-14 21:37:16 [config.py:3634] - If the model uses absolute position encoding, positions exceeding derived_max_model_len will cause a CUDA array out-of-bounds error.
WARNING 07-14 21:37:16 [config.py:3634] For testing purposes only, do not implement any functionality relying on VLLM_ALLOW_LONG_MAX_MODEL_LEN.
WARNING 07-14 21:37:16 [config.py:3634] ===============================
WARNING 07-14 21:37:16 [config.py:3634]

related to #20828 #20837 #17924 .........

Test Plan

Test Result

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @noooop, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue where the VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable, intended to allow users to specify a maximum model length beyond the derived limit, was found to cause incorrect model outputs or CUDA errors. My changes remove this problematic override mechanism, ensuring that max_model_len is strictly validated against the model's capabilities to prevent these issues and improve stability.

Highlights

  • Removal of VLLM_ALLOW_LONG_MAX_MODEL_LEN: The environment variable VLLM_ALLOW_LONG_MAX_MODEL_LEN and all its associated logic have been completely removed from the codebase, including its definition and parsing in vllm/envs.py.
  • Strict max_model_len validation: The validation logic in vllm/config.py has been updated to unconditionally raise a ValueError if the specified max_model_len exceeds the model's derived maximum length, eliminating the previous allowance for overriding this limit via the removed environment variable.
  • Updated example and benchmark configurations: Various Buildkite benchmark configurations (latency-tests-cpu.json, serving-tests-cpu.json, throughput-tests-cpu.json) and the qwen_1m.py example script have been updated to remove references to the now-defunct VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable setting. The qwen_1m.py example now explicitly passes hf_overrides for max_position_embeddings to the LLM constructor for proper long context handling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation ci/build performance Performance-related issues qwen Related to Qwen models labels Jul 14, 2025
@noooop noooop force-pushed the rm_VLLM_ALLOW_LONG_MAX_MODEL_LEN branch from 677847b to 684aaad Compare July 14, 2025 05:50
@noooop
Copy link
Contributor Author

noooop commented Jul 14, 2025

@louie-tsai

Would it be possible to remove the VLLM_ALLOW_LONG_MAX_MODEL_LEN that was added in #18444 ?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable and related code, preventing users from overriding the maximum model length. This change aims to improve code correctness and prevent potential CUDA errors or incorrect model outputs. The removal includes updates to configuration files, environment variable definitions, and error messages.

@hmellor
Copy link
Member

hmellor commented Jul 14, 2025

@noooop this was not added in #18444, it was added in #7080.

@njhill as the reviewer of the original PR, is VLLM_ALLOW_LONG_MAX_MODEL_LEN still a valid feature today?

@noooop
Copy link
Contributor Author

noooop commented Jul 14, 2025

@noooop this was not added in #18444, it was added in #7080.

Now, throughout the vllm project, VLLM_ALLOW_LONG_MAX_MODEL_LEN is only used in #18444. I want to know why #18444 needs VLLM_ALLOW_LONG_MAX_MODEL_LEN, and would there be any impact if it is removed?

Also, examples/offline_inference/qwen_1m.py uses VLLM_ALLOW_LONG_MAX_MODEL as well, and The fix might not be the most effective.

@noooop noooop changed the title Remove VLLM_ALLOW_LONG_MAX_MODEL_LEN [Frontend] Raise an extremely dangerous warning when using VLLM_ALLOW_LONG_MAX_MODEL_LEN Jul 14, 2025
@noooop
Copy link
Contributor Author

noooop commented Jul 14, 2025

Perhaps a extremely dangerous warning should be raised when using VLLM_ALLOW_LONG_MAX_MODEL_LEN, instead of deleting it directly.

Signed-off-by: wang.yuqi <noooop@126.com>
@noooop noooop force-pushed the rm_VLLM_ALLOW_LONG_MAX_MODEL_LEN branch from 4e16bb9 to 26f9cf0 Compare July 14, 2025 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation performance Performance-related issues qwen Related to Qwen models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants