@@ -469,17 +469,21 @@ struct LocalMask : NoMask {
469
469
// where we only compute the next row and use cache for the rest
470
470
// - if you'd like this, you only need to set kIsQBegin=false
471
471
472
+ const int K = get<1 >(problem_size);
473
+
472
474
if constexpr (IsQBegin) {
473
475
CUTLASS_PRAGMA_UNROLL
474
476
for (int i = 0 ; i < size (acc_qk); i++) {
475
477
auto pos = index_qk (i);
476
478
const int pos_i = get<0 >(pos);
477
479
const int pos_j = get<1 >(pos);
478
480
479
- bool masked = (pos_i - window_size_left > pos_j) || (pos_i + window_size_right < pos_j) || !elem_less (pos, problem_size);
480
- if (masked) {
481
- acc_qk (i) = -INFINITY;
482
- }
481
+ const int window_left_bound = pos_i - window_size_left;
482
+ const int window_right_bound = pos_i + window_size_right;
483
+
484
+ bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || (pos_j >= K);
485
+
486
+ acc_qk (i) = masked ? -INFINITY : acc_qk (i);
483
487
}
484
488
} else {
485
489
const auto offset_q = get<1 >(problem_size) - get<0 >(problem_size);
@@ -489,10 +493,14 @@ struct LocalMask : NoMask {
489
493
const int pos_i = get<0 >(pos);
490
494
const int pos_j = get<1 >(pos);
491
495
492
- bool masked = (pos_i + offset_q - window_size_left > pos_j) || (pos_i + offset_q + window_size_right < pos_j) || (pos_j >= get<1 >(problem_size));
493
- if (masked) {
494
- acc_qk (i) = -INFINITY;
495
- }
496
+ const int offset_pos_i = pos_i + offset_q;
497
+
498
+ const int window_left_bound = offset_pos_i - window_size_left;
499
+ const int window_right_bound = offset_pos_i + window_size_right;
500
+
501
+ bool masked = (pos_j < window_left_bound) || (pos_j > window_right_bound) || (pos_j >= K);
502
+
503
+ acc_qk (i) = masked ? -INFINITY : acc_qk (i);
496
504
}
497
505
}
498
506
}
0 commit comments