diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 85fd6dec..1859ab0d 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -31,8 +31,10 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from huggingface_hub import PyTorchModelHubMixin -class Mamba2(nn.Module): + +class Mamba2(nn.Module, PyTorchModelHubMixin): def __init__( self, d_model,