-
Notifications
You must be signed in to change notification settings - Fork 4
changes for latest versions of BenchMARL #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,7 +74,10 @@ def __init__( | |
else None | ||
) # Components that maps std_dev according to scale_mapping | ||
|
||
self.input_features = self.input_leaf_spec.shape[-1] | ||
# self.input_features = self.input_leaf_spec.shape[-1] | ||
self.input_features = sum( | ||
[spec.shape[-1] for spec in self.input_spec.values(True, True)] | ||
) | ||
self.output_features = self.output_leaf_spec.shape[-1] | ||
|
||
self.shared_mlp = MultiAgentMLP( | ||
|
@@ -132,9 +135,12 @@ def _forward( | |
) -> TensorDictBase: | ||
# Gather in_key | ||
|
||
input = tensordict.get( | ||
self.in_key | ||
) # Observation tensor of shape [*batch, n_agents, n_features] | ||
# input = tensordict.get( | ||
# self.in_key | ||
# ) | ||
input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1) | ||
|
||
# Observation tensor of shape [*batch, n_agents, n_features] | ||
shared_out = self.shared_mlp.forward(input) | ||
if agent_index is None: # Gather outputs for all agents on the obs | ||
# tensor of shape [*batch, n_agents, n_actions], where the outputs | ||
|
@@ -143,7 +149,9 @@ def _forward( | |
else: # Gather outputs for one agent on the obs | ||
# tensor of shape [*batch, n_agents, n_actions], where the outputs | ||
# along the n_agent dimension are taken with the same (agent_index) agent network | ||
agent_out = self.agent_mlps.agent_networks[agent_index].forward(input) | ||
# agent_out = self.agent_mlps.agent_networks[agent_index].forward(input) | ||
with self.agent_mlps.params[agent_index].to_module(self.agent_mlps._empty_net): | ||
agent_out = self.agent_mlps._empty_net(input) | ||
|
||
shared_out = self.process_shared_out(shared_out) | ||
|
||
|
@@ -166,6 +174,9 @@ def _forward( | |
or distance.isnan().any() # It is the first iteration | ||
or self.n_agents == 1 | ||
): | ||
distance = self.estimate_snd(input) | ||
if update_estimate: | ||
self.estimated_snd[:] = distance.detach() | ||
Comment on lines
+177
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you expalin this a bit? If those conditions are met, we can avoid computing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did this to be able to log the estimated_snd during training when the desired snd is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, but you can still see it under eval/snd no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, But if I understand that is only during evaluation right? This was helpful in understanding how the SND evolves while training. But you are right eval/snd is enough. Should we roll back to the previous version? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok got it, i ll take care of things don't worry |
||
scaling_ratio = 1.0 | ||
else: # DiCo scaling | ||
scaling_ratio = torch.where( | ||
|
@@ -237,9 +248,14 @@ def estimate_snd(self, obs: torch.Tensor): | |
""" | ||
agent_actions = [] | ||
# Gather what actions each agent would take if given the obs tensor | ||
for agent_net in self.agent_mlps.agent_networks: | ||
agent_outputs = agent_net(obs) | ||
agent_actions.append(agent_outputs) | ||
# for agent_net in self.agent_mlps.agent_networks: | ||
# agent_outputs = agent_net(obs) | ||
# agent_actions.append(agent_outputs) | ||
for agent_index in range(self.n_agents): | ||
with self.agent_mlps.params[agent_index].to_module(self.agent_mlps._empty_net): | ||
agent_out = self.agent_mlps._empty_net(obs) | ||
agent_actions.append(agent_out) | ||
|
||
|
||
distance = ( | ||
compute_behavioral_distance(agent_actions=agent_actions, just_mean=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the key change as agent_networks is no longer supported in torchRL.