diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 112ac670e9a0..6d9512d15a3d 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -71,6 +71,12 @@ layer_name="RMSNorm", ) }, + "mps": { + Mode.INFERENCE: LayerRepository( + repo_id="kernels-community/mlx_rmsnorm", + layer_name="RMSNorm", + ) + }, }, "MLP": { "cuda": LayerRepository(