From df2daa37d3b7e1e814c21690435bc4c83d89d51a Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 15 Jul 2024 18:48:11 +0200 Subject: [PATCH] Add mixin --- mamba_ssm/modules/mamba2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 85fd6dec..1859ab0d 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -31,8 +31,10 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from huggingface_hub import PyTorchModelHubMixin -class Mamba2(nn.Module): + +class Mamba2(nn.Module, PyTorchModelHubMixin): def __init__( self, d_model,