Skip to content

Commit 65d2d3a

Browse files
authored
optimize fp8 deep gemm tma (#10580)
1 parent 6e8b317 commit 65d2d3a

File tree

1 file changed

+7
-0
lines changed
  • ops/csrc/fp8/deep_gemm/jit_kernels

1 file changed

+7
-0
lines changed

ops/csrc/fp8/deep_gemm/jit_kernels/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ def get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor:
109109
assert x.dim() in (2, 3)
110110
remove_dim = False
111111
if x.dim() == 2:
112+
m, n = x.shape
113+
114+
aligned_m = get_tma_aligned_size(m, x.element_size())
115+
116+
if aligned_m == m and x.strides[0] == 1 and x.strides[1] == aligned_m:
117+
return x
118+
112119
x, remove_dim = x.unsqueeze(0), True
113120

114121
b, m, n = x.shape

0 commit comments

Comments
 (0)