You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #4636
X-link: facebookresearch/FBGEMM#1668
The existing group gemm does not work when K is not a multiple of 8. In this case, TMA load cannot be used. This diff adds the support for this general case. It works by adding mask when loading the data.
With this change, when K is not a multiple of 8, using fuse_scatter_add results in larger numerical discrepancies that cause the unit test to fail. This is most likely due to atomic add operations in scatter_add or increased rounding errors with negative numbers, since the unit test uses randn. However, if the unit test uses rand to generate only positive numbers, then the test passes successfully. For now, in this case we just disable fused scatter_add in unit test, and do not allow it in group gemm impl.
Reviewed By: sgrigory
Differential Revision: D79393881
fbshipit-source-id: a04ad8ed1be64ac31b5db325524b44e6eabf7de3
warnings.warn("TMA load is disabled as there is no TMA descriptor support!")
969
+
warnings.warn(
970
+
"TMA load is disabled as there is no TMA descriptor support!", stacklevel=2
971
+
)
964
972
965
973
ifUSE_TMA_STOREandnotutils.HAS_TMA_DESC:
966
974
USE_TMA_STORE=False
967
-
warnings.warn("TMA store is disabled as there is no TMA descriptor support!")
975
+
warnings.warn(
976
+
"TMA store is disabled as there is no TMA descriptor support!", stacklevel=2
977
+
)
968
978
969
979
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
970
980
ifuse_warp_specializationandtorch.version.hip:
971
-
warnings.warn("Warp specialization is disabled as it is not supported on ROCm.")
981
+
warnings.warn(
982
+
"Warp specialization is disabled as it is not supported on ROCm.",
983
+
stacklevel=2,
984
+
)
972
985
use_warp_specialization=False
973
986
974
987
ifuse_warp_specializationandnot_HAS_WS_SUPPORT:
975
988
warnings.warn(
976
-
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
989
+
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.",
990
+
stacklevel=2,
977
991
)
978
992
use_warp_specialization=False
979
993
@@ -991,6 +1005,22 @@ def _grouped_gemm(
991
1005
N=w.shape[0] //G
992
1006
assertK==w.shape[1]
993
1007
1008
+
ifK%8!=0:
1009
+
use_warp_specialization=False
1010
+
USE_TMA_LOAD=False
1011
+
USE_TMA_STORE=False
1012
+
warnings.warn(
1013
+
f"TMA load and warp specialization are disabled since K is not a multiple of 8: {K=}.",
1014
+
stacklevel=2,
1015
+
)
1016
+
assert (
1017
+
x_scaleisNone
1018
+
), f"Quantisation is not supported yet when K is not a multiple of 8: {K=}"
1019
+
1020
+
assert (
1021
+
output_tensorisNone
1022
+
), f"Fused scatter add has large rounding error when K is not a multiple of 8: {K=}"
0 commit comments