Skip to content

Conversation

@maharajamihir
Copy link

@maharajamihir maharajamihir commented Jul 7, 2025

This PR is a mirror of p-doom#60. I have tested it on our repo on our Minecraft world-modelling dataset and checkpoints. Sampling of 16 frames went from 6min30s to 1min20s, with compilation being the largest fraction of the time.

Autoregressively samples up to seq_len future frames, following Figure 8 of the paper

  • Input frames are tokenized
  • Future frames are generated autoregressively in token space
  • All frames are detokenized in a single pass.
  • speeds up sampling and increases reconstruction quality significantly (distribution shift between training and inference is lower)

Note:

  • For interactive or stepwise sampling, detokenization should occur after each action.
  • to maintain consistent tensor shapes across timesteps, all current and future frames are (maskgit-)decoded at every step.
  • causal structure is preserved by
    1. reapplying the mask before each decoding step.
    2. a temporal causal mask is applied within each ST-transformer block.

Screenshot From 2025-07-05 13-37-07

…avoid recompile compile, 3) jitted sampling, 4) scan instead of loop
@maharajamihir maharajamihir changed the title Modified sampling to be closer to genies inference; made it faster Modified sampling to be closer to genies inference + made it faster Jul 7, 2025
- Overhauled Genie.sample() to support autoregressive generation of up to `seq_len` future frames, following the approach in Figure 8 of the paper.
    - Input frames are tokenized once, and future frames are generated autoregressively in token space.
    - All frames are detokenized in a single pass at the end.
    - Added detailed docstring explaining the sampling process and tensor dimensions.
    - Replaced the old MaskGIT loop with a step-wise scan over timesteps, masking and updating tokens for each future frame.
- Updated MaskGITStep mask update logic:
    - Flattened and sorted token probabilities for mask updates using einops.
    - Changed mask update to operate on the flattened token dimension, then reshaped back.
    - Fixed mask indexing to ensure correct number of tokens are unmasked at each step.
- Added einops import for tensor reshaping.
- Minor type and variable name improvements for clarity and correctness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant