Skip to content

Commit 0491ad4

Browse files
committed
merge dm_control to gymnasium
1 parent 340e59e commit 0491ad4

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

examples/dm_control/train_ppo.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import numpy as np
2-
from gymnasium.wrappers import FlattenObservation
32
import torch
3+
from gymnasium.wrappers import FlattenObservation
44

55
from openrl.configs.config import create_config_parser
66
from openrl.envs.common import make
77
from openrl.envs.wrappers.base_wrapper import BaseWrapper
8-
from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper,ConvertEmptyBoxWrapper
8+
from openrl.envs.wrappers.extra_wrappers import (
9+
ConvertEmptyBoxWrapper,
10+
FrameSkip,
11+
GIFWrapper,
12+
)
913
from openrl.modules.common import PPONet as Net
1014
from openrl.runners.common import PPOAgent as Agent
1115

12-
1316
env_name = "dm_control/cartpole-balance-v0"
1417
# env_name = "dm_control/walker-walk-v0"
1518

@@ -25,7 +28,7 @@ def train():
2528
env_name,
2629
env_num=env_num,
2730
asynchronous=True,
28-
env_wrappers=[FrameSkip, FlattenObservation,ConvertEmptyBoxWrapper],
31+
env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper],
2932
)
3033

3134
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
@@ -52,10 +55,10 @@ def evaluation():
5255
render_mode=render_mode,
5356
env_num=4,
5457
asynchronous=True,
55-
env_wrappers=[FrameSkip, FlattenObservation,ConvertEmptyBoxWrapper],
58+
env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper],
5659
)
5760
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
58-
# env = GIFWrapper(env, gif_path="./new.gif", fps=5)
61+
env = GIFWrapper(env, gif_path="./new.gif", fps=5)
5962

6063
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
6164
# initialize the trainer

0 commit comments

Comments
 (0)