Skip to content

Commit bc50fd7

Browse files
Merge pull request #11 from LondonNode/feature/optimizations
Feature/optimizations
2 parents dd5b331 + e3b33b0 commit bc50fd7

File tree

9 files changed

+74
-33
lines changed

9 files changed

+74
-33
lines changed

pearll/buffers/base_buffer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
import numpy as np
66
import psutil
7+
import torch as T
78
from gym import Env
89
from gym.vector import VectorEnv
910

11+
from pearll import settings
1012
from pearll.common.enumerations import TrajectoryType
1113
from pearll.common.type_aliases import Observation, Trajectories
12-
from pearll.common.utils import get_space_shape, to_torch
14+
from pearll.common.utils import get_space_shape
1315

1416

1517
class BaseBuffer(ABC):
@@ -128,13 +130,15 @@ def _transform_samples(
128130

129131
# return torch tensors instead of numpy arrays
130132
if dtype == TrajectoryType.TORCH:
131-
observations, actions, rewards, next_observations, dones = to_torch(
132-
observations,
133-
actions,
134-
rewards,
135-
next_observations,
136-
dones,
133+
observations = T.from_numpy(observations).to(
134+
settings.DEVICE, non_blocking=True
137135
)
136+
actions = T.from_numpy(actions).to(settings.DEVICE, non_blocking=True)
137+
rewards = T.from_numpy(rewards).to(settings.DEVICE, non_blocking=True)
138+
next_observations = T.from_numpy(next_observations).to(
139+
settings.DEVICE, non_blocking=True
140+
)
141+
dones = T.from_numpy(dones).to(settings.DEVICE, non_blocking=True)
138142

139143
return Trajectories(
140144
observations=observations,

pearll/buffers/her_buffer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Dict, Tuple, Union
22

33
import numpy as np
4+
import torch as T
45
from gym.core import GoalEnv
56

7+
from pearll import settings
68
from pearll.buffers.base_buffer import BaseBuffer
79
from pearll.common.enumerations import GoalSelectionStrategy, TrajectoryType
810
from pearll.common.type_aliases import DictTrajectories, Tensor
@@ -219,7 +221,31 @@ def sample(
219221
else:
220222
batch_inds = np.random.randint(0, end_idx, size=batch_size)
221223

222-
trajectories = self._sample_trajectories(batch_inds)
224+
trajectories = list(self._sample_trajectories(batch_inds))
225+
226+
if dtype == TrajectoryType.TORCH:
227+
trajectories[0]["observation"] = T.from_numpy(
228+
trajectories[0]["observation"]
229+
).to(settings.DEVICE, non_blocking=True)
230+
trajectories[0]["desired_goal"] = T.from_numpy(
231+
trajectories[0]["desired_goal"]
232+
).to(settings.DEVICE, non_blocking=True)
233+
trajectories[1] = T.from_numpy(trajectories[1]).to(
234+
settings.DEVICE, non_blocking=True
235+
)
236+
trajectories[2] = T.from_numpy(trajectories[2]).to(
237+
settings.DEVICE, non_blocking=True
238+
)
239+
trajectories[3]["observation"] = T.from_numpy(
240+
trajectories[3]["observation"]
241+
).to(settings.DEVICE, non_blocking=True)
242+
trajectories[3]["desired_goal"] = T.from_numpy(
243+
trajectories[3]["desired_goal"]
244+
).to(settings.DEVICE, non_blocking=True)
245+
trajectories[4] = T.from_numpy(trajectories[4]).to(
246+
settings.DEVICE, non_blocking=True
247+
)
248+
223249
return DictTrajectories(
224250
observations=trajectories[0],
225251
actions=trajectories[1],

pearll/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def to_numpy(*data) -> Union[Tuple[np.ndarray], np.ndarray]:
5656
if isinstance(el, T.Tensor):
5757
result[i] = el.detach().cpu().numpy()
5858
else:
59-
result[i] = np.array(el)
59+
result[i] = np.asarray(el)
6060

6161
if len(data) == 1:
6262
return result[0]

pearll/models/actor_critics.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pearll import settings
99
from pearll.common.enumerations import Distribution
1010
from pearll.common.type_aliases import Tensor
11-
from pearll.common.utils import get_space_range, get_space_shape, to_numpy
11+
from pearll.common.utils import get_space_range, get_space_shape
1212
from pearll.models.encoders import IdentityEncoder, MLPEncoder
1313
from pearll.models.heads import (
1414
BaseActorHead,
@@ -66,7 +66,10 @@ def __init__(
6666
self.state_info = {}
6767
self.make_state_info()
6868
self.state = np.concatenate(
69-
[to_numpy(d.flatten()) for d in self.model.state_dict().values()]
69+
[
70+
d.flatten().detach().cpu().numpy()
71+
for d in self.model.state_dict().values()
72+
]
7073
)
7174
self.space = Box(low=-1e6, high=1e6, shape=self.state.shape)
7275
self.space_shape = get_space_shape(self.space)
@@ -217,7 +220,12 @@ def forward(self, observations: Tensor) -> T.Tensor:
217220
trigger = T.rand(1).item()
218221

219222
if trigger <= self.epsilon:
220-
actions = T.randint(low=0, high=action_size, size=q_values.shape[:-1])
223+
actions = T.randint(
224+
low=0,
225+
high=action_size,
226+
size=q_values.shape[:-1],
227+
device=settings.DEVICE,
228+
)
221229
else:
222230
_, actions = T.max(q_values, dim=-1)
223231

pearll/models/encoders.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pearll.common.type_aliases import Tensor
88
from pearll.common.utils import to_numpy
9-
from pearll.models.utils import concat_obs_actions
9+
from pearll.models.utils import preprocess_inputs
1010

1111

1212
class IdentityEncoder(T.nn.Module):
@@ -19,8 +19,8 @@ def forward(
1919
self, observations: Tensor, actions: Optional[Tensor] = None
2020
) -> T.Tensor:
2121
# Some algorithms use both the observations and actions as input (e.g. DDPG for conitnuous Q function)
22-
observations = concat_obs_actions(observations, actions)
23-
return observations
22+
input = preprocess_inputs(observations, actions)
23+
return input
2424

2525

2626
class FlattenEncoder(T.nn.Module):
@@ -34,8 +34,8 @@ def forward(
3434
) -> T.Tensor:
3535
# Some algorithms use both the observations and actions as input (e.g. DDPG for conitnuous Q function)
3636
# Make sure observations is a torch tensor, get error if numpy for some reason??
37-
observations = concat_obs_actions(observations, actions)
38-
return T.flatten(observations)
37+
input = preprocess_inputs(observations, actions)
38+
return T.flatten(input)
3939

4040

4141
class MLPEncoder(T.nn.Module):
@@ -48,8 +48,8 @@ def __init__(self, input_size, output_size):
4848
def forward(
4949
self, observations: Tensor, actions: Optional[Tensor] = None
5050
) -> T.Tensor:
51-
observations = concat_obs_actions(observations, actions)
52-
return self.model(observations)
51+
input = preprocess_inputs(observations, actions)
52+
return self.model(input)
5353

5454

5555
class CNNEncoder(T.nn.Module):

pearll/models/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import torch as T
44

5+
from pearll import settings
56
from pearll.common.type_aliases import Tensor
6-
from pearll.common.utils import to_torch
77

88

99
def trainable_parameters(model: T.nn.Module) -> list:
@@ -31,12 +31,13 @@ def get_mlp_size(data_shape: Union[int, Tuple[int]]) -> int:
3131
return data_shape
3232

3333

34-
def concat_obs_actions(observations: Tensor, actions: Optional[Tensor]) -> T.Tensor:
34+
def preprocess_inputs(observations: Tensor, actions: Optional[Tensor]) -> T.Tensor:
35+
input = T.as_tensor(observations)
36+
if input.dim() == 0:
37+
input = input.unsqueeze(0)
3538
if actions is not None:
36-
observations, actions = to_torch(observations, actions)
37-
if observations.dim() == 0:
38-
observations = observations.unsqueeze(0)
39+
actions = T.as_tensor(actions)
3940
if actions.dim() == 0:
4041
actions = actions.unsqueeze(0)
41-
return T.cat([observations, actions], dim=-1).float()
42-
return to_torch(observations).float()
42+
input = T.cat([input, actions], dim=-1)
43+
return input.float().to(settings.DEVICE, non_blocking=True)

pearll/updaters/evolution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional
33

44
import numpy as np
5+
import torch as T
56
from gym.spaces import Discrete, MultiDiscrete
67
from torch.distributions import Normal, kl_divergence
78

@@ -11,7 +12,6 @@
1112
SelectionFunc,
1213
UpdaterLog,
1314
)
14-
from pearll.common.utils import to_torch
1515
from pearll.models.actor_critics import ActorCritic
1616

1717

@@ -86,7 +86,7 @@ def __call__(
8686
"""
8787
# Snapshot current population dist for kl divergence
8888
# use copy() to avoid modifying the original
89-
old_dist = Normal(to_torch(self.mean.copy()), self.std)
89+
old_dist = Normal(T.from_numpy(self.mean.copy()), self.std)
9090

9191
# Main update
9292
self.mean += learning_rate * optimization_direction
@@ -104,7 +104,7 @@ def __call__(
104104
self.update_networks(population)
105105

106106
# Calculate Log metrics
107-
new_dist = Normal(to_torch(self.mean), self.std)
107+
new_dist = Normal(T.from_numpy(self.mean), self.std)
108108
population_entropy = new_dist.entropy().mean()
109109
population_kl = kl_divergence(old_dist, new_dist).mean()
110110

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pearll"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
description = "Adaptable tools to make reinforcement learning and evolutionary computation algorithms"
55
license = "MIT"
66
authors = ["Rohan Tangri <rohan.tangri@gmail.com>"]

tests/test_her.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,11 @@ def test_her_sample(goal_selection_strategy, buffer_size):
291291
observations[pos] = next_observations[pos - 1]
292292

293293
trajectories = buffer.sample(4)
294-
sampled_observations = trajectories.observations["observation"]
295-
sampled_next_observations = trajectories.next_observations["observation"]
296-
her_sampled_goals = trajectories.observations["desired_goal"]
294+
sampled_observations = np.asarray(trajectories.observations["observation"])
295+
sampled_next_observations = np.asarray(
296+
trajectories.next_observations["observation"]
297+
)
298+
her_sampled_goals = np.asarray(trajectories.observations["desired_goal"])
297299
# Check if sampled next observations are actually the next observations
298300
for i, obs in enumerate(sampled_observations):
299301
array_idx = np.where(observations == obs)[0]

0 commit comments

Comments
 (0)