Skip to content

Distilled models may produce embeddings with different dimensions than original model #284

@rola93

Description

@rola93

Distilled models may produce embeddings with different dimensions than original model

Description:

Hello!

I discovered an issue where StaticEmbedding.from_distillation() can produce embeddings with different dimensions than the original model, potentially leading to unexpected behavior and incompatible embeddings.

Issue Description

When distilling certain models (specifically CLIP-based models), the distilled embeddings can have larger dimensions than the original model's output. This happens because the distillation process extracts embeddings from intermediate layers rather than the final model output.

Reproduction Example

from sentence_transformers.models import StaticEmbedding
from sentence_transformers import SentenceTransformer
import pandas as pd

model_name = "sentence-transformers/clip-ViT-B-32-multilingual-v1"

all_results = []

for pca_dims in ["auto", None, 512, 100, 900]:
    static_embedding = StaticEmbedding.from_distillation(
        model_name, 
        device="cuda", 
        pca_dims=pca_dims,
        apply_zipf=True
    )
    query_model = SentenceTransformer(modules=[static_embedding])
    infered_dims = query_model.encode("hello").shape[0]

    all_results.append({
        "pca_dims": pca_dims,
        "infered_dims": infered_dims,
        "declared_dims": query_model.get_sentence_embedding_dimension()
    })

# Compare with original model
original_model = SentenceTransformer(model_name)
infered_dims = original_model.encode("hello").shape[0]

all_results.append({
    "pca_dims": "original_model",
    "infered_dims": infered_dims,
    "declared_dims": original_model.get_sentence_embedding_dimension()
})

pd.DataFrame(all_results)

Output:
Image

Root Cause

The issue occurs with CLIP-based models that have a two-stage architecture:

  1. Text encoder: Produces 768-dimensional embeddings
  2. Linear projection: Maps from 768 to 512 dimensions (shared embedding space)

The distillation process extracts embeddings from the text encoder (last_hidden_state in _encode_mean_using_model) before the final projection layer, resulting in 768-dimensional embeddings instead of the expected 512.

When pca_dims="auto", the code preserves the extracted dimensions (768) rather than matching the original model's output dimensions (512).

Impact

  • Incompatible embeddings: Distilled embeddings cannot be directly compared with original model embeddings
  • Unexpected behavior: Users expecting 512-dimensional embeddings get 768-dimensional ones
  • Silent failure: No warning is issued about the dimension mismatch

Proposed Solution

Add a sanity check in distillation.py to warn users about dimension mismatches:

# Create the embeddings
embeddings = create_embeddings(
    tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
)

# Sanity check: warn if dimensions don't match
try:
    original_dims = model.get_sentence_embedding_dimension()
    distilled_dims = embeddings.shape[1]
    
    if distilled_dims != original_dims:
        logger.warning(
            f"Dimension mismatch detected! "
            f"Original model produces {original_dims}D embeddings, "
            f"but distillation extracted {distilled_dims}D embeddings. "
            f"This typically happens with multi-stage models (e.g., CLIP). "
        )
except (AttributeError, TypeError):
    # Handle cases where model doesn't have get_sentence_embedding_dimension()
    pass

# Post process the embeddings by applying PCA and Zipf weighting.
embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)

Final Considerations

It's important to note that simply matching dimensions does not solve the semantic incompatibility. Suggesting or enforcing pca_dims=original_dims would only address the dimensional mismatch, but the resulting embeddings would still represent fundamentally different semantic spaces:

  • Original model embeddings: Output from the complete pipeline, including final projection layers that map to a shared semantic space (e.g., text-image shared space in CLIP models)
  • Distilled embeddings: Extracted from intermediate layers (text encoder only), representing a different semantic space

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions