-
Notifications
You must be signed in to change notification settings - Fork 33
[inactive] Track entropy and MI of routing distribution for topk MoE #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
oleksost
wants to merge
18
commits into
main
Choose a base branch
from
routing_stats
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
2a7cf1b
added mutual information and entropy for routing probs
oleksost dd85e84
format
oleksost aef18e7
pre-commits
oleksost bef39d8
improved
oleksost 620ec76
using metrics dict instead of losses
oleksost 7a93aee
reduce metrics
oleksost eb617e8
check return_metrics before reducing metrics
oleksost 440738a
check return metrics before reducing
oleksost e5f3c4b
corrwect averaging with number of layers
oleksost 27e2a5c
device
oleksost b016d95
Merge branch 'main' into routing_stats
oleksost 7b9ac8c
polishing
oleksost 0577b2c
simplified: all metrics from forward are reduced
oleksost 9e2ec37
Merge branch 'routing_stats' of https://github.com/ServiceNow/Fast-LLβ¦
oleksost efd16bf
nvm
oleksost 1202f5f
moved runner test to a new file
oleksost 9855b82
parameter for MoE metrics calculation
oleksost 9c47764
Merge branch 'main' into routing_stats
oleksost File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,36 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculates routing entropy for each token, then averages over all tokens. | ||
If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization. | ||
""" | ||
n_experts = probs.size(-1) | ||
entropy_values = entropy(probs) | ||
average_entropy = entropy_values.mean() # Average over batch and tokens | ||
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) | ||
|
||
|
||
def entropy(probs: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
probs = torch.clamp(probs, min=1e-9) # Avoid log(0) | ||
return -torch.sum(probs * torch.log(probs), dim=-1) | ||
|
||
|
||
def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculates the difference between the entropy of the average routing and | ||
the average routing entropy, we average across all tokens of all examples in the batch. | ||
If low, means that routing is not informative. | ||
""" | ||
n_experts = probs.size(-1) | ||
average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens | ||
entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]] | ||
entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]] | ||
|
||
return entropy_avg_routing - entropy_routing | ||
|
||
|
||
class MixtureOfExpertMLP(MLPBase): | ||
""" | ||
MoeLayer following implementation from | ||
|
@@ -174,6 +204,20 @@ def _topk_routing( | |
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) | ||
if losses is not None or (self.training and grad_scale is not None): | ||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32) | ||
|
||
# Calculate and log entropy and mutual information | ||
entropy = calculate_normalized_average_entropy(probs) | ||
mutual_info = calculate_mutual_information(probs) | ||
|
||
# Store these metrics | ||
if "router_entropy" not in losses: | ||
losses["router_entropy"] = [] | ||
if "router_mutual_info" not in losses: | ||
losses["router_mutual_info"] = [] | ||
|
||
losses["router_entropy"].append(entropy.detach()) | ||
losses["router_mutual_info"].append(mutual_info.detach()) | ||
|
||
mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) | ||
# Auxiliary loss, corresponding to the sum of probabilities for the top experts. | ||
# In the optimal case (uniform distribution), loss = experts_per_token / num_experts. | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import pytest | ||
import torch | ||
|
||
from fast_llm.layers.transformer.mixture_of_experts import ( | ||
calculate_mutual_information, | ||
calculate_normalized_average_entropy, | ||
) | ||
|
||
|
||
def test_diversity_entropy(): | ||
""" | ||
collapse routing would have low entropy and low mutual information | ||
""" | ||
|
||
collapased_probs = torch.tensor( | ||
[ | ||
# Batch 1 | ||
[ | ||
[0.99, 0.01, 0.0, 0.0], | ||
[0.99, 0.01, 0.0, 0.0], | ||
[0.99, 0.01, 0.0, 0.0], | ||
], | ||
# Batch 2 | ||
[ | ||
[0.99, 0.01, 0.0, 0.0], | ||
[0.99, 0.01, 0.0, 0.0], | ||
[0.99, 0.01, 0.0, 0.0], | ||
], | ||
] | ||
) | ||
norm_entropy = calculate_normalized_average_entropy(collapased_probs) | ||
mutual_info = calculate_mutual_information(collapased_probs) | ||
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" | ||
assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {mutual_info}" | ||
|
||
# diverse but no collapse | ||
# should give low entropy and high mutual information | ||
diverse_probs = torch.tensor( | ||
[ | ||
# Batch 1 | ||
[ | ||
[0.99, 0.01, 0.0, 0.0], | ||
[0.01, 0.99, 0.0, 0.0], | ||
[0.01, 0.01, 0.99, 0.0], | ||
], | ||
# Batch 2 | ||
[ | ||
[0.01, 0.01, 0.99, 0.0], | ||
[0.99, 0.01, 0.0, 0.0], | ||
[0.01, 0.01, 0.01, 0.99], | ||
], | ||
] | ||
) | ||
norm_entropy = calculate_normalized_average_entropy(diverse_probs) | ||
mutual_info = calculate_mutual_information(diverse_probs) | ||
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" | ||
assert torch.isclose(mutual_info, torch.tensor(0.9), atol=1e-1), f"Expected 1.0, got {mutual_info}" | ||
|
||
|
||
def test_calculate_normalized_average_entropy(): | ||
# AI generated test case | ||
# Create a batch of routing probabilities | ||
batch_size = 2 | ||
seq_len = 3 | ||
n_experts = 4 | ||
|
||
# Test 1: Uniform distribution (should give normalized entropy of 1.0) | ||
uniform_probs = torch.ones(batch_size, seq_len, n_experts) / n_experts | ||
norm_entropy = calculate_normalized_average_entropy(uniform_probs) | ||
assert torch.isclose(norm_entropy, torch.tensor(1.0), atol=1e-5), f"Expected 1.0, got {norm_entropy}" | ||
|
||
# Test 2: One-hot distribution (should give normalized entropy of 0.0) | ||
one_hot = torch.zeros(batch_size, seq_len, n_experts) | ||
for b in range(batch_size): | ||
for s in range(seq_len): | ||
one_hot[b, s, b % n_experts] = 1.0 | ||
norm_entropy = calculate_normalized_average_entropy(one_hot) | ||
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {norm_entropy}" | ||
|
||
# Test 3: Mixed distribution | ||
mixed_probs = torch.tensor( | ||
[ | ||
# Batch 1 | ||
[ | ||
[0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 | ||
[0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 | ||
[0.25, 0.25, 0.25, 0.25], # Token 3: uniform | ||
], | ||
# Batch 2 | ||
[ | ||
[0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1 | ||
[0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3 | ||
[0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3 | ||
], | ||
] | ||
) | ||
norm_entropy = calculate_normalized_average_entropy(mixed_probs) | ||
# The expected value is between 0 and 1 | ||
assert 0.0 < norm_entropy < 1.0, f"Expected value between 0 and 1, got {norm_entropy}" | ||
|
||
|
||
def test_calculate_mutual_information(): | ||
# AI generated test cases | ||
# Create a batch of routing probabilities | ||
batch_size = 2 | ||
seq_len = 3 | ||
n_experts = 4 | ||
|
||
# Test 1: All tokens route to the same expert (low mutual information) | ||
same_expert = torch.zeros(batch_size, seq_len, n_experts) | ||
same_expert[:, :, 0] = 1.0 # All tokens route to expert 0 | ||
mutual_info = calculate_mutual_information(same_expert) | ||
assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" | ||
|
||
# Test 2: Each token routes to a different expert (high mutual information) | ||
different_experts = torch.zeros(batch_size, seq_len, n_experts) | ||
for b in range(batch_size): | ||
for s in range(seq_len): | ||
different_experts[b, s, s % n_experts] = 1.0 | ||
mutual_info = calculate_mutual_information(different_experts) | ||
# The value should be positive and closer to 1 | ||
assert mutual_info > 0.0, f"Expected positive value, got {mutual_info}" | ||
|
||
# Test 3: Mixed routing pattern | ||
mixed_probs = torch.tensor( | ||
[ | ||
# Batch 1 | ||
[ | ||
[0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 | ||
[0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 | ||
[0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2 | ||
], | ||
# Batch 2 | ||
[ | ||
[0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3 | ||
[0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0 | ||
[0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1 | ||
], | ||
] | ||
) | ||
mutual_info = calculate_mutual_information(mixed_probs) | ||
# The expected value is between 0 and 1 | ||
assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" | ||
|
||
|
||
def test_edge_cases(): | ||
# AI generated test cases | ||
# Test with very small batch and sequence length | ||
tiny_probs = torch.tensor([[[0.25, 0.25, 0.25, 0.25]]]) # batch=1, seq_len=1, n_experts=4 | ||
norm_entropy = calculate_normalized_average_entropy(tiny_probs) | ||
mutual_info = calculate_mutual_information(tiny_probs) | ||
assert torch.isclose(norm_entropy, torch.tensor(1.0)), f"Expected 1.0, got {norm_entropy}" | ||
assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" | ||
|
||
# Test with very small probabilities | ||
small_probs = torch.ones(2, 3, 4) * 1e-8 | ||
small_probs[:, :, 0] = 1.0 - 3e-8 # Make sure they sum to 1 | ||
norm_entropy = calculate_normalized_average_entropy(small_probs) | ||
mutual_info = calculate_mutual_information(small_probs) | ||
assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {norm_entropy}" | ||
assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}" | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could try
@torch.compile
on these for a free performance boost.