1
1
import numpy as np
2
- from gymnasium .wrappers import FlattenObservation
3
2
import torch
3
+ from gymnasium .wrappers import FlattenObservation
4
4
5
5
from openrl .configs .config import create_config_parser
6
6
from openrl .envs .common import make
7
7
from openrl .envs .wrappers .base_wrapper import BaseWrapper
8
- from openrl .envs .wrappers .extra_wrappers import FrameSkip , GIFWrapper ,ConvertEmptyBoxWrapper
8
+ from openrl .envs .wrappers .extra_wrappers import (
9
+ ConvertEmptyBoxWrapper ,
10
+ FrameSkip ,
11
+ GIFWrapper ,
12
+ )
9
13
from openrl .modules .common import PPONet as Net
10
14
from openrl .runners .common import PPOAgent as Agent
11
15
12
-
13
16
env_name = "dm_control/cartpole-balance-v0"
14
17
# env_name = "dm_control/walker-walk-v0"
15
18
@@ -25,7 +28,7 @@ def train():
25
28
env_name ,
26
29
env_num = env_num ,
27
30
asynchronous = True ,
28
- env_wrappers = [FrameSkip , FlattenObservation ,ConvertEmptyBoxWrapper ],
31
+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
29
32
)
30
33
31
34
net = Net (env , cfg = cfg , device = "cuda" if torch .cuda .is_available () else "cpu" )
@@ -52,10 +55,10 @@ def evaluation():
52
55
render_mode = render_mode ,
53
56
env_num = 4 ,
54
57
asynchronous = True ,
55
- env_wrappers = [FrameSkip , FlattenObservation ,ConvertEmptyBoxWrapper ],
58
+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
56
59
)
57
60
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
58
- # env = GIFWrapper(env, gif_path="./new.gif", fps=5)
61
+ env = GIFWrapper (env , gif_path = "./new.gif" , fps = 5 )
59
62
60
63
net = Net (env , cfg = cfg , device = "cuda" if torch .cuda .is_available () else "cpu" )
61
64
# initialize the trainer
0 commit comments