Skip to content

Training script #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

Conversation

nviolante25
Copy link

From #42

@kvuong2711
Copy link

kvuong2711 commented Apr 21, 2025

Hi @nviolante25,

Thanks for the PR. I'm trying to test your PR out, but kept running into issues with float16/float32, therefore the attention function crashed out. I'm running python main.py --base configs/example_training/seva-clipl_dl3dv.yaml. I'm also happy to contribute to the PR if you have some clues on where to fix potentially fix this.

@nviolante25
Copy link
Author

Hi @kvuong2711
I'm having the same issue, the only solution so far was to disable with sdpa_kernel(SDPBackend.FLASH_ATTENTION) . Not sure that is causing it to fail in the first place. I thought it could be the torch version but unfortunately changing it didn't work.

@jensenz-sai
Copy link
Contributor

jensenz-sai commented Apr 25, 2025

@kvuong2711 @nviolante25,

Yup, it's a bit tricky to set up flash attention for mixed-precision training. You could replace it with xformer's attention here implemented in generative-models codebase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants