diff --git a/.gitignore b/.gitignore
index f6f64fcef..ea0331b26 100644
--- a/.gitignore
+++ b/.gitignore
@@ -53,6 +53,7 @@ wheels/
.installed.cfg
*.egg
MANIFEST
+.cache/*
# PyInstaller
# Usually these files are written by a python script from a template
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 000000000..4797720f2
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,22 @@
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Set the OS, Python version, and other tools you might need
+build:
+ os: ubuntu-24.04
+ tools:
+ python: "3.12"
+
+# Build documentation with Mkdocs
+mkdocs:
+ configuration: mkdocs.yml
+
+python:
+ install:
+ - method: pip
+ path: .
+ extra_requirements:
+ - dev
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 000000000..99db91219
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,77 @@
+# LLM Compressor Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our community include:
+
+- Demonstrating empathy and kindness toward other people
+- Being respectful of differing opinions, viewpoints, and experiences
+- Giving and gracefully accepting constructive feedback
+- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
+- Focusing on what is best not just for us as individuals, but for the overall community
+
+Examples of unacceptable behavior include:
+
+- The use of sexualized language or imagery, and sexual attention or advances of any kind
+- Trolling, insulting or derogatory comments, and personal or political attacks
+- Public or private harassment
+- Publishing others’ private information, such as a physical or email address, without their explicit permission
+- Other conduct which could reasonably be considered inappropriate in a professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement through GitHub, Slack, or Email. All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla’s code of conduct enforcement ladder](https://github.com/mozilla/diversity).
+
+For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
diff --git a/README.md b/README.md
index 424d83597..dd9e8a9ae 100644
--- a/README.md
+++ b/README.md
@@ -18,11 +18,11 @@ Big updates have landed in LLM Compressor! To get a more in-depth look, check ou
Some of the exciting new features include:
-* **Large Model Support with Sequential Onloading** As of llm-compressor>=0.6.0, you can now quantize very large language models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeek-R1 Example](examples/quantizing_moe/deepseek_r1_example.py).
+* **Llama4 Quantization Support**: Quantize a Llama4 model to [W4A16](examples/multimodal_vision/llama4_example.py) or [NVFP4](examples/quantization_w4a4_fp4/llama4_example.py). The checkpoint produced can seamlessly run in vLLM.
+* **Large Model Support with Sequential Onloading**: As of llm-compressor>=0.6.0, you can now quantize very large language models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeek-R1 Example](examples/quantizing_moe/deepseek_r1_example.py).
* **Preliminary FP4 Quantization Support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [weight-only quantization](examples/quantization_w4a16_fp4/llama3_example.py) and [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py). Support is currently preliminary and additional support will be added for MoEs.
* **Updated AWQ Support:** Improved support for MoEs with better handling of larger models
* **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor).
-* **Day 0 Llama 4 Support:** Meta utilized LLM Compressor to create the [FP8-quantized Llama-4-Maverick-17B-128E](https://huggingface.co/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8), optimized for vLLM inference using [compressed-tensors](https://github.com/neuralmagic/compressed-tensors) format.
### Supported Formats
* Activation Quantization: W8A8 (int8 and fp8)
diff --git a/docs/assets/llmcompressor-icon-white.png b/docs/assets/llmcompressor-icon-white.png
new file mode 100644
index 000000000..c14a0f72b
Binary files /dev/null and b/docs/assets/llmcompressor-icon-white.png differ
diff --git a/docs/assets/llmcompressor-icon.png b/docs/assets/llmcompressor-icon.png
new file mode 100644
index 000000000..d8b6748f4
Binary files /dev/null and b/docs/assets/llmcompressor-icon.png differ
diff --git a/docs/assets/llmcompressor-user-flows.png b/docs/assets/llmcompressor-user-flows.png
new file mode 100644
index 000000000..734ba3b1e
Binary files /dev/null and b/docs/assets/llmcompressor-user-flows.png differ
diff --git a/docs/developer/index.md b/docs/developer/index.md
new file mode 100644
index 000000000..9c55c6e9e
--- /dev/null
+++ b/docs/developer/index.md
@@ -0,0 +1,39 @@
+---
+weight: -3
+---
+
+# Developer
+
+Welcome to the Developer section of LLM Compressor! This area provides essential resources for developers who want to contribute to or extend LLM Compressor. Whether you're interested in fixing bugs, adding new features, improving documentation, or understanding the project's governance, you'll find comprehensive guides to help you get started.
+
+LLM Compressor is an open-source project that values community contributions. We maintain high standards for code quality, documentation, and community interactions to ensure that LLM Compressor remains a robust, reliable, and user-friendly tool for compressing large language models.
+
+## Developer Resources
+
+
+
+- :material-handshake:{ .lg .middle } Code of Conduct
+
+ ---
+
+ Our community guidelines ensure that participation in the LLM Compressor project is a positive, inclusive, and respectful experience for everyone.
+
+ [:octicons-arrow-right-24: Code of Conduct](code-of-conduct.md)
+
+- :material-source-pull:{ .lg .middle } Contributing Guide
+
+ ---
+
+ Learn how to effectively contribute to LLM Compressor, including reporting bugs, suggesting features, improving documentation, and submitting code.
+
+ [:octicons-arrow-right-24: Contributing Guide](contributing.md)
+
+- :material-tools:{ .lg .middle } Development Guide
+
+ ---
+
+ Detailed instructions for setting up your development environment, implementing changes, and adhering to the project's coding standards and best practices.
+
+ [:octicons-arrow-right-24: Development Guide](developing.md)
+
+
diff --git a/docs/examples/index.md b/docs/examples/index.md
new file mode 100644
index 000000000..d4059330f
--- /dev/null
+++ b/docs/examples/index.md
@@ -0,0 +1,9 @@
+---
+weight: -4
+---
+
+# Examples
+
+Welcome to the LLM Compressor examples section! Here, you'll find practical demonstrations showing how to use LLM Compressor to optimize large language models for faster and more efficient deployment with vLLM. These examples will help you understand the various compression techniques and functionalities available in LLM Compressor, making it easier to apply them to your own models.
+
+To explore the examples, you can either navigate through the list provided in the sidebar or click next to see the next example in the series. Each example is designed to be self-contained, with clear instructions and code snippets that you can run directly.
diff --git a/docs/getting-started/compress.md b/docs/getting-started/compress.md
new file mode 100644
index 000000000..c10bbfd83
--- /dev/null
+++ b/docs/getting-started/compress.md
@@ -0,0 +1,67 @@
+---
+weight: -8
+---
+
+# Compress Your Model
+
+LLM Compressor provides a straightforward way to compress your models using various optimization techniques. This guide will walk you through the process of compressing a model using different quantization methods.
+
+## Prerequisites
+
+Before you begin, ensure you have the following prerequisites:
+- **Operating System:** Linux (recommended for GPU support)
+- **Python Version:** 3.9 or newer
+- **Available GPU:** For optimal performance, it's recommended to use a GPU. LLM Compressor supports the latest PyTorch and CUDA versions for compatability with NVIDIA GPUs.
+
+## Select a Model and Dataset
+
+Before you start compressing, select the model you'd like to compress and a calibration dataset that is representative of your use case. LLM Compressor supports a variety of models and integrates natively with Hugging Face Transformers and Model Hub, so a great starting point is to use a model from the Hugging Face Model Hub. LLM Compressor also supports many datasets from the Hugging Face Datasets library, making it easy to find a suitable dataset for calibration.
+
+For this guide, we'll use the `TinyLlama` model and the `open_platypus` dataset for calibration. You can replace these with your own model and dataset as needed.
+
+## Select a Quantization Method and Scheme
+
+LLM Compressor supports several quantization methods and schemes, each with its own strengths and weaknesses. The choice of method and scheme will depend on your specific use case, hardware capabilities, and desired trade-offs between model size, speed, and accuracy.
+
+Some common quantization schemes include:
+
+| Scheme | Description | Hardware Compatibility |
+|--------|-------------|------------------------|
+| **FP W8A8** | 8-bit floating point (FP8) quantization for weights and activations, providing ~2X smaller weights with 8-bit arithmetic operations. Good for general performance and compression, especially for server and batch inference. | Latest NVIDIA GPUs (Ada Lovelace, Hopper, and later) and latest AMD GPUs |
+| **INT W8A8** | 8-bit integer (INT8) quantization for weights and activations, providing ~2X smaller weights with 8-bit arithmetic operations. Good for general performance and compression, especially for server and batch inference. | All NVIDIA GPUs, AMD GPUs, TPUs, CPUs, and other accelerators |
+| **W4A16** | 4-bit integer (INT4) weights with 16-bit floating point (FP16) activations, providing ~3.7X smaller weights but requiring 16-bit arithmetic operations. Maximum compression for latency-sensitive applications with limited memory. | All NVIDIA GPUs, AMD GPUs, TPUs, CPUs, and other accelerators |
+
+Some common quantization methods include:
+
+| Method | Description | Accuracy Recovery vs. Time |
+|--------|-------------|----------------------------|
+| **GPTQ** | Utilizes second-order layer-wise optimizations to prioritize important weights/activations and enables updates to remaining weights | High accuracy recovery but more expensive/slower to run |
+| **AWQ** | Uses channelwise scaling to better preserve important outliers in weights and activations | Moderate accuracy recovery with faster runtime than GPTQ |
+| **SmoothQuant** | Smooths outliers in activations by folding them into weights, ensuring better accuracy for weight and activation quantized models | Good accuracy recovery with minimal calibration time; composable with other methods |
+
+For this guide, we'll use `GPTQ` composed with `SmoothQuant` to create an `INT W8A8` quantized model. This combination provides a good balance for performance, accuracy, and compatability across a wide range of hardware.
+
+## Apply the Recipe
+
+LLM Compressor provides the `oneshot` API for simple and straightforward model compression. This API allows you to apply a pre-defined recipe to your model and dataset, making it easy to get started with compression. To apply what we discussed above, we'll import the necessary modifiers and create a recipe to apply to our model and dataset:
+
+```python
+from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
+from llmcompressor.modifiers.quantization import GPTQModifier
+from llmcompressor import oneshot
+
+recipe = [
+ SmoothQuantModifier(smoothing_strength=0.8),
+ GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
+]
+oneshot(
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ dataset="open_platypus",
+ recipe=recipe,
+ output_dir="TinyLlama-1.1B-Chat-v1.0-INT8",
+ max_seq_length=2048,
+ num_calibration_samples=512,
+)
+```
+
+Once the above code is run, it will save the compressed model to the specified output directory: `TinyLlama-1.1B-Chat-v1.0-INT8`. You can then load this model using the Hugging Face Transformers library or vLLM for inference and testing.
diff --git a/docs/getting-started/deploy.md b/docs/getting-started/deploy.md
new file mode 100644
index 000000000..a396410fe
--- /dev/null
+++ b/docs/getting-started/deploy.md
@@ -0,0 +1,57 @@
+---
+weight: -6
+---
+
+# Deploy with vLLM
+
+Once you've compressed your model using LLM Compressor, you can deploy it for efficient inference using vLLM. This guide walks you through the deployment process, using the output from the [Compress Your Model](compress.md) guide. If you haven't completed that step, change the model arguments in the code snippets below to point to your desired model.
+
+vLLM is a high-performance inference engine designed for large language models, providing support for various quantization formats and optimized for both single and multi-GPU setups. It also offers an OpenAI-compatible API for easy integration with existing applications.
+
+## Prerequisites
+
+Before deploying your model, ensure you have the following prerequisites:
+- **Operating System:** Linux (recommended for GPU support)
+- **Python Version:** 3.9 or newer
+- **Available GPU:** For optimal performance, it's recommended to use a GPU. vLLM supports a range of accelerators, including NVIDIA GPUs, AMD GPUs, TPUs, and other accelerators.
+- **vLLM Installed:** Ensure you have vLLM installed. You can install it using pip:
+ ```bash
+ pip install vllm
+ ```
+
+## Python API
+
+vLLM provides a Python API for easy integration with your applications, enabling you to load and use your compressed model directly in your Python code. To test the compressed model, use the following code:
+
+```python
+from vllm import LLM
+
+model = LLM("./TinyLlama-1.1B-Chat-v1.0-INT8")
+output = model.generate("What is machine learning?", max_tokens=256)
+print(output)
+```
+
+After running the above code, you should see the generated output from your compressed model. This confirms that the model is loaded and ready for inference.
+
+## HTTP Server
+
+vLLM also provides an HTTP server for serving your model via a RESTful API that is compatible with OpenAI's API definitions. This allows you to easily integrate your model into existing applications or services.
+To start the HTTP server, use the following command:
+
+```bash
+vllm serve "./TinyLlama-1.1B-Chat-v1.0-INT8"
+```
+
+By default, the server will run on `localhost:8000`. You can change the host and port by using the `--host` and `--port` flags. Now that the server is running, you can send requests to it using any HTTP client. For example, you can use `curl` to send a request:
+
+```bash
+curl -X POST http://localhost:8000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "TinyLlama-1.1B-Chat-v1.0-INT8",
+ "messages": [{"role": "user", "content": "What is machine learning?"}],
+ "max_tokens": 256
+ }'
+```
+
+This will return a JSON response with the generated text from your model. You can also use any HTTP client library in your programming language of choice to send requests to the server.
diff --git a/docs/getting-started/index.md b/docs/getting-started/index.md
new file mode 100644
index 000000000..38d3fdd60
--- /dev/null
+++ b/docs/getting-started/index.md
@@ -0,0 +1,41 @@
+---
+weight: -10
+---
+
+# Getting Started
+
+Welcome to LLM Compressor! This section will guide you through the process of installing the library, compressing your first model, and deploying it with vLLM for faster, more efficient inference.
+
+LLM Compressor makes it simple to optimize large language models for deployment, offering various quantization techniques that help you find the perfect balance between model quality, performance, and resource efficiency.
+
+## Quick Start Guides
+
+Follow the guides below to get started with LLM Compressor and optimize your models for production deployment.
+
+
+
+- :material-package-variant:{ .lg .middle } Installation
+
+ ---
+
+ Learn how to install LLM Compressor using pip or from source.
+
+ [:octicons-arrow-right-24: Installation Guide](install.md)
+
+- :material-memory:{ .lg .middle } Compress Your Model
+
+ ---
+
+ Learn how to apply quantization to your models using different algorithms and formats.
+
+ [:octicons-arrow-right-24: Compression Guide](compress.md)
+
+- :material-rocket-launch:{ .lg .middle } Deploy with vLLM
+
+ ---
+
+ Deploy your compressed model for efficient inference using vLLM.
+
+ [:octicons-arrow-right-24: Deployment Guide](deploy.md)
+
+
diff --git a/docs/getting-started/install.md b/docs/getting-started/install.md
new file mode 100644
index 000000000..abef6e63f
--- /dev/null
+++ b/docs/getting-started/install.md
@@ -0,0 +1,67 @@
+---
+weight: -10
+---
+
+# Installation
+
+LLM Compressor can be installed using several methods depending on your requirements. Below are the detailed instructions for each installation pathway.
+
+## Prerequisites
+
+Before installing LLM Compressor, ensure you have the following prerequisites:
+
+- **Operating System:** Linux (recommended for GPU support)
+- **Python Version:** 3.9 or newer
+- **Pip Version:** Ensure you have the latest version of pip installed. You can upgrade pip using the following command:
+
+ ```bash
+ python -m pip install --upgrade pip
+ ```
+
+## Installation Methods
+
+### Install from PyPI
+
+The simplest way to install LLM Compressor is via pip from the Python Package Index (PyPI):
+
+```bash
+pip install llmcompressor
+```
+
+This will install the latest stable release of LLM Compressor.
+
+### Install a Specific Version from PyPI
+
+If you need a specific version of LLM Compressor, you can specify the version number during installation:
+
+```bash
+pip install llmcompressor==0.5.1
+```
+
+Replace `0.1.0` with your desired version number.
+
+### Install from Source
+
+To install the latest development version of LLM Compressor from the main branch, use the following command:
+
+```bash
+pip install git+https://github.com/vllm-project/llm-compressor.git
+```
+
+This will clone the repository and install LLM Compressor directly from the main branch.
+
+### Install from a Local Clone
+
+If you have cloned the LLM Compressor repository locally and want to install it, navigate to the repository directory and run:
+
+```bash
+pip install .
+```
+
+For development purposes, you can install it in editable mode with the `dev` extra:
+
+```bash
+pip install -e .[dev]
+```
+
+This allows you to make changes to the source code and have them reflected immediately without reinstalling.
diff --git a/docs/schemes.md b/docs/guides/compression_schemes.md
similarity index 86%
rename from docs/schemes.md
rename to docs/guides/compression_schemes.md
index 19ff746e4..29bd99e7e 100644
--- a/docs/schemes.md
+++ b/docs/guides/compression_schemes.md
@@ -1,4 +1,4 @@
-# Optimization Schemes
+# Compression Schemes
## PTQ
PTQ is performed to reduce the precision of quantizable weights (e.g., linear layers) to a lower bit-width. Supported formats are:
@@ -19,6 +19,9 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
- Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell).
+### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py)
+- Uses block-wise quantization to compress weights to FP8 in blocks (commonly 128x128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM.
+
## Sparsification
Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include:
diff --git a/docs/guides/index.md b/docs/guides/index.md
new file mode 100644
index 000000000..b1d012d0c
--- /dev/null
+++ b/docs/guides/index.md
@@ -0,0 +1,29 @@
+---
+weight: -5
+---
+
+# Guides
+
+Welcome to the LLM Compressor guides section! Here you'll find comprehensive documentation covering key components and concepts of LLM Compressor. These guides will help you understand the various compression options available, how to apply them effectively, and how to deploy your optimized models for maximum performance.
+
+## Key Guides
+
+
+
+- :material-tune:{ .lg .middle } Compression Schemes
+
+ ---
+
+ Explore the available compression schemes for Quantization and Pruning to determine which is best for your use case.
+
+ [:octicons-arrow-right-24: Compression Schemes](compression_schemes.md)
+
+- :material-content-save:{ .lg .middle } Saving Models
+
+ ---
+
+ Learn the enhanced ways to save your compressed models with the library's extended `save_pretrained` functionality for compatibility with vLLM deployment.
+
+ [:octicons-arrow-right-24: Saving a Model](saving_a_model.md)
+
+
diff --git a/docs/save_pretrained.md b/docs/guides/saving_a_model.md
similarity index 99%
rename from docs/save_pretrained.md
rename to docs/guides/saving_a_model.md
index 4790a1500..6cc5e137f 100644
--- a/docs/save_pretrained.md
+++ b/docs/guides/saving_a_model.md
@@ -1,4 +1,4 @@
-# Enhanced `save_pretrained` Arguments
+# Saving a Model
The `llmcompressor` library extends Hugging Face's `save_pretrained` method with additional arguments to support model compression functionality. This document explains these extra arguments and how to use them effectively.
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 000000000..711581f8a
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,71 @@
+# Home
+
+!!! info "New Feature: Axolotl Sparse Finetuning Integration"
+ Easily finetune sparse LLMs through our seamless integration with Axolotl.
+ [Learn more](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor).
+
+!!! info "New Feature: AutoAWQ Integration"
+ Perform low-bit weight-only quantization efficiently using AutoAWQ, now part of LLM Compressor. [Learn more](https://github.com/vllm-project/llm-compressor/pull/1177).
+
+## 
LLM Compressor
+
+
+
+
+
+**LLM Compressor** is an easy-to-use library for optimizing large language models for deployment with vLLM, enabling up to **5X faster, cheaper inference**. It provides a comprehensive toolkit for:
+
+- Applying a wide variety of compression algorithms, including weight and activation quantization, pruning, and more
+- Seamlessly integrating with Hugging Face Transformers, Models, and Datasets
+- Using a `safetensors`-based file format for compressed model storage that is compatible with `vLLM`
+- Supporting performant compression of large models via `accelerate`
+
+## Key Features
+
+- **Weight and Activation Quantization:** Reduce model size and improve inference performance for general and server-based applications with the latest research.
+ - Supported Algorithms: GPTQ, AWQ, SmoothQuant, RTN
+ - Supported Formats: INT W8A8, FP W8A8
+- **Weight-Only Quantization:** Reduce model size and improve inference performance for latency sensitive applications with the latest research
+ - Supported Algorithms: GPTQ, AWQ, RTN
+ - Supported Formats: INT W4A16, INT W8A16
+- **Weight Pruning:** Reduce model size and improve inference performance for all use cases with the latest research
+ - Supported Algorithms: SparseGPT, Magnitude, Sparse Finetuning
+ - Supported Formats: 2:4 (semi-structured), unstructured
+
+## Key Sections
+
+
+
+- :material-rocket-launch:{ .lg .middle } Getting Started
+
+ ---
+
+ Install LLM Compressor and learn how to apply your first optimization recipe.
+
+ [:octicons-arrow-right-24: Getting started](./getting-started/)
+
+- :material-book-open-variant:{ .lg .middle } Guides
+
+ ---
+
+ Detailed guides covering compression schemes, algorithms, and advanced usage patterns.
+
+ [:octicons-arrow-right-24: Guides](./guides/)
+
+- :material-flask:{ .lg .middle } Examples
+
+ ---
+
+ Step-by-step examples for different compression techniques and model types.
+
+ [:octicons-arrow-right-24: Examples](./examples/)
+
+- :material-tools:{ .lg .middle } Developer Resources
+
+ ---
+
+ Information for contributors and developers extending LLM Compressor.
+
+ [:octicons-arrow-right-24: Developer Resources](./developer/)
+
+
diff --git a/docs/scripts/__init__.py b/docs/scripts/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/docs/scripts/gen_files.py b/docs/scripts/gen_files.py
new file mode 100644
index 000000000..afae28da5
--- /dev/null
+++ b/docs/scripts/gen_files.py
@@ -0,0 +1,114 @@
+"""
+Copy required files from outside of the docs directory into the docs directory
+for the documentation build and site.
+Uses mkdocs-gen-files to handle the file generation and compatibility with MkDocs.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import mkdocs_gen_files
+
+
+@dataclass
+class ProcessFile:
+ root_path: Path
+ docs_path: Path
+ title: Optional[str] = None
+ weight: Optional[float] = None
+
+
+def find_project_root() -> Path:
+ start_path = Path(__file__).absolute()
+ current_path = start_path.parent
+
+ while current_path:
+ if (current_path / "mkdocs.yml").exists():
+ return current_path
+ current_path = current_path.parent
+
+ raise FileNotFoundError(
+ f"Could not find mkdocs.yml in the directory tree starting from {start_path}"
+ )
+
+
+def process_files(files: list[ProcessFile], project_root: Path):
+ for file in files:
+ source_path = project_root / file.root_path
+ target_path = file.docs_path
+
+ if not source_path.exists():
+ raise FileNotFoundError(
+ f"Source file {source_path} does not exist for copying into docs "
+ f"directory at {target_path}"
+ )
+
+ content = source_path.read_text(encoding="utf-8")
+
+ # Only add frontmatter if title or weight are set
+ if file.title is not None or file.weight is not None:
+ frontmatter = "---\n"
+ if file.title is not None:
+ frontmatter += f"title: {file.title}\n"
+ if file.weight is not None:
+ frontmatter += f"weight: {file.weight}\n"
+ frontmatter += "---\n\n"
+ content = frontmatter + content
+
+ with mkdocs_gen_files.open(target_path, "w") as file_handle:
+ file_handle.write(content)
+
+ mkdocs_gen_files.set_edit_path(target_path, source_path)
+
+
+def migrate_developer_docs():
+ project_root = find_project_root()
+ files = [
+ ProcessFile(
+ root_path=Path("CODE_OF_CONDUCT.md"),
+ docs_path=Path("developer/code-of-conduct.md"),
+ title="Code of Conduct",
+ weight=-10,
+ ),
+ ProcessFile(
+ root_path=Path("CONTRIBUTING.md"),
+ docs_path=Path("developer/contributing.md"),
+ title="Contributing Guide",
+ weight=-8,
+ ),
+ ProcessFile(
+ root_path=Path("DEVELOPING.md"),
+ docs_path=Path("developer/developing.md"),
+ title="Development Guide",
+ weight=-6,
+ ),
+ ]
+ process_files(files, project_root)
+
+
+def migrate_examples():
+ project_root = find_project_root()
+ examples_path = project_root / "examples"
+ files = []
+
+ # Find all README.md files 2 levels down (examples/EXAMPLE_NAME/README.md)
+ for example_dir in examples_path.iterdir():
+ if not example_dir.is_dir() or not (readme_path := example_dir / "README.md").exists():
+ continue
+
+ example_name = example_dir.name
+ files.append(
+ ProcessFile(
+ root_path=readme_path.relative_to(project_root),
+ docs_path=Path(f"examples/{example_name}.md"),
+ title=None,
+ weight=None,
+ )
+ )
+
+ process_files(files, project_root)
+
+
+migrate_developer_docs()
+migrate_examples()
diff --git a/docs/scripts/mathjax.js b/docs/scripts/mathjax.js
new file mode 100644
index 000000000..7e48906af
--- /dev/null
+++ b/docs/scripts/mathjax.js
@@ -0,0 +1,19 @@
+window.MathJax = {
+ tex: {
+ inlineMath: [["\\(", "\\)"]],
+ displayMath: [["\\[", "\\]"]],
+ processEscapes: true,
+ processEnvironments: true
+ },
+ options: {
+ ignoreHtmlClass: ".*|",
+ processHtmlClass: "arithmatex"
+ }
+};
+
+document$.subscribe(() => {
+ MathJax.startup.output.clearCache()
+ MathJax.typesetClear()
+ MathJax.texReset()
+ MathJax.typesetPromise()
+})
diff --git a/docs/stylesheets/style.css b/docs/stylesheets/style.css
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py
new file mode 100644
index 000000000..a28d89ced
--- /dev/null
+++ b/examples/multimodal_vision/llama4_example.py
@@ -0,0 +1,92 @@
+import torch
+from datasets import load_dataset
+from transformers import Llama4ForConditionalGeneration, Llama4Processor
+
+from llmcompressor import oneshot
+from llmcompressor.modeling import prepare_for_calibration
+from llmcompressor.modifiers.quantization import GPTQModifier
+
+# Select model and load it.
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
+processor = Llama4Processor.from_pretrained(model_id)
+# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
+# This change allows compatibility with vllm.
+# To apply your own custom module for experimentation, consider updating
+# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
+model = prepare_for_calibration(model)
+
+DATASET_ID = "neuralmagic/calibration"
+NUM_CALIBRATION_SAMPLES = 512
+MAX_SEQUENCE_LENGTH = 8192
+
+ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
+
+
+def preprocess_function(example):
+ messgages = []
+ for message in example["messages"]:
+ messgages.append(
+ {
+ "role": message["role"],
+ "content": [{"type": "text", "text": message["content"]}],
+ }
+ )
+
+ return processor.apply_chat_template(
+ messgages,
+ return_tensors="pt",
+ padding=False,
+ truncation=True,
+ max_length=MAX_SEQUENCE_LENGTH,
+ tokenize=True,
+ add_special_tokens=False,
+ return_dict=True,
+ add_generation_prompt=False,
+ )
+
+
+ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
+
+
+def data_collator(batch):
+ assert len(batch) == 1
+ return {
+ key: torch.tensor(value)
+ if key != "pixel_values"
+ else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
+ for key, value in batch[0].items()
+ }
+
+
+# Configure the quantization algorithm to run.
+recipe = GPTQModifier(
+ targets="Linear",
+ scheme="W4A16",
+ ignore=[
+ "re:.*lm_head",
+ "re:.*self_attn",
+ "re:.*router",
+ "re:vision_model.*",
+ "re:multi_modal_projector.*",
+ "Llama4TextAttention",
+ ],
+)
+
+# Apply algorithms.
+# due to the large size of Llama4, we specify sequential targets such that
+# only one MLP is loaded into GPU memory at a time
+oneshot(
+ model=model,
+ dataset=ds,
+ recipe=recipe,
+ max_seq_length=MAX_SEQUENCE_LENGTH,
+ num_calibration_samples=NUM_CALIBRATION_SAMPLES,
+ data_collator=data_collator,
+ sequential_targets=["Llama4TextMLP"],
+)
+
+# Save to disk compressed.
+SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
+model.save_pretrained(SAVE_DIR, save_compressed=True)
+processor.save_pretrained(SAVE_DIR)
diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py
new file mode 100644
index 000000000..7b56928e8
--- /dev/null
+++ b/examples/quantization_w4a4_fp4/llama4_example.py
@@ -0,0 +1,93 @@
+import torch
+from datasets import load_dataset
+from transformers import Llama4ForConditionalGeneration, Llama4Processor
+
+from llmcompressor import oneshot
+from llmcompressor.modeling import prepare_for_calibration
+from llmcompressor.modifiers.quantization import QuantizationModifier
+
+# Select model and load it.
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
+processor = Llama4Processor.from_pretrained(model_id)
+# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
+# This change allows compatibility with vllm.
+# To apply your own custom module for experimentation, consider updating
+# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
+model = prepare_for_calibration(model)
+
+DATASET_ID = "neuralmagic/calibration"
+NUM_CALIBRATION_SAMPLES = 20
+MAX_SEQUENCE_LENGTH = 8192
+
+ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
+
+
+def preprocess_function(example):
+ messgages = []
+ for message in example["messages"]:
+ messgages.append(
+ {
+ "role": message["role"],
+ "content": [{"type": "text", "text": message["content"]}],
+ }
+ )
+
+ return processor.apply_chat_template(
+ messgages,
+ return_tensors="pt",
+ padding=False,
+ truncation=True,
+ max_length=MAX_SEQUENCE_LENGTH,
+ tokenize=True,
+ add_special_tokens=False,
+ return_dict=True,
+ add_generation_prompt=False,
+ )
+
+
+ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
+
+
+def data_collator(batch):
+ assert len(batch) == 1
+ return {
+ key: torch.tensor(value)
+ if key != "pixel_values"
+ else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
+ for key, value in batch[0].items()
+ }
+
+
+# Configure the quantization algorithm to run.
+recipe = QuantizationModifier(
+ targets="Linear",
+ scheme="NVFP4",
+ ignore=[
+ "re:.*lm_head",
+ "re:.*self_attn",
+ "re:.*router",
+ "re:vision_model.*",
+ "re:multi_modal_projector.*",
+ "Llama4TextAttention",
+ ],
+)
+
+# Apply algorithms.
+# due to the large size of Llama4, we specify sequential targets such that
+# only one MLP is loaded into GPU memory at a time
+oneshot(
+ model=model,
+ dataset=ds,
+ recipe=recipe,
+ max_seq_length=MAX_SEQUENCE_LENGTH,
+ num_calibration_samples=NUM_CALIBRATION_SAMPLES,
+ sequential_targets=["Llama4TextMLP"],
+ data_collator=data_collator,
+)
+
+
+# Save to disk compressed.
+SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4"
+model.save_pretrained(SAVE_DIR)
+processor.save_pretrained(SAVE_DIR)
diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py
new file mode 100644
index 000000000..fd496fe15
--- /dev/null
+++ b/examples/quantization_w8a8_fp8/fp8_block_example.py
@@ -0,0 +1,33 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from llmcompressor import oneshot
+from llmcompressor.modifiers.quantization import QuantizationModifier
+
+MODEL_ID = "Qwen/Qwen3-0.6B"
+
+# Load model.
+model = AutoModelForCausalLM.from_pretrained(
+ MODEL_ID, device_map="auto", torch_dtype="auto"
+)
+tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
+
+# Configure the quantization algorithm and scheme.
+# In this case, we:
+# * quantize the weights to fp8 with block-wise quantization
+# * quantize the activations to fp8 with dynamic per-token-group quantization
+recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"])
+
+# Apply quantization.
+oneshot(model=model, recipe=recipe)
+
+# Confirm generations of the quantized model look sane.
+print("========== SAMPLE GENERATION ==============")
+input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
+output = model.generate(input_ids, max_new_tokens=20)
+print(tokenizer.decode(output[0]))
+print("==========================================")
+
+# Save to disk in compressed-tensors format.
+SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-BLOCK"
+model.save_pretrained(SAVE_DIR)
+tokenizer.save_pretrained(SAVE_DIR)
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 000000000..48acc5e48
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,113 @@
+site_name: LLM Compressor Docs
+site_description: Documentation for LLM Compressor, an easy-to-use library for compressing large language models for deployment with vLLM.
+site_url: https://docs.vllm.ai/projects/llm-compressor
+repo_url: https://github.com/vllm-project/llm-compressor
+edit_uri: https://github.com/vllm-project/llm-compressor/tree/main/docs
+
+theme:
+ name: material
+ font:
+ text: Roboto
+ code: Roboto Mono
+ language: en
+ logo: assets/llmcompressor-icon-white.png
+ favicon: assets/llmcompressor-icon-white.png
+ features:
+ - content.action.edit
+ - content.code.annotate
+ - content.code.copy
+ - content.code.select
+ - navigation.footer
+ - navigation.indexes
+ - navigation.instant
+ - navigation.path
+ - navigation.top
+ - navigation.tracking
+ - search.highlight
+ - search.share
+ - search.suggest
+ - toc.follow
+ palette:
+ # Palette toggle for automatic mode
+ - media: "(prefers-color-scheme)"
+ toggle:
+ icon: material/brightness-auto
+ name: Switch to light mode
+
+ # Palette toggle for light mode
+ - media: "(prefers-color-scheme: light)"
+ scheme: youtube
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+
+ # Palette toggle for dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ toggle:
+ icon: material/brightness-4
+ name: Switch to system preference
+
+markdown_extensions:
+ - abbr
+ - admonition
+ - attr_list
+ - def_list
+ - footnotes
+ - md_in_html
+ - pymdownx.arithmatex:
+ generic: true
+ - pymdownx.blocks.caption
+ - pymdownx.details
+ - pymdownx.emoji:
+ emoji_index: !!python/name:material.extensions.emoji.twemoji
+ emoji_generator: !!python/name:material.extensions.emoji.to_svg
+ - pymdownx.highlight:
+ anchor_linenums: true
+ line_spans: __span
+ pygments_lang_class: true
+ - pymdownx.inlinehilite
+ - pymdownx.mark
+ - pymdownx.smartsymbols
+ - pymdownx.snippets
+ - pymdownx.superfences:
+ custom_fences:
+ - name: mermaid
+ class: mermaid
+ format: !!python/name:pymdownx.superfences.fence_code_format
+ - pymdownx.tabbed:
+ alternate_style: true
+ - pymdownx.tasklist:
+ custom_checkbox: true
+ - pymdownx.tilde
+ - tables
+
+plugins:
+ - api-autonav:
+ modules: ['src/llmcompressor']
+ - gen-files:
+ scripts:
+ - docs/scripts/gen_files.py
+ - minify:
+ minify_html: true
+ - mkdocs-nav-weight
+ - mkdocstrings:
+ default_handler: python
+ handlers:
+ python:
+ options:
+ docstring_style: sphinx
+ - search
+ - section-index
+ - social
+ - tags
+
+extra:
+ generator: false
+
+extra_css:
+ - stylesheets/style.css
+
+extra_javascript:
+ - scripts/mathjax.js
+ - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js
diff --git a/setup.py b/setup.py
index 04de6484d..88fa55223 100644
--- a/setup.py
+++ b/setup.py
@@ -155,6 +155,17 @@ def localversion_func(version: ScmVersion) -> str:
"flake8~=7.0.0",
# pre commit hooks
"pre-commit",
+ # docs
+ "mkdocs",
+ "mkdocs-material[imaging]",
+ "markdown",
+ "pymdown-extensions",
+ "mkdocs-section-index",
+ "mkdocs-minify-plugin",
+ "mkdocs-api-autonav",
+ "mkdocstrings-python",
+ "mkdocs-gen-files",
+ "mkdocs-nav-weight",
]
},
entry_points={
diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py
index 5647e4d06..95ec832fb 100644
--- a/src/llmcompressor/entrypoints/utils.py
+++ b/src/llmcompressor/entrypoints/utils.py
@@ -20,7 +20,7 @@
from llmcompressor.pytorch.model_load.helpers import parse_dtype
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
- patch_tied_tensors_bug,
+ untie_word_embeddings,
)
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
@@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"):
)
# untie tie_word_embeddings weights
- patch_tied_tensors_bug(model_args.model)
+ if not model_args.tie_word_embeddings:
+ untie_word_embeddings(model_args.model)
# wrap model.save_pretrained
modify_save_pretrained(model_args.model)
@@ -143,7 +144,6 @@ def initialize_model_from_path(
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
- tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
@@ -156,7 +156,6 @@ def initialize_model_from_path(
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
- tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py
index c5de440ce..60436cdc9 100644
--- a/src/llmcompressor/modeling/deepseek_v3.py
+++ b/src/llmcompressor/modeling/deepseek_v3.py
@@ -1,20 +1,23 @@
import torch
+from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
+__all__ = ["DeepseekV3MoECalibrate"]
+
class DeepseekV3MoECalibrate(torch.nn.Module):
"""
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
"""
- def __init__(self, config, experts, gate, shared_experts):
+ def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE):
super().__init__()
self.config = config
- self.experts = experts
- self.gate = gate
- self.shared_experts = shared_experts
+ self.experts = original.experts
+ self.gate = original.gate
+ self.shared_experts = original.shared_experts
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
@@ -46,7 +49,5 @@ def forward(self, hidden_states):
return hidden_states
-def replace(module: DeepseekV3MoE) -> DeepseekV3MoECalibrate:
- return DeepseekV3MoECalibrate(
- module.config, module.experts, module.gate, module.shared_experts
- )
+def replace(config: DeepseekV3Config, module: DeepseekV3MoE):
+ return DeepseekV3MoECalibrate(config=config, original=module)
diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py
new file mode 100644
index 000000000..a8168a31f
--- /dev/null
+++ b/src/llmcompressor/modeling/fuse.py
@@ -0,0 +1,60 @@
+from typing import Iterable
+
+import torch
+from compressed_tensors import (
+ align_module_device,
+ get_execution_device,
+ update_offload_parameter,
+)
+
+__all__ = ["center_embeddings", "fuse_norm_linears"]
+
+
+PRECISION = torch.float64
+
+
+def center_embeddings(embedding: torch.nn.Module):
+ """
+ Shift each embedding to have a mean of zero
+
+ :param embedding: embedding module containing embeddings to center
+ """
+ if not hasattr(embedding, "weight"):
+ raise ValueError(f"Cannot fuse norm of type {type(embedding)}")
+
+ with align_module_device(embedding):
+ weight_dtype = embedding.weight.dtype
+ weight = embedding.weight.to(PRECISION)
+ new_weight = weight - weight.mean(dim=-1, keepdim=True)
+ new_weight = new_weight.to(weight_dtype)
+
+ update_offload_parameter(embedding, "weight", new_weight)
+
+
+def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
+ """
+ Fuse the scaling operation of norm layer into subsequent linear layers.
+ This useful for ensuring transform invariance between norm and linear layers.
+
+ Note that unitary transforms (rotation) commute with normalization, but not scaling
+
+ :param norm: norm layer whose weight will be fused into subsequent linears
+ :param linears: linear layers which directly follow the norm layer
+ """
+ if not hasattr(norm, "weight"):
+ raise ValueError(f"Cannot fuse norm of type {type(norm)}")
+
+ for linear in linears:
+ # NOTE: spinquant does this op in float64
+ exec_device = get_execution_device(norm)
+ with align_module_device(norm, exec_device), align_module_device(
+ linear, exec_device
+ ):
+ weight_dtype = linear.weight.dtype
+ new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
+ new_weight = new_weight.to(weight_dtype)
+
+ update_offload_parameter(linear, "weight", new_weight)
+
+ new_norm_weight = torch.ones_like(norm.weight, device="cpu")
+ update_offload_parameter(norm, "weight", new_norm_weight)
diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py
new file mode 100644
index 000000000..02e3dc8fc
--- /dev/null
+++ b/src/llmcompressor/modeling/llama4.py
@@ -0,0 +1,68 @@
+from typing import Tuple
+
+import torch
+from transformers.models import Llama4Config
+from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
+from transformers.models.llama4.modeling_llama4 import (
+ Llama4TextExperts,
+ Llama4TextMLP,
+ Llama4TextMoe,
+)
+
+from llmcompressor.utils.dev import skip_weights_initialize
+
+__all__ = ["SequentialLlama4TextMoe"]
+
+
+class SequentialLlama4TextMoe(torch.nn.Module):
+ def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.hidden_dim = config.hidden_size
+ self.num_experts = config.num_local_experts
+ self.experts = SequentialLlama4TextExperts(config, original.experts)
+ self.router = original.router
+ self.shared_expert = original.shared_expert
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = self.router(hidden_states)
+
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
+
+ router_scores = (
+ torch.full_like(router_logits, float("-inf"))
+ .scatter_(1, router_indices, router_top_value)
+ .transpose(0, 1)
+ )
+ router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
+
+ out = self.shared_expert(hidden_states)
+ for i in range(self.num_experts):
+ out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
+
+ return out, router_scores
+
+
+class SequentialLlama4TextExperts(torch.nn.ModuleList):
+ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
+ self.num_experts = original.gate_up_proj.shape[0]
+ with skip_weights_initialize():
+ super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
+
+ intermediate_size = original.down_proj.shape[1]
+
+ for i in range(self.num_experts):
+ gate_up = original.gate_up_proj[i]
+ down = original.down_proj[i]
+
+ gate_proj = gate_up[:, :intermediate_size]
+ up_proj = gate_up[:, intermediate_size:]
+
+ self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous()
+ self[i].up_proj.weight.data = up_proj.t().clone().contiguous()
+ self[i].down_proj.weight.data = down.t().clone().contiguous()
+
+
+def replace(config: Llama4Config, module: Llama4TextMoe):
+ return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py
index 6944327b0..0ef627db4 100644
--- a/src/llmcompressor/modeling/prepare.py
+++ b/src/llmcompressor/modeling/prepare.py
@@ -1,12 +1,14 @@
from compressed_tensors.utils import replace_module
from transformers import PreTrainedModel
-from llmcompressor.modeling.deepseek_v3 import replace as replace_DeepseekV3MoE
+from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
+from llmcompressor.modeling.llama4 import replace as replace_llama4
__all__ = ["prepare_for_calibration"]
replacements = {
- "DeepseekV3MoE": replace_DeepseekV3MoE,
+ "DeepseekV3MoE": replace_deepseekv3,
+ "Llama4TextMoe": replace_llama4,
}
@@ -14,7 +16,7 @@ def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
for name, module in model.named_modules():
cls_name = module.__class__.__name__
if cls_name in replacements:
- new_module = replacements[cls_name](module)
+ new_module = replacements[cls_name](config=model.config, module=module)
replace_module(model, name, new_module)
return model
diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py
index b10a4cb31..c43d1ae5a 100644
--- a/src/llmcompressor/modifiers/quantization/calibration.py
+++ b/src/llmcompressor/modifiers/quantization/calibration.py
@@ -10,7 +10,11 @@
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
-from compressed_tensors.utils import align_module_device, update_parameter_data
+from compressed_tensors.utils import (
+ align_module_device,
+ update_offload_parameter,
+ update_parameter_data,
+)
from loguru import logger
from torch.nn import Module
@@ -124,8 +128,19 @@ def call_observer(
updated_scale, updated_zero_point = observer(
value, g_idx=g_idx, global_scale=global_scale
)
- update_parameter_data(module, updated_scale, f"{base_name}_scale")
- update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
+ # register or update scale & zero_point parameters (supports block shapes)
+ scale_name = f"{base_name}_scale"
+ zp_name = f"{base_name}_zero_point"
+ for name, value in [
+ (scale_name, updated_scale),
+ (zp_name, updated_zero_point),
+ ]:
+ if not hasattr(module, name):
+ module.register_parameter(
+ name, torch.nn.Parameter(value.clone(), requires_grad=False)
+ )
+ else:
+ update_offload_parameter(module, name, value)
def update_weight_global_scale(module: Module):
diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py
index 3ee446cf3..2541f88ab 100644
--- a/src/llmcompressor/observers/base.py
+++ b/src/llmcompressor/observers/base.py
@@ -63,12 +63,17 @@ def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
+ tensor_id: Optional[Any] = None,
+ global_scale: Optional[torch.Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
+ :param tensor_id: Optional id if different ranges of observed tensors are
+ passed, useful for sharding tensors by group_size
+ :param global_scale: optional scale to further scale local quantization scales
:return: tuple of scale and zero point derived from the observed tensor
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
@@ -193,12 +198,57 @@ def get_qparams(
)
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
- # TODO (#1475) add support for block-wise quantization
- raise NotImplementedError(
- "Block-wise quantization is not yet supported, "
- "consider group-wise quantization instead. More info at "
- "https://github.com/vllm-project/llm-compressor/issues/1475"
+ # Block-wise quantization: one scale/zero_point per block of shape
+ # [block_rows, block_cols]
+ rows, cols = observed.shape[:2]
+ bs = self.quantization_args.block_structure
+ if not (
+ isinstance(bs, (list, tuple))
+ and len(bs) == 2
+ and all(isinstance(x, int) for x in bs)
+ ):
+ raise ValueError(
+ f"Invalid block_structure '{bs}'. "
+ "Must be a list of two ints [rows, cols]."
+ )
+ block_rows, block_cols = bs
+
+ # Enforce exact division (dimensions must be divisible by block size)
+ if rows % block_rows != 0:
+ raise ValueError(
+ f"Tensor height {rows} is not divisible by block_rows "
+ f"{block_rows}. Block quantization requires exact division."
+ )
+ if cols % block_cols != 0:
+ raise ValueError(
+ f"Tensor width {cols} is not divisible by block_cols "
+ f"{block_cols}. Block quantization requires exact division."
+ )
+
+ num_br = rows // block_rows
+ num_bc = cols // block_cols
+ # allocate per-block scale and zero_point
+ self._scale = torch.empty(
+ (num_br, num_bc), dtype=observed.dtype, device=observed.device
+ )
+ self._zero_point = torch.empty(
+ (num_br, num_bc), dtype=observed.dtype, device=observed.device
)
+ # compute qparams for each block
+ for i in range(num_br):
+ r0 = i * block_rows
+ r1 = (i + 1) * block_rows
+ for j in range(num_bc):
+ c0 = j * block_cols
+ c1 = (j + 1) * block_cols
+ # reduce across both dims to get one scale and zp per block
+ scale_bp, zp_bp = self.calculate_qparams(
+ observed[r0:r1, c0:c1],
+ reduce_dims=(0, 1),
+ tensor_id=i * num_bc + j,
+ )
+ self._scale[i, j] = scale_bp
+ self._zero_point[i, j] = zp_bp
return self._scale, self._zero_point
diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
index 69b0e3f28..1495f6d06 100644
--- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
+++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
@@ -9,11 +9,11 @@
CompressionFormat,
ModelCompressor,
SparsityCompressionConfig,
+ delete_offload_parameter,
is_module_offloaded,
- update_offload_parameter,
+ register_offload_parameter,
)
from loguru import logger
-from safetensors.torch import storage_ptr
from transformers import PreTrainedModel
from llmcompressor.core import active_session
@@ -27,7 +27,7 @@
from llmcompressor.transformers.utils import RECIPE_FILE_NAME
from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
-__all__ = ["modify_save_pretrained"]
+__all__ = ["modify_save_pretrained", "untie_word_embeddings"]
def modify_save_pretrained(model: PreTrainedModel):
@@ -120,7 +120,7 @@ def save_pretrained_wrapper(
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
-def patch_tied_tensors_bug(model: torch.nn.Module):
+def untie_word_embeddings(model: PreTrainedModel):
"""
Patches bug where HF transformers will fail to untie weights under specific
circumstances (https://github.com/huggingface/transformers/issues/33689).
@@ -129,28 +129,27 @@ def patch_tied_tensors_bug(model: torch.nn.Module):
:param model: model to fix
"""
- if (
- hasattr(model.config, "tie_word_embeddings")
- and not model.config.tie_word_embeddings
- ):
- input_embed = model.get_input_embeddings()
- output_embed = model.get_output_embeddings()
-
- if input_embed is None or output_embed is None:
- # some models fail to properly override the abstract methods
- return
-
- if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight):
- for module in (input_embed, output_embed):
- if not is_module_offloaded(module):
- # create new storage ptr for onloaded weight
- untied_data = module.weight.data.clone()
- module.weight.data = untied_data
- else:
- # create new storage ptr for offloaded weight
- # note `update_offload_parameter` does not create a new storage ptr
- untied_data = module._hf_hook.weights_map["weight"].clone()
- update_offload_parameter(module, "weight", untied_data)
+ input_embed = model.get_input_embeddings()
+ output_embed = model.get_output_embeddings()
+
+ for module in (input_embed, output_embed):
+ if module is None or not hasattr(module, "weight"):
+ logger.warning(f"Cannot untie {module} which does not have weight param")
+ continue
+
+ # this could be replaced by a `get_offloaded_parameter` util
+ if not is_module_offloaded(module):
+ untied_data = module.weight.data.clone()
+ else:
+ untied_data = module._hf_hook.weights_map["weight"].clone()
+
+ requires_grad = module.weight.requires_grad
+ new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad)
+ delete_offload_parameter(module, "weight")
+ register_offload_parameter(module, "weight", new_parameter)
+
+ if hasattr(model.config, "tie_word_embeddings"):
+ model.config.tie_word_embeddings = False
def get_model_compressor(
diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py
new file mode 100644
index 000000000..005d89f99
--- /dev/null
+++ b/tests/llmcompressor/modeling/test_fuse.py
@@ -0,0 +1,32 @@
+import pytest
+import torch
+
+from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears
+
+
+@pytest.mark.unit
+def test_center_embeddings():
+ embedding = torch.nn.Embedding(10, 10)
+ center_embeddings(embedding)
+
+ assert torch.allclose(
+ embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5
+ )
+
+
+@pytest.mark.unit
+def test_fuse_norm_linears():
+ norm = torch.nn.LayerNorm((5,))
+ norm.weight.data = torch.rand(norm.weight.shape)
+ linears = [
+ torch.nn.Linear(5, 5),
+ torch.nn.Linear(5, 5),
+ ]
+
+ input = torch.rand((1, 5), requires_grad=False)
+ true_output = torch.stack([linear(norm(input)) for linear in linears])
+
+ fuse_norm_linears(norm, linears)
+ output = torch.stack([linear(norm(input)) for linear in linears])
+
+ assert torch.allclose(true_output, output)
diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py
index b95ee9c1c..c11e817b6 100644
--- a/tests/llmcompressor/modifiers/quantization/test_base.py
+++ b/tests/llmcompressor/modifiers/quantization/test_base.py
@@ -35,6 +35,62 @@ def q_config_kwargs(config_0, config_1):
)
+@pytest.fixture
+def block_q_config_kwargs():
+ return dict(
+ config_groups=dict(
+ group_block=dict(
+ targets=["Linear"],
+ input_activations=dict(
+ num_bits=8, symmetric=True, strategy="group", group_size=128
+ ),
+ weights=dict(
+ num_bits=8,
+ symmetric=True,
+ strategy="block",
+ block_structure=[128, 128],
+ ),
+ ),
+ )
+ )
+
+
+def test_block_strategy_parsing(block_q_config_kwargs):
+ modifier = GPTQModifier(**block_q_config_kwargs)
+ resolved = modifier.resolve_quantization_config()
+ w_scheme = resolved.config_groups["group_block"].weights
+ assert w_scheme.strategy == "block"
+ assert w_scheme.block_structure == [128, 128]
+
+
+@pytest.fixture
+def block_q_config_kwargs():
+ return dict(
+ config_groups=dict(
+ group_block=dict(
+ targets=["Linear"],
+ input_activations=dict(
+ num_bits=8, symmetric=True, strategy="group", group_size=128
+ ),
+ weights=dict(
+ num_bits=8,
+ symmetric=True,
+ strategy="block",
+ block_structure=[128, 128],
+ ),
+ ),
+ )
+ )
+
+
+def test_block_strategy_parsing(block_q_config_kwargs):
+ modifier = GPTQModifier(**block_q_config_kwargs)
+ resolved = modifier.resolve_quantization_config()
+ w_scheme = resolved.config_groups["group_block"].weights
+ assert w_scheme.strategy == "block"
+ assert w_scheme.block_structure == [128, 128]
+
+
@pytest.mark.parametrize(
"has_actorder,actorder,config_0,config_1,expected_0,expected_1",
[
diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py
index 140e706d1..aad551ff8 100644
--- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py
+++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py
@@ -28,7 +28,7 @@
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
get_model_compressor,
modify_save_pretrained,
- patch_tied_tensors_bug,
+ untie_word_embeddings,
)
from tests.testing_utils import requires_gpu
@@ -224,8 +224,6 @@ def test_quant_model_reload(format, dtype, tmp_path):
shutil.rmtree(tmp_path)
-# technically only tie_word_embeddings=False is supported right now
-# setting to True is discouraged
@pytest.mark.parametrize(
"offload,torch_dtype,tie_word_embeddings,device",
[
@@ -237,25 +235,23 @@ def test_quant_model_reload(format, dtype, tmp_path):
# offloading
(True, torch.float16, False, "cpu"),
(True, torch.float32, False, "cpu"),
- # (True, torch.float16, True, "cpu"), # TODO: fails
- # (True, torch.float32, True, "cpu"), # TODO: fails
+ (True, torch.float16, True, "cpu"),
+ (True, torch.float32, True, "cpu"),
],
)
def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path):
model_path = "nm-testing/llama2.c-stories15M"
save_path = tmp_path / "save_path"
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- tie_word_embeddings=tie_word_embeddings,
- torch_dtype=torch_dtype,
- )
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
if offload:
model = dispatch_model(model, {"": device}, force_hooks=True)
else:
model = model.to(device)
- patch_tied_tensors_bug(model)
+ if not tie_word_embeddings:
+ untie_word_embeddings(model)
+
modify_save_pretrained(model)
model.save_pretrained(save_path, safe_serialization=True)
@@ -294,22 +290,18 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp
(True, torch.float32, True, "cpu"),
],
)
-def test_model_shared_tensors(
- offload, torch_dtype, tie_word_embeddings, device, tmp_path
-):
+def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device):
# load model
- model = AutoModelForCausalLM.from_pretrained(
- "nm-testing/llama2.c-stories15M",
- torch_dtype=torch_dtype,
- tie_word_embeddings=tie_word_embeddings,
- )
- patch_tied_tensors_bug(model)
-
+ model_path = "nm-testing/llama2.c-stories15M"
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
if offload:
model = dispatch_model(model, {"": device}, force_hooks=True)
else:
model = model.to(device)
+ if not tie_word_embeddings:
+ untie_word_embeddings(model)
+
# modify lm head
with torch.no_grad(), align_module_device(model.lm_head):
update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1)
@@ -332,12 +324,8 @@ def test_model_shared_tensors(
(False, torch.float32, True, "cuda:0"),
],
)
-def test_model_shared_tensors_gpu(
- offload, torch_dtype, tie_word_embeddings, device, tmp_path
-):
- test_model_shared_tensors(
- offload, torch_dtype, tie_word_embeddings, device, tmp_path
- )
+def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device):
+ test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device)
@requires_gpu