Skip to content

Commit a908e29

Browse files
authored
fix minor (#62)
1 parent 1d523c9 commit a908e29

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

jorldy/core/agent/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
class DDPG(BaseAgent):
15+
action_type = "continuous"
1516
"""Deep deterministic policy gradient (DDPG) agent.
1617
1718
Args:
@@ -65,7 +66,6 @@ def __init__(
6566
if device
6667
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
6768
)
68-
self.action_type = "continuous"
6969

7070
self.actor = Network(
7171
actor, state_size, action_size, D_hidden=hidden_size, head=head

jorldy/core/agent/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
class DQN(BaseAgent):
15+
action_type = "discrete"
1516
"""DQN agent.
1617
1718
Args:
@@ -57,7 +58,6 @@ def __init__(
5758
num_workers=1,
5859
**kwargs,
5960
):
60-
6161
self.device = (
6262
torch.device(device)
6363
if device

jorldy/core/agent/ppo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class PPO(REINFORCE):
2525

2626
def __init__(
2727
self,
28+
network="discrete_policy_value",
2829
batch_size=32,
2930
n_step=128,
3031
n_epoch=3,
@@ -36,7 +37,7 @@ def __init__(
3637
num_workers=1,
3738
**kwargs,
3839
):
39-
super(PPO, self).__init__(**kwargs)
40+
super(PPO, self).__init__(network=network, **kwargs)
4041

4142
self.batch_size = batch_size
4243
self.n_step = n_step

jorldy/core/agent/vmpo.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class VMPO(REINFORCE):
3131

3232
def __init__(
3333
self,
34+
network="discrete_policy_value",
3435
optim_config={"name": "adam"},
3536
batch_size=32,
3637
n_step=128,
@@ -49,7 +50,11 @@ def __init__(
4950
alpha_sigma=1.0,
5051
**kwargs,
5152
):
52-
super(VMPO, self).__init__(optim_config=optim_config, **kwargs)
53+
super(VMPO, self).__init__(
54+
network=network,
55+
optim_config=optim_config,
56+
**kwargs,
57+
)
5358

5459
self.batch_size = batch_size
5560
self.n_step = n_step

0 commit comments

Comments
 (0)