Skip to content

Commit de6781b

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
make apply_mask a bit more efficient (#4788)
Summary: Pull Request resolved: #4788 X-link: facebookresearch/FBGEMM#1812 Make forward faster. Reviewed By: Aya-ZIbra, sryap Differential Revision: D81246965
1 parent abc52f9 commit de6781b

File tree

1 file changed

+16
-8
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective

1 file changed

+16
-8
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_fusion.hpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -469,17 +469,21 @@ struct LocalMask : NoMask {
469469
// where we only compute the next row and use cache for the rest
470470
// - if you'd like this, you only need to set kIsQBegin=false
471471

472+
const int K = get<1>(problem_size);
473+
472474
if constexpr (IsQBegin) {
473475
CUTLASS_PRAGMA_UNROLL
474476
for (int i = 0; i < size(acc_qk); i++) {
475477
auto pos = index_qk(i);
476478
const int pos_i = get<0>(pos);
477479
const int pos_j = get<1>(pos);
478480

479-
bool masked = (pos_i - window_size_left > pos_j) || (pos_i + window_size_right < pos_j) || !elem_less(pos, problem_size);
480-
if (masked) {
481-
acc_qk(i) = -INFINITY;
482-
}
481+
const int window_left_bound = pos_i - window_size_left;
482+
const int window_right_bound = pos_i + window_size_right;
483+
484+
bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || (pos_j >= K);
485+
486+
acc_qk(i) = masked ? -INFINITY : acc_qk(i);
483487
}
484488
} else {
485489
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
@@ -489,10 +493,14 @@ struct LocalMask : NoMask {
489493
const int pos_i = get<0>(pos);
490494
const int pos_j = get<1>(pos);
491495

492-
bool masked = (pos_i + offset_q - window_size_left > pos_j) || (pos_i + offset_q + window_size_right < pos_j) || (pos_j >= get<1>(problem_size));
493-
if (masked) {
494-
acc_qk(i) = -INFINITY;
495-
}
496+
const int offset_pos_i = pos_i + offset_q;
497+
498+
const int window_left_bound = offset_pos_i - window_size_left;
499+
const int window_right_bound = offset_pos_i + window_size_right;
500+
501+
bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || (pos_j >= K);
502+
503+
acc_qk(i) = masked ? -INFINITY : acc_qk(i);
496504
}
497505
}
498506
}

0 commit comments

Comments
 (0)