diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index e15ca8e87..9eca78a7c 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -5,7 +5,10 @@ use serde::Deserialize; #[derive(Debug, Deserialize, PartialEq, Clone)] #[serde(rename_all = "lowercase")] pub enum HiddenAct { - #[serde(alias = "gelu_pytorch_tanh")] + // NOTE: `GeluErf` is excluded due to incompatibility with cuBLASLt, as only GeLU + tanh + // approximation is implemented due to efficiency, so GeLU is standardized to tanh approx. with + // slight numerical deviation from GeLU erf (neglible on inference quality) + #[serde(alias = "gelu_new", alias = "gelu_pytorch_tanh")] Gelu, Relu, Silu, diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..46607d60d 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -180,7 +180,19 @@ impl CandleBackend { let config: String = std::fs::read_to_string(model_path.join("config.json")) .context("Unable to read config file") .map_err(|err| BackendError::Start(format!("{err:?}")))?; - let config: Config = serde_json::from_str(&config) + + let config_value: serde_json::Value = serde_json::from_str(&config) + .context("Unable to parse config.json") + .map_err(|err| BackendError::Start(format!("{err:?}")))?; + + if let Some(hidden_act) = config_value.get("hidden_act").and_then(|v| v.as_str()) { + if hidden_act == "gelu" { + // NOTE: https://github.com/huggingface/text-embeddings-inference/pull/753 + tracing::warn!("The `config.json` contains `hidden_act=gelu` and GeLU + tanh approximation will be used instead of exact GeLU (aka. GeLU erf), which might lead to subtle differences with Transformers or Sentence Transformers outputs which use exact GeLU when `hidden_act=gelu`, unless specified otherwise. GeLU + tanh is more efficient and more consistent across devices (e.g., cuBLASLt comes with fused GeLU + tanh), and will have negligible impact on the inference quality."); + } + } + + let config: Config = serde_json::from_value(config_value) .context("Model is not supported") .map_err(|err| BackendError::Start(format!("{err:?}")))?; diff --git a/backends/candle/tests/snapshots/test_bert__emotions_batch.snap b/backends/candle/tests/snapshots/test_bert__emotions_batch.snap index fd582b8c2..149ccf21c 100644 --- a/backends/candle/tests/snapshots/test_bert__emotions_batch.snap +++ b/backends/candle/tests/snapshots/test_bert__emotions_batch.snap @@ -2,87 +2,87 @@ source: backends/candle/tests/test_bert.rs expression: predictions_batch --- -- - -6.548559 - - -6.302024 - - -4.8671727 - - -3.9600255 - - -4.6329865 - - -6.2816987 - - -6.069644 - - -5.7742686 - - -6.9259467 - - -6.1909447 - - -5.67395 +- - -6.5485673 + - -6.3020196 + - -4.86717 + - -3.9600184 + - -4.632993 + - -6.2817054 + - -6.069636 + - -5.7742705 + - -6.925953 + - -6.190939 + - -5.6739373 - -6.1698227 - - -7.513461 - - -6.865867 - - -7.186479 - - -7.128109 - - -8.210709 - - -7.0171394 - - -7.1321163 - - -8.533409 - - -6.2294865 - - -8.742306 - - -5.7792044 - - -8.657227 - - -8.258305 - - -6.64832 - - -7.4060283 - - 3.046496 -- - -5.8167515 - - -6.6119466 - - -5.2771955 - - -2.6306503 - - -4.6419163 - - -5.579778 - - -5.797174 - - -6.0305815 - - -5.8720746 - - 0.45377323 - - -3.0235887 - - -5.3944407 - - -5.186683 - - -6.2649117 - - -6.1962767 - - -6.97937 - - -5.5674877 - - -5.521044 - - -5.8899207 - - -4.8699703 - - -5.6259933 - - -7.6109924 - - -4.3881936 - - -6.039008 - - -4.934696 - - -0.6715916 - - -6.399376 - - -2.4499295 -- - -6.548559 - - -6.302024 - - -4.8671727 - - -3.9600255 - - -4.6329865 - - -6.2816987 - - -6.069644 - - -5.7742686 - - -6.9259467 - - -6.1909447 - - -5.67395 + - -7.5134573 + - -6.8658743 + - -7.1864815 + - -7.128115 + - -8.2107115 + - -7.017146 + - -7.132131 + - -8.533407 + - -6.229486 + - -8.742311 + - -5.7792006 + - -8.65723 + - -8.258308 + - -6.648321 + - -7.406026 + - 3.0464942 +- - -5.816747 + - -6.611947 + - -5.2771983 + - -2.6306484 + - -4.6419153 + - -5.5797825 + - -5.7971735 + - -6.030578 + - -5.872076 + - 0.45378062 + - -3.0235896 + - -5.3944383 + - -5.18668 + - -6.264913 + - -6.196284 + - -6.9793677 + - -5.567489 + - -5.5210495 + - -5.889915 + - -4.8699794 + - -5.625993 + - -7.6109934 + - -4.388194 + - -6.0390115 + - -4.934693 + - -0.6715966 + - -6.3993735 + - -2.4499245 +- - -6.5485673 + - -6.3020196 + - -4.86717 + - -3.9600184 + - -4.632993 + - -6.2817054 + - -6.069636 + - -5.7742705 + - -6.925953 + - -6.190939 + - -5.6739373 - -6.1698227 - - -7.513461 - - -6.865867 - - -7.186479 - - -7.128109 - - -8.210709 - - -7.0171394 - - -7.1321163 - - -8.533409 - - -6.2294865 - - -8.742306 - - -5.7792044 - - -8.657227 - - -8.258305 - - -6.64832 - - -7.4060283 - - 3.046496 + - -7.5134573 + - -6.8658743 + - -7.1864815 + - -7.128115 + - -8.2107115 + - -7.017146 + - -7.132131 + - -8.533407 + - -6.229486 + - -8.742311 + - -5.7792006 + - -8.65723 + - -8.258308 + - -6.648321 + - -7.406026 + - 3.0464942