fix: update nnx decoders from main#4132
Open
mesakhcienet wants to merge 1 commit into
Open
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
10ee47d to
e08ab86
Compare
4 tasks
fd85387 to
7bc7d04
Compare
5052d63 to
247a7e3
Compare
816939f to
9aac666
Compare
9aac666 to
0199152
Compare
- Update NNX decoders and qwen model usage - Fix quantization and resolve static attribute error in ToLinen state updates - Modernize sharding deprecated warnings to out_sharding - Add unit test for Gemma4 Small NNXDecoder to maximize coverage (TPU only)
0199152 to
64091a6
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR synchronizes the NNX-based decoders (
src/maxtext/layers/nnx_decoders.py) with the Flax Linen-based decoders (src/maxtext/layers/decoders.py). It resolves functional, naming, and architectural gaps between Linen and NNX paradigms—specifically targeting Qwen3/Qwen3.5, and Gemma 4 model support.Additionally, this PR modernizes variable constructors to avoid upstream JAX/Flax NNX deprecation warnings:
sharding=toout_sharding=in variables (e.g.,RMSNormscale parameters and Attention sinks).tests/unit/qwen3_next_vs_reference_test.pyto extract actualnnx.Paramvalues correctly outside module constructors (using.metadata["nnx_value"]to safely unwrap rawdataclasses.Fieldobjects in modern Flax versions).Why is this change being made?
shardingconfigurations inside touched files.Tests
Tested on CPU with:
codespell,pylint,pyink,yamllint).Checklist
gemini-reviewlabel.