-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Fix Incorrect Gradients and Illegal Memory Access Error in Mamba2 #537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@tridao @albertfgu , I'm also going to ping you guys due to the nature of the problem |
Also, I should mention that for the incorrect gradients, it will happen when we accidently zero out already computed and stored gradients by other thread blocks. For example thread block 0 could accidently zero out gradients which have been computed by thread block 1. In this case thread block 1 will iterate through the first loop calculating gradients, while thread block 0 finishes and proceeds to zero out gradients that were stored by 1. |
Thanks so much @Hprairie! Let me take a careful look |
Also fixed another bug in |
@Hprairie is it simpler to set |
@tridao Yes, that would be simpler I think, the only problem is that you need to revert the pointer updates being done in the first loop which either requires saving the original pointers or reverting the update, which is not as simple. Another thing that could be done (which should fix the problem) is to do |
Oh i was gonna do the simple thing.
We can do
|
Ahh lol yeah that seems much simpler. Also, I could be wrong, but I don't think the |
I trained models using this repository's implementation prior to this PR and consistently observed a ~0.2% difference in grad norm across runs starting from the same checkpoint and data. This raises concerns about non-reproducible training behavior. Does this fix significantly impact gradient backpropagation? If so, it might suggest that Mamba2 models trained with the pre-PR implementation were affected. Addressing this soon would be valuable, as future SSM research may build upon this repository.
|
It's hard to say for sure. The current branch of the repo should be fine as the pr was merged. Before the merge, there was a chance that models trained had wrong gradient calculations, but it is hard to say exactly because it depends on which The hardest part about the bug was that it was probabilistic, meaning that running the code multiple times may/may not have it occur. Thus it's hard to say if this fix will have an impact or not on your specific machine, but there is a non-zero probability that it will. |
Hey Tri and Albert,
Digging through the code base, I have found a bug in the backward pass for the gradient calculations in the function$\Delta A$ .
_chunk_scan_bwd_ddAcs_stable_kernel
. I think this solves the issue given by #503, however, it will definitely fix incorrect gradient calculations forEssentially, all that was going on lines 1146-1149 if
BLOCK_SIZE_N
was greater thenBLOCK_SIZE_M
then it would update the pointer byBLOCK_SIZE_N
. Then when we try to zero out the remainder of the sequence in lines 1152:1154, the loop assumes that pointer ofddA_cs_ptr
is athi=(pid_m + 1) * BLOCK_SIZE_M
when we could have overshot it, given thatBLOCK_SIZE_M < BLOCK_SIZE_N
.All that I do in the PR is remove the potential for us to overshoot by disabling the kernel configs where
BLOCK_SIZE_N > BLOCK_SIZE_M
. This means that at the end of the first loop we will always have out ptr updated by(pid_m + ) * BLOCK_SIZE_M
. More complicated things can be done to fix this, however, I don't want to add complexity to the kernel, so I think that this works for now.I found that incorrect gradient calculations for
dA
were incredibly common when this bug was present, thus I don't know how much this will have affected model training.Let me know if you need any clarifications or if I can do anything else to help!