Skip to content

Commit 7a55749

Browse files
spcypptfacebook-github-bot
authored andcommitted
Fix OSS GenAI
Summary: - add quantize_qkv_per_head for cuda <12 - skip silu_mul_quant test as it requires H100 and more memory - fix lint - fix moe.parallelism import - fix fairscale dependency - disable layers_test as it is not UnitTest. Reviewed By: q10 Differential Revision: D74493137 fbshipit-source-id: 59923aa312ad8b874ff14767476a286059785a1f
1 parent 4492631 commit 7a55749

File tree

6 files changed

+35
-10
lines changed

6 files changed

+35
-10
lines changed

.github/scripts/fbgemm_gpu_test.bash

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ __configure_fbgemm_gpu_test_cpu () {
7979
# shellcheck disable=SC2086
8080
print_exec conda env config vars set ${env_prefix} CUDA_VISIBLE_DEVICES=-1
8181

82-
ignored_tests=(
82+
export ignored_tests=(
8383
# These tests have non-CPU operators referenced in @given
8484
./uvm/copy_test.py
8585
./uvm/uvm_test.py
@@ -99,7 +99,8 @@ __configure_fbgemm_gpu_test_cuda () {
9999
# shellcheck disable=SC2086
100100
print_exec conda env config vars unset ${env_prefix} CUDA_VISIBLE_DEVICES
101101

102-
ignored_tests=(
102+
export ignored_tests=(
103+
./moe/layers_test.py # not a UnitTest
103104
)
104105
}
105106

@@ -123,7 +124,7 @@ __configure_fbgemm_gpu_test_rocm () {
123124
print_exec conda env config vars set ${env_prefix} HSA_XNACK=1
124125
fi
125126

126-
ignored_tests=(
127+
export ignored_tests=(
127128
# https://github.com/pytorch/FBGEMM/issues/1559
128129
./batched_unary_embeddings_test.py
129130
./sll/triton_sll_test.py

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,5 +2997,21 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
29972997
throw std::runtime_error(
29982998
"CUDA version is older than 12.0"); // requires CUDA>=12
29992999
}
3000+
3001+
at::Tensor quantize_qkv_per_head(
3002+
at::Tensor xqkv_amax_row, // [B_T, HH]
3003+
at::Tensor xqkv, // [B_T, HH, D_H]
3004+
at::Tensor varseq_seqpos, // [B_T]
3005+
std::optional<at::Tensor> varseq_batch, // [B_T]
3006+
at::Tensor q_seqstarts, // [B+1]
3007+
at::Tensor cache_K, // [B][MAX_T][N_KVH][D_H]
3008+
at::Tensor cache_V, // [B][MAX_T][N_KVH][D_H]
3009+
at::Tensor XQ_O, // [B_T][N_H][D]
3010+
int64_t max_seq_length, // Length of the sequence
3011+
std::optional<at::Tensor> qparam_k,
3012+
std::optional<at::Tensor> qparam_v) {
3013+
throw std::runtime_error(
3014+
"CUDA version is older than 12.0"); // requires CUDA>=12
3015+
}
30003016
#endif
30013017
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/test/moe/activation_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def ref_fn() -> torch.Tensor:
7171

7272
torch.testing.assert_allclose(y, y_ref, rtol=1.6e-2, atol=1e-3)
7373

74+
@unittest.skipIf(
75+
not torch.cuda.is_available()
76+
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
77+
"Skip when H100 is not available",
78+
)
7479
@given(
7580
T=st.sampled_from([1, 128, 2048, 4096, 16384]),
7681
D=st.sampled_from([5120, 7168]),

fbgemm_gpu/experimental/gen_ai/test/moe/layers_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,9 @@
1313
import traceback
1414
from datetime import datetime
1515
from functools import partial
16-
from typing import Callable, Mapping, Tuple, Union
16+
from typing import Tuple
1717

1818
import torch
19-
from deeplearning.fbgemm.fbgemm_gpu.experimental.gen_ai.test.moe.parallelism import (
20-
get_ep_group,
21-
get_global_rank,
22-
get_routed_experts_mp_group,
23-
init_parallel,
24-
)
2519
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_row
2620
from fbgemm_gpu.experimental.gen_ai.moe.layers import (
2721
BaselineMoE,
@@ -34,6 +28,13 @@
3428
# pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`.
3529
from torch.profiler import profile, ProfilerActivity
3630

31+
from .parallelism import (
32+
get_ep_group,
33+
get_global_rank,
34+
get_routed_experts_mp_group,
35+
init_parallel,
36+
)
37+
3738
TRACE_DIR: str = "/tmp/"
3839
WARM_UP_ITERS = 15
3940
PROFILE_ITERS = 20

fbgemm_gpu/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ setuptools
2626
setuptools_git_versioning
2727
tabulate
2828
patchelf
29+
fairscale

fbgemm_gpu/requirements_genai.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ setuptools
2828
setuptools_git_versioning
2929
tabulate
3030
patchelf
31+
fairscale

0 commit comments

Comments
 (0)