Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Deprecations:

Others:
^^^^^^^
- Performance optimization: Use ``np.asarray()`` instead of ``np.array()`` in ``ReplayBuffer`` and ``RolloutBuffer`` to avoid unnecessary array copies (@sxngt)

Documentation:
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1844,4 +1845,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @sxngt
42 changes: 21 additions & 21 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,19 @@ def add(
action = action.reshape((self.n_envs, self.action_dim))

# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs)
self.observations[self.pos] = np.asarray(obs).copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between that and simply np.array()?

I also need to check if it's needed at all (in the sense if side effects are possible)


if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
self.observations[(self.pos + 1) % self.buffer_size] = np.asarray(next_obs).copy()
else:
self.next_observations[self.pos] = np.array(next_obs)
self.next_observations[self.pos] = np.asarray(next_obs).copy()

self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
self.actions[self.pos] = np.asarray(action)
self.rewards[self.pos] = np.asarray(reward)
self.dones[self.pos] = np.asarray(done)

if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.timeouts[self.pos] = np.asarray([info.get("TimeLimit.truncated", False) for info in infos])

self.pos += 1
if self.pos == self.buffer_size:
Expand Down Expand Up @@ -468,10 +468,10 @@ def add(
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = np.array(obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.observations[self.pos] = np.asarray(obs).copy()
self.actions[self.pos] = np.asarray(action)
self.rewards[self.pos] = np.asarray(reward)
self.episode_starts[self.pos] = np.asarray(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
Expand Down Expand Up @@ -623,22 +623,22 @@ def add( # type: ignore[override]
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = np.array(obs[key])
self.observations[key][self.pos] = np.asarray(obs[key]).copy()

for key in self.next_observations.keys():
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.next_observations[key][self.pos] = np.array(next_obs[key])
self.next_observations[key][self.pos] = np.asarray(next_obs[key]).copy()

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
self.actions[self.pos] = np.asarray(action)
self.rewards[self.pos] = np.asarray(reward)
self.dones[self.pos] = np.asarray(done)

if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.timeouts[self.pos] = np.asarray([info.get("TimeLimit.truncated", False) for info in infos])

self.pos += 1
if self.pos == self.buffer_size:
Expand Down Expand Up @@ -780,7 +780,7 @@ def add( # type: ignore[override]
log_prob = log_prob.reshape(-1, 1)

for key in self.observations.keys():
obs_ = np.array(obs[key])
obs_ = np.asarray(obs[key]).copy()
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
Expand All @@ -790,9 +790,9 @@ def add( # type: ignore[override]
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.actions[self.pos] = np.asarray(action)
self.rewards[self.pos] = np.asarray(reward)
self.episode_starts[self.pos] = np.asarray(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
Expand Down
105 changes: 105 additions & 0 deletions tests/test_buffer_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np
import pytest
from stable_baselines3.common.buffers import ReplayBuffer
from gymnasium import spaces


def test_replay_buffer_no_copy_when_already_array():
"""Test that ReplayBuffer avoids unnecessary copies when inputs are already numpy arrays."""
obs_space = spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32)
action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
buffer = ReplayBuffer(buffer_size=10, observation_space=obs_space, action_space=action_space)

# Create numpy arrays
obs = np.array([1, 2, 3, 4], dtype=np.float32)
next_obs = np.array([2, 3, 4, 5], dtype=np.float32)
action = np.array([0.5, -0.5], dtype=np.float32)
reward = np.array([1.0], dtype=np.float32)
done = np.array([False], dtype=np.float32)

# Add to buffer
buffer.add(obs, next_obs, action, reward, done, [{}])

# Verify data was stored correctly
assert np.array_equal(buffer.observations[0], obs)
assert np.array_equal(buffer.next_observations[0], next_obs)
assert np.array_equal(buffer.actions[0], action)
assert np.array_equal(buffer.rewards[0], reward)
assert np.array_equal(buffer.dones[0], done)

# Verify that modifying original arrays doesn't affect buffer (copy was made for observations)
obs[:] = 0
next_obs[:] = 0
assert not np.array_equal(buffer.observations[0], obs)
assert not np.array_equal(buffer.next_observations[0], next_obs)

# Actions, rewards, dones don't need copy protection
action[:] = 99
reward[:] = 99
done[:] = 1
# These may or may not be equal depending on implementation details
# The important thing is that the buffer functions correctly


def test_replay_buffer_handles_lists_and_scalars():
"""Test that ReplayBuffer correctly handles different input types."""
obs_space = spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32)
action_space = spaces.Discrete(3)
buffer = ReplayBuffer(buffer_size=10, observation_space=obs_space, action_space=action_space)

# Test with lists
obs_list = [1.0, 2.0, 3.0, 4.0]
next_obs_list = [2.0, 3.0, 4.0, 5.0]
action_scalar = 1
reward_scalar = 2.5
done_bool = True

buffer.add(obs_list, next_obs_list, action_scalar, reward_scalar, done_bool, [{}])

# Verify conversion worked
assert buffer.observations[0].shape == (4,)
assert buffer.actions[0].shape == (1,)
assert isinstance(buffer.rewards[0], np.ndarray)
assert isinstance(buffer.dones[0], np.ndarray)


def test_replay_buffer_memory_optimization_mode():
"""Test that memory optimization mode works correctly with the optimization."""
obs_space = spaces.Box(low=0, high=255, shape=(84, 84, 4), dtype=np.uint8)
action_space = spaces.Discrete(4)

buffer = ReplayBuffer(
buffer_size=100,
observation_space=obs_space,
action_space=action_space,
optimize_memory_usage=True
)

obs = np.random.randint(0, 255, size=(84, 84, 4), dtype=np.uint8)
next_obs = np.random.randint(0, 255, size=(84, 84, 4), dtype=np.uint8)

buffer.add(obs, next_obs, 2, 1.0, False, [{}])

# In optimize_memory_usage mode, next_obs is stored at (pos + 1) % buffer_size
assert np.array_equal(buffer.observations[0], obs)
assert np.array_equal(buffer.observations[1], next_obs)

# Verify buffer doesn't have next_observations array
assert not hasattr(buffer, 'next_observations') or buffer.next_observations is None


def test_replay_buffer_discrete_observation_space():
"""Test that discrete observation spaces are handled correctly."""
obs_space = spaces.Discrete(10)
action_space = spaces.Discrete(2)
buffer = ReplayBuffer(buffer_size=10, observation_space=obs_space, action_space=action_space)

obs = 5
next_obs = 7
action = 1

buffer.add(obs, next_obs, action, 1.0, False, [{}])

# Check reshaping worked correctly
assert buffer.observations[0].shape == (1,)
assert buffer.observations[0][0] == 5
Loading