Skip to content

[inactive] Track entropy and MI of routing distribution for topk MoE #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class TransformerKwargs:
class TransformerLossNames:
load_balancing_loss = "load_balancing_loss"
router_z_loss = "router_z_loss"
router_entropy = "router_entropy"
router_mutual_info = "router_mutual_info"


class RotaryEmbeddingType(str, enum.Enum):
Expand Down
44 changes: 44 additions & 0 deletions fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@
logger = logging.getLogger(__name__)


def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could try @torch.compile on these for a free performance boost.

"""
Calculates routing entropy for each token, then averages over all tokens.
If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization.
"""
n_experts = probs.size(-1)
entropy_values = entropy(probs)
average_entropy = entropy_values.mean() # Average over batch and tokens
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype))


def entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calculate_entropy

probs = torch.clamp(probs, min=1e-9) # Avoid log(0)
return -torch.sum(probs * torch.log(probs), dim=-1)


def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor:
"""
Calculates the difference between the entropy of the average routing and
the average routing entropy, we average across all tokens of all examples in the batch.
If low, means that routing is not informative.
"""
n_experts = probs.size(-1)
average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens
entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]]
entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]]

return entropy_avg_routing - entropy_routing


class MixtureOfExpertMLP(MLPBase):
"""
MoeLayer following implementation from
Expand Down Expand Up @@ -174,6 +204,20 @@ def _topk_routing(
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
if losses is not None or (self.training and grad_scale is not None):
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)

# Calculate and log entropy and mutual information
entropy = calculate_normalized_average_entropy(probs)
mutual_info = calculate_mutual_information(probs)

# Store these metrics
if "router_entropy" not in losses:
losses["router_entropy"] = []
if "router_mutual_info" not in losses:
losses["router_mutual_info"] = []

losses["router_entropy"].append(entropy.detach())
losses["router_mutual_info"].append(mutual_info.detach())

mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1)
# Auxiliary loss, corresponding to the sum of probabilities for the top experts.
# In the optimal case (uniform distribution), loss = experts_per_token / num_experts.
Expand Down
16 changes: 16 additions & 0 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,22 @@ def loss_defs(self) -> list[LossDef]:
count=self._config.transformer.num_layers,
)
)
# Add new metrics
loss_defs.append(
LossDef(
name="router_entropy",
formatted_name="router entropy",
count=self._config.transformer.num_layers,
)
)
loss_defs.append(
LossDef(
name="router_mutual_info",
formatted_name="router mutual info",
count=self._config.transformer.num_layers,
)
)

if self._config.logit_z_loss:
LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1)
return loss_defs
Expand Down
165 changes: 165 additions & 0 deletions tests/test_routing_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pytest
import torch

from fast_llm.layers.transformer.mixture_of_experts import (
calculate_mutual_information,
calculate_normalized_average_entropy,
)


def test_diversity_entropy():
"""
collapse routing would have low entropy and low mutual information
"""

collapased_probs = torch.tensor(
[
# Batch 1
[
[0.99, 0.01, 0.0, 0.0],
[0.99, 0.01, 0.0, 0.0],
[0.99, 0.01, 0.0, 0.0],
],
# Batch 2
[
[0.99, 0.01, 0.0, 0.0],
[0.99, 0.01, 0.0, 0.0],
[0.99, 0.01, 0.0, 0.0],
],
]
)
norm_entropy = calculate_normalized_average_entropy(collapased_probs)
mutual_info = calculate_mutual_information(collapased_probs)
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}"
assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {mutual_info}"

# diverse but no collapse
# should give low entropy and high mutual information
diverse_probs = torch.tensor(
[
# Batch 1
[
[0.99, 0.01, 0.0, 0.0],
[0.01, 0.99, 0.0, 0.0],
[0.01, 0.01, 0.99, 0.0],
],
# Batch 2
[
[0.01, 0.01, 0.99, 0.0],
[0.99, 0.01, 0.0, 0.0],
[0.01, 0.01, 0.01, 0.99],
],
]
)
norm_entropy = calculate_normalized_average_entropy(diverse_probs)
mutual_info = calculate_mutual_information(diverse_probs)
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}"
assert torch.isclose(mutual_info, torch.tensor(0.9), atol=1e-1), f"Expected 1.0, got {mutual_info}"


def test_calculate_normalized_average_entropy():
# AI generated test case
# Create a batch of routing probabilities
batch_size = 2
seq_len = 3
n_experts = 4

# Test 1: Uniform distribution (should give normalized entropy of 1.0)
uniform_probs = torch.ones(batch_size, seq_len, n_experts) / n_experts
norm_entropy = calculate_normalized_average_entropy(uniform_probs)
assert torch.isclose(norm_entropy, torch.tensor(1.0), atol=1e-5), f"Expected 1.0, got {norm_entropy}"

# Test 2: One-hot distribution (should give normalized entropy of 0.0)
one_hot = torch.zeros(batch_size, seq_len, n_experts)
for b in range(batch_size):
for s in range(seq_len):
one_hot[b, s, b % n_experts] = 1.0
norm_entropy = calculate_normalized_average_entropy(one_hot)
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {norm_entropy}"

# Test 3: Mixed distribution
mixed_probs = torch.tensor(
[
# Batch 1
[
[0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0
[0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1
[0.25, 0.25, 0.25, 0.25], # Token 3: uniform
],
# Batch 2
[
[0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1
[0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3
[0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3
],
]
)
norm_entropy = calculate_normalized_average_entropy(mixed_probs)
# The expected value is between 0 and 1
assert 0.0 < norm_entropy < 1.0, f"Expected value between 0 and 1, got {norm_entropy}"


def test_calculate_mutual_information():
# AI generated test cases
# Create a batch of routing probabilities
batch_size = 2
seq_len = 3
n_experts = 4

# Test 1: All tokens route to the same expert (low mutual information)
same_expert = torch.zeros(batch_size, seq_len, n_experts)
same_expert[:, :, 0] = 1.0 # All tokens route to expert 0
mutual_info = calculate_mutual_information(same_expert)
assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}"

# Test 2: Each token routes to a different expert (high mutual information)
different_experts = torch.zeros(batch_size, seq_len, n_experts)
for b in range(batch_size):
for s in range(seq_len):
different_experts[b, s, s % n_experts] = 1.0
mutual_info = calculate_mutual_information(different_experts)
# The value should be positive and closer to 1
assert mutual_info > 0.0, f"Expected positive value, got {mutual_info}"

# Test 3: Mixed routing pattern
mixed_probs = torch.tensor(
[
# Batch 1
[
[0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0
[0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1
[0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2
],
# Batch 2
[
[0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3
[0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0
[0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1
],
]
)
mutual_info = calculate_mutual_information(mixed_probs)
# The expected value is between 0 and 1
assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}"


def test_edge_cases():
# AI generated test cases
# Test with very small batch and sequence length
tiny_probs = torch.tensor([[[0.25, 0.25, 0.25, 0.25]]]) # batch=1, seq_len=1, n_experts=4
norm_entropy = calculate_normalized_average_entropy(tiny_probs)
mutual_info = calculate_mutual_information(tiny_probs)
assert torch.isclose(norm_entropy, torch.tensor(1.0)), f"Expected 1.0, got {norm_entropy}"
assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}"

# Test with very small probabilities
small_probs = torch.ones(2, 3, 4) * 1e-8
small_probs[:, :, 0] = 1.0 - 3e-8 # Make sure they sum to 1
norm_entropy = calculate_normalized_average_entropy(small_probs)
mutual_info = calculate_mutual_information(small_probs)
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {norm_entropy}"
assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}"


if __name__ == "__main__":
pytest.main([__file__])