Define shared memory for oneAPI (Flash attention on Intel GPUs) #13
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I know oneAPI is generally minimally supported in Julia, but it still works surprisingly well with KernelAbstractions.jl, and it seems everything except flash attention runs already.
As I have an Arc A770 laying around, I wanted to be able to try some DL stuff on my home desktop.
But again, since oneAPI is minimally supported across the ecosystem,
naive_attention
from the unit tests doesn't run because of theNNlib.batched_mul
. I've opened a PR for partial oneAPI support: FluxML/NNlib.jl#644Flash attention benchmarks
Since naive_attention doesn't work, I made a special script for just flash attention. Seems roughly within an order of magnitude above benchmarks in #11.