Skip to content

[Bug]: bug title #2145

@annimukherjee

Description

@annimukherjee

🐛 Bug

I have a discrepancy with how evaluate.py is structured.

I am trying to evaluate DQN on Atari Environments (specifically: Breakout). I have defined my environment like this:

# === Setup Environment ===
    
    # Training environment
    env = make_atari_env(args.env_name, n_envs=args.n_envs, seed=args.seed)
    env = VecFrameStack(env, n_stack=args.frame_stack)
    env = VecTransposeImage(env)
    
    # Evaluation environment
    eval_env = make_atari_env(args.env_name, n_envs=args.n_envs, seed=args.seed+1000)
    eval_env = VecFrameStack(eval_env, n_stack=args.frame_stack)
    eval_env = VecTransposeImage(eval_env)
    
    # Set seeds for reproducibility
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)

Now in the evaluation.py script located at stable_baselines3/common/evaluation.py whenever a done is emitted, there are two logical paths the evaluate function can follow

  1. Checking if the environment is montior wrapped or not
  2. If not, it assumes that done being True means that the episode is over and appends current_rewards and current_lengths to episode_rewards and episode_lengths respectively and resets them.

My issue is:
When an environment is Montior wrapped, shouldn't current_rewards and current_lengths be reset to zero if and only if "episode" in info.keys()? Else, you're just tracking the rewards and lengths for one life in the Atari game.

So my argument is essentially this in one line:

When an environment is Monitor wrapped, instead of pre-maturely resetting current_rewards and current_lengths to be zero based on the done, I think it's prudent to set it back to zero when the episode actually ends.

I realize that current_rewards and current_lengths aren't really important in Atari (or more generally, Monitor Wrapped environments) environments because ultimately the episode_rewards and episode_lengths come from info["episode"]["r"] however, this is something that should be looked into and I'm happy to open a PR with proper comments and explanations.

To Reproduce

import warnings
from typing import Any, Callable, Optional, Union

import gymnasium as gym
import numpy as np

from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped


def evaluate_policy(
    model: "type_aliases.PolicyPredictor",
    env: Union[gym.Env, VecEnv],
    n_eval_episodes: int = 10,
    deterministic: bool = True,
    render: bool = False,
    callback: Optional[Callable[[dict[str, Any], dict[str, Any]], None]] = None,
    reward_threshold: Optional[float] = None,
    return_episode_rewards: bool = False,
    warn: bool = True,
) -> Union[tuple[float, float], tuple[list[float], list[int]]]:

......
......
......

 if dones[i]:
                    if is_monitor_wrapped:
                        # Atari wrapper can send a "done" signal when
                        # the agent loses a life, but it does not correspond
                        # to the true end of episode
                        if "episode" in info.keys():
                            # Do not trust "done" with episode endings.
                            # Monitor wrapper includes "episode" key in info if environment
                            # has been wrapped with it. Use those rewards instead.
                            episode_rewards.append(info["episode"]["r"])
                            episode_lengths.append(info["episode"]["l"])
                            # Only increment at the real end of an episode
                            episode_counts[i] += 1
                    else:
                        episode_rewards.append(current_rewards[i])
                        episode_lengths.append(current_lengths[i])
                        episode_counts[i] += 1
                    
                    current_rewards[i] = 0
                    current_lengths[i] = 0

        observations = new_observations

Relevant log output / Error message

System Info

- OS: Linux-6.8.0-59-generic-x86_64-with-glibc2.35 # 61~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Apr 15 17:03:15 UTC 2
- Python: 3.10.12
- Stable-Baselines3: 2.6.0
- PyTorch: 2.7.0+cu126
- GPU Enabled: True
- Numpy: 2.2.6
- Cloudpickle: 3.1.1
- Gymnasium: 1.1.1

Checklist

  • My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal and working example to reproduce the bug
  • I've used the markdown code blocks for both code and stack traces.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions