Skip to content

Conversation

AntonOresten
Copy link
Contributor

@AntonOresten AntonOresten commented Aug 16, 2025

I know oneAPI is generally minimally supported in Julia, but it still works surprisingly well with KernelAbstractions.jl, and it seems everything except flash attention runs already.

As I have an Arc A770 laying around, I wanted to be able to try some DL stuff on my home desktop.

But again, since oneAPI is minimally supported across the ecosystem, naive_attention from the unit tests doesn't run because of the NNlib.batched_mul. I've opened a PR for partial oneAPI support: FluxML/NNlib.jl#644

Flash attention benchmarks

Since naive_attention doesn't work, I made a special script for just flash attention. Seems roughly within an order of magnitude above benchmarks in #11.

Causal: false, use_padmask: false, use_pair: false
Flash attention FWD:
  17.929 ms (247 allocations: 14.12 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  278.993 ms (2764 allocations: 208.47 KiB)
 - Peak memory usage: 48.383 MiB

Causal: false, use_padmask: false, use_pair: true
Flash attention FWD:
  25.261 ms (253 allocations: 15.52 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  316.307 ms (2800 allocations: 213.06 KiB)
 - Peak memory usage: 304.383 MiB

Causal: false, use_padmask: true, use_pair: false
Flash attention FWD:
  19.050 ms (253 allocations: 15.31 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  291.967 ms (2776 allocations: 211.83 KiB)
 - Peak memory usage: 48.383 MiB

Causal: false, use_padmask: true, use_pair: true
Flash attention FWD:
  26.510 ms (259 allocations: 16.58 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  326.150 ms (2812 allocations: 216.53 KiB)
 - Peak memory usage: 304.383 MiB

Causal: true, use_padmask: false, use_pair: false
Flash attention FWD:
  11.215 ms (247 allocations: 14.12 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  151.188 ms (2764 allocations: 208.47 KiB)
 - Peak memory usage: 48.383 MiB

Causal: true, use_padmask: false, use_pair: true
Flash attention FWD:
  15.863 ms (253 allocations: 15.52 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  170.084 ms (2800 allocations: 213.06 KiB)
 - Peak memory usage: 304.383 MiB

Causal: true, use_padmask: true, use_pair: false
Flash attention FWD:
  11.756 ms (253 allocations: 15.31 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  150.940 ms (2776 allocations: 211.83 KiB)
 - Peak memory usage: 48.383 MiB

Causal: true, use_padmask: true, use_pair: true
Flash attention FWD:
  16.165 ms (259 allocations: 16.58 KiB)
 - Peak memory usage: 8.250 MiB
Flash attention FWD + BWD:
  170.592 ms (2812 allocations: 216.53 KiB)
 - Peak memory usage: 304.383 MiB

@AntonOresten
Copy link
Contributor Author

Getting a lot of segmentation faults when trying to run the unit tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant