@@ -736,11 +736,26 @@ def enable_batch_invariant_mode():
736736
737737 _batch_invariant_MODE = True
738738 _batch_invariant_LIB = torch .library .Library ("aten" , "IMPL" )
739- _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
740- _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
741- _batch_invariant_LIB .impl ("aten::matmul" , matmul_batch_invariant , "CUDA" )
742- _batch_invariant_LIB .impl ("aten::bmm" , bmm_batch_invariant , "CUDA" )
743- _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
739+
740+ # Batch invariant matmuls are no longer needed after cublas overrides
741+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
742+ if current_platform .is_device_capability (100 ):
743+ # For PyTorch 2.9, B200 uses GEMV for bs=1
744+ # Requires https://github.com/pytorch/pytorch/pull/166735, so need to use mm overrides
745+ _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
746+ _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
747+ _batch_invariant_LIB .impl ("aten::matmul" , matmul_batch_invariant , "CUDA" )
748+ _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
749+ else :
750+ # Only source of batch invariance for Hopper is split-k, can disable through
751+ # cuBLAS workspace config
752+ _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
753+ _original_cublaslt_workspace_size = os .environ .get (
754+ "CUBLASLT_WORKSPACE_SIZE" , None
755+ )
756+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
757+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
758+
744759 _batch_invariant_LIB .impl (
745760 "aten::_log_softmax" , _log_softmax_batch_invariant , "CUDA"
746761 )
@@ -749,6 +764,7 @@ def enable_batch_invariant_mode():
749764 _batch_invariant_LIB .impl ("aten::mean.dim" , mean_batch_invariant , "CUDA" )
750765
751766 # Also monkeypatch torch.bmm directly as a fallback
767+ _batch_invariant_LIB .impl ("aten::bmm" , bmm_batch_invariant , "CUDA" )
752768 _original_torch_bmm = torch .bmm
753769 torch .bmm = bmm_batch_invariant
754770
@@ -770,14 +786,6 @@ def enable_batch_invariant_mode():
770786 )
771787 torch .backends .cuda .preferred_blas_library (backend = "cublaslt" )
772788
773- if not is_torch_equal_or_newer ("2.10.0.dev" ):
774- _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
775- _original_cublaslt_workspace_size = os .environ .get (
776- "CUBLASLT_WORKSPACE_SIZE" , None
777- )
778- os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
779- os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
780-
781789
782790def disable_batch_invariant_mode ():
783791 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
0 commit comments