Skip to content

Fix Incorrect Gradients and Illegal Memory Access Error in Mamba2 #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2024

Conversation

Hprairie
Copy link
Contributor

@Hprairie Hprairie commented Aug 24, 2024

Hey Tri and Albert,

Digging through the code base, I have found a bug in the backward pass for the gradient calculations in the function _chunk_scan_bwd_ddAcs_stable_kernel. I think this solves the issue given by #503, however, it will definitely fix incorrect gradient calculations for $\Delta A$.

Essentially, all that was going on lines 1146-1149 if BLOCK_SIZE_N was greater then BLOCK_SIZE_M then it would update the pointer by BLOCK_SIZE_N. Then when we try to zero out the remainder of the sequence in lines 1152:1154, the loop assumes that pointer of ddA_cs_ptr is at hi=(pid_m + 1) * BLOCK_SIZE_M when we could have overshot it, given that BLOCK_SIZE_M < BLOCK_SIZE_N.

All that I do in the PR is remove the potential for us to overshoot by disabling the kernel configs where BLOCK_SIZE_N > BLOCK_SIZE_M. This means that at the end of the first loop we will always have out ptr updated by (pid_m + ) * BLOCK_SIZE_M. More complicated things can be done to fix this, however, I don't want to add complexity to the kernel, so I think that this works for now.

I found that incorrect gradient calculations for dA were incredibly common when this bug was present, thus I don't know how much this will have affected model training.

Let me know if you need any clarifications or if I can do anything else to help!

@Hprairie
Copy link
Contributor Author

@tridao @albertfgu , I'm also going to ping you guys due to the nature of the problem

@Hprairie
Copy link
Contributor Author

Hprairie commented Aug 24, 2024

Also, I should mention that for the incorrect gradients, it will happen when we accidently zero out already computed and stored gradients by other thread blocks. For example thread block 0 could accidently zero out gradients which have been computed by thread block 1.

In this case thread block 1 will iterate through the first loop calculating gradients, while thread block 0 finishes and proceeds to zero out gradients that were stored by 1.

@tridao
Copy link
Collaborator

tridao commented Aug 24, 2024

Thanks so much @Hprairie! Let me take a careful look

@Hprairie
Copy link
Contributor Author

Also fixed another bug in _chunk_scan_bwd_ddAcs_stable_kernel, in line 1136 we also don't multiply start_n by the stride.

@tridao tridao merged commit 442fab4 into state-spaces:main Nov 5, 2024
@tridao
Copy link
Collaborator

tridao commented Nov 5, 2024

@Hprairie is it simpler to set ddA_cs_ptrs to hi after the loop? Then we don't need to worry about which block M & block N would work?

@Hprairie
Copy link
Contributor Author

Hprairie commented Nov 5, 2024

@tridao Yes, that would be simpler I think, the only problem is that you need to revert the pointer updates being done in the first loop which either requires saving the original pointers or reverting the update, which is not as simple. Another thing that could be done (which should fix the problem) is to do new_hi = tl.cdiv(hi, BLOCK_SIZE_N) * BLOCK_SIZE_N) after the first loop. You could also only do this computation if BLOCK_SIZE_N > BLOCK_SIZE_M which the compiler should be able to figure out. It's been a while since I tried this but it might have been slower bc of poor triton optimizations. I don't have that much bandwidth rn to try, but was going to toy around with some new idea for mamba kernels in a month and could try some stuff then

@tridao
Copy link
Collaborator

tridao commented Nov 5, 2024

Oh i was gonna do the simple thing.
Instead of

    for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
        tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)
        ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n

We can do

    for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
        tl.store(ddA_cumsum_ptr + (offs_n + start_n + 1) * stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)

@Hprairie
Copy link
Contributor Author

Hprairie commented Nov 5, 2024

Ahh lol yeah that seems much simpler. Also, I could be wrong, but I don't think the +1 is necessary.

@klae01
Copy link

klae01 commented Nov 30, 2024

I trained models using this repository's implementation prior to this PR and consistently observed a ~0.2% difference in grad norm across runs starting from the same checkpoint and data. This raises concerns about non-reproducible training behavior.

Does this fix significantly impact gradient backpropagation? If so, it might suggest that Mamba2 models trained with the pre-PR implementation were affected. Addressing this soon would be valuable, as future SSM research may build upon this repository.

grad norm:
image
loss:
image

Oh i was gonna do the simple thing. Instead of

    for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
        tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)
        ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n

We can do

    for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
        tl.store(ddA_cumsum_ptr + (offs_n + start_n + 1) * stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)

@Hprairie
Copy link
Contributor Author

It's hard to say for sure. The current branch of the repo should be fine as the pr was merged. Before the merge, there was a chance that models trained had wrong gradient calculations, but it is hard to say exactly because it depends on which config triton's autotuner chooses. If it chooses a config that I have removed in the PR, then incorrect gradients will likely occur due to non-deterministic GPU scheduling.

The hardest part about the bug was that it was probabilistic, meaning that running the code multiple times may/may not have it occur.

Thus it's hard to say if this fix will have an impact or not on your specific machine, but there is a non-zero probability that it will.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants