Skip to content

Commit 2feadb6

Browse files
jasonjk-parkfacebook-github-bot
authored andcommitted
Fix imports
Summary: Update imports for latest updates + silu_mul interface change Reviewed By: jianyuh Differential Revision: D64516452 fbshipit-source-id: b9b98a6eda45a093661e8b23f6b8ec300b559960
1 parent 58f3b1f commit 2feadb6

File tree

1 file changed

+5
-4
lines changed
  • torchbenchmark/operators/fp8_fused_quant_gemm_rowwise

1 file changed

+5
-4
lines changed

torchbenchmark/operators/fp8_fused_quant_gemm_rowwise/operator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import triton
99

1010
try:
11-
from gen_ai.llm_inference.fb.llm.llama_layers import (
12-
quantize_fp8_row,
11+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
12+
from gen_ai.llm_inference.fb.llm.kernel.rms_norm import (
1313
rms_norm,
1414
rms_norm_fp8_rowwise_quant,
15+
)
16+
from gen_ai.llm_inference.fb.llm.kernel.silu_mul import (
1517
silu_mul,
1618
silu_mul_fp8_rowwise_quant,
1719
)
@@ -120,8 +122,7 @@ def _impl(x1, x2, wq, w_scale, wd):
120122
@register_benchmark(enabled=HAS_FB_IMPORT)
121123
def silu_mul_quant(self, x1, x2, wq, w_scale, wd) -> Callable:
122124
def _impl(x1, x2, wq, w_scale, wd):
123-
y = torch.empty_like(x1)
124-
x = silu_mul(x1, x2, y)
125+
x = silu_mul(x1, x2)
125126
xq, x_scale = quantize_fp8_row(x, use_triton=True)
126127
if torch.version.hip:
127128
# use CK kernel for AMD

0 commit comments

Comments
 (0)