Skip to content

Fix: Gemma 3 & 4 Base Model and Flax Linen Decoding #4066

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/gemma3_4-base-decoding-fixes
Open

Fix: Gemma 3 & 4 Base Model and Flax Linen Decoding #4066
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/gemma3_4-base-decoding-fixes

Conversation

@RexBearIU

@RexBearIU RexBearIU commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR introduces architectural configurations for Gemma3 and Gemma4, and implements critical decoding layer KV-cache propagation fixes for JAX/NNX
(nnx_decoders.py) and Flax Linen (decoders.py) pipelines.

Before this change, Flax Linen scanned blocks (_apply_gemma3_scanned_blocks and _apply_gemma4_scanned_blocks) lacked mechanisms to propagate
intermediate KV-caches across scanned layers under nn.scan, causing end-to-end decodes to fail.

Key Changes:
- Flax Linen Decoder (src/maxtext/layers/decoders.py): Added kv_caches and attention_metadata propagation across Gemma3/4 scanned blocks
via a stack/unstack mapping over nn.scan.
- JAX/NNX Decoder (src/maxtext/layers/nnx_decoders.py): Unified dynamic KV-cache carry updates within scanned block execution.
- Model Scaffolding (gemma3.py, gemma4.py): Added kv_cache routing through scannable blocks and layer inputs.

Tests

  • Validated via end-to-end decoding with Gemma3-4B on an 8-device TPU mesh.
  • Run unit tests:
python3 -m pytest tests/post_training/unit/lora_utils_test.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 4, 2026

Copy link
Copy Markdown

@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 10 times, most recently from bf6986a to c4c93a7 Compare June 4, 2026 18:59
@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown

🤖 Hi @RexBearIU, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch from 3a983c6 to c4c93a7 Compare June 5, 2026 10:41
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Jun 5, 2026
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 3 times, most recently from e755816 to 5c2a98d Compare June 8, 2026 12:18
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 2 times, most recently from 467033c to 355c904 Compare June 11, 2026 08:12
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch from 355c904 to 59ca028 Compare June 12, 2026 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant