Skip to content

fix: update nnx decoders from main#4132

Open
mesakhcienet wants to merge 1 commit into
AI-Hypercomputer:mainfrom
CIeNET-International:fix/update-nnx-decoder
Open

fix: update nnx decoders from main#4132
mesakhcienet wants to merge 1 commit into
AI-Hypercomputer:mainfrom
CIeNET-International:fix/update-nnx-decoder

Conversation

@mesakhcienet

@mesakhcienet mesakhcienet commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR synchronizes the NNX-based decoders (src/maxtext/layers/nnx_decoders.py) with the Flax Linen-based decoders (src/maxtext/layers/decoders.py). It resolves functional, naming, and architectural gaps between Linen and NNX paradigms—specifically targeting Qwen3/Qwen3.5, and Gemma 4 model support.

Additionally, this PR modernizes variable constructors to avoid upstream JAX/Flax NNX deprecation warnings:

  • Updated occurrences of sharding= to out_sharding= in variables (e.g., RMSNorm scale parameters and Attention sinks).
  • Refactored test configurations in tests/unit/qwen3_next_vs_reference_test.py to extract actual nnx.Param values correctly outside module constructors (using .metadata["nnx_value"] to safely unwrap raw dataclasses.Field objects in modern Flax versions).

Why is this change being made?

  1. Functional Alignment: Aligns the NNX decoder pipeline fully with Flax Linen decoders to prevent behavioral/architectural drift in modern decoder layers.
  2. Eliminates Warnings: Cleans up noisy deprecation warnings regarding sharding configurations inside touched files.
  3. Correctness in Tests: Resolves test suite failures and allows proper execution/evaluation of decoder models in purely JAX environments.

Tests

Tested on CPU with:

JAX_PLATFORMS=cpu PYTHONPATH=src python -m pytest tests/unit/qwen3_next_vs_reference_test.py
  • Results: Verified that normalizers, model structure, and gating layouts are aligned correctly.
  • Linting & Code Quality: Fully verified via the pre-commit pipeline (codespell, pylint, pyink, yamllint).

Checklist

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@codecov

codecov Bot commented Jun 10, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 74.22222% with 58 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/nnx_decoders.py 76.76% 32 Missing and 14 partials ⚠️
src/maxtext/checkpoint_conversion/utils/utils.py 8.33% 11 Missing ⚠️
src/maxtext/checkpoint_conversion/to_maxtext.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@RexBearIU RexBearIU force-pushed the fix/update-nnx-decoder branch 3 times, most recently from 816939f to 9aac666 Compare June 12, 2026 10:50
@xibinliu xibinliu force-pushed the fix/update-nnx-decoder branch from 9aac666 to 0199152 Compare June 12, 2026 16:49
- Update NNX decoders and qwen model usage
- Fix quantization and resolve static attribute error in ToLinen state updates
- Modernize sharding deprecated warnings to out_sharding
- Add unit test for Gemma4 Small NNXDecoder to maximize coverage (TPU only)
@xibinliu xibinliu force-pushed the fix/update-nnx-decoder branch from 0199152 to 64091a6 Compare June 12, 2026 16:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant