-
Notifications
You must be signed in to change notification settings - Fork 2k
Description
🐛 Bug
I have a discrepancy with how evaluate.py is structured.
I am trying to evaluate DQN on Atari Environments (specifically: Breakout). I have defined my environment like this:
# === Setup Environment ===
# Training environment
env = make_atari_env(args.env_name, n_envs=args.n_envs, seed=args.seed)
env = VecFrameStack(env, n_stack=args.frame_stack)
env = VecTransposeImage(env)
# Evaluation environment
eval_env = make_atari_env(args.env_name, n_envs=args.n_envs, seed=args.seed+1000)
eval_env = VecFrameStack(eval_env, n_stack=args.frame_stack)
eval_env = VecTransposeImage(eval_env)
# Set seeds for reproducibility
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
Now in the evaluation.py script located at stable_baselines3/common/evaluation.py
whenever a done
is emitted, there are two logical paths the evaluate function can follow
- Checking if the environment is montior wrapped or not
- If not, it assumes that
done
being True means that the episode is over and appendscurrent_rewards
andcurrent_lengths
toepisode_rewards
andepisode_lengths
respectively and resets them.
My issue is:
When an environment is Montior wrapped, shouldn't current_rewards
and current_lengths
be reset to zero if and only if "episode" in info.keys()
? Else, you're just tracking the rewards and lengths for one life in the Atari game.
So my argument is essentially this in one line:
When an environment is Monitor wrapped, instead of pre-maturely resetting current_rewards and current_lengths to be zero based on the done
, I think it's prudent to set it back to zero when the episode actually ends.
I realize that current_rewards and current_lengths aren't really important in Atari (or more generally, Monitor Wrapped environments) environments because ultimately the episode_rewards and episode_lengths come from info["episode"]["r"]
however, this is something that should be looked into and I'm happy to open a PR with proper comments and explanations.
To Reproduce
import warnings
from typing import Any, Callable, Optional, Union
import gymnasium as gym
import numpy as np
from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
def evaluate_policy(
model: "type_aliases.PolicyPredictor",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable[[dict[str, Any], dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
) -> Union[tuple[float, float], tuple[list[float], list[int]]]:
......
......
......
if dones[i]:
if is_monitor_wrapped:
# Atari wrapper can send a "done" signal when
# the agent loses a life, but it does not correspond
# to the true end of episode
if "episode" in info.keys():
# Do not trust "done" with episode endings.
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
# Only increment at the real end of an episode
episode_counts[i] += 1
else:
episode_rewards.append(current_rewards[i])
episode_lengths.append(current_lengths[i])
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
observations = new_observations
Relevant log output / Error message
System Info
- OS: Linux-6.8.0-59-generic-x86_64-with-glibc2.35 # 61~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Apr 15 17:03:15 UTC 2
- Python: 3.10.12
- Stable-Baselines3: 2.6.0
- PyTorch: 2.7.0+cu126
- GPU Enabled: True
- Numpy: 2.2.6
- Cloudpickle: 3.1.1
- Gymnasium: 1.1.1
Checklist
- My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I've used the markdown code blocks for both code and stack traces.