Skip to content

Commit 0092356

Browse files
authored
Merge pull request #254 from huangshiyu13/main
update test
2 parents b735781 + 090b617 commit 0092356

File tree

3 files changed

+85
-234
lines changed

3 files changed

+85
-234
lines changed

openrl/envs/snake/common.py

Lines changed: 0 additions & 227 deletions
This file was deleted.

openrl/modules/networks/utils/nlp/causal_policy.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,33 @@ def policy(self):
6565

6666
def _build_model_heads(self, model_name: str, config: str, device: str):
6767
if self.disable_drop_out:
68-
config = AutoConfig.from_pretrained(model_name)
68+
if model_name == "test_gpt2":
69+
from transformers import GPT2Config
70+
71+
config = GPT2Config()
72+
73+
else:
74+
config = AutoConfig.from_pretrained(model_name)
6975
config_dict = config.to_dict()
7076
for key in config_dict:
7177
if "drop" in key:
7278
config_dict[key] = 0.0
7379
config = config.from_dict(config_dict)
7480

75-
self._policy_model = AutoModelForCausalLM.from_pretrained(
76-
model_name, config=config
77-
)
81+
if model_name == "test_gpt2":
82+
from transformers import GPT2LMHeadModel
7883

79-
self._value_model = AutoModelForCausalLM.from_pretrained(
80-
model_name, config=config
81-
)
84+
self._policy_model = GPT2LMHeadModel(config)
85+
self._value_model = GPT2LMHeadModel(config)
86+
87+
else:
88+
self._policy_model = AutoModelForCausalLM.from_pretrained(
89+
model_name, config=config
90+
)
91+
92+
self._value_model = AutoModelForCausalLM.from_pretrained(
93+
model_name, config=config
94+
)
8295

8396
self._value_head = nn.Linear(
8497
self._value_model.config.hidden_size, 1, bias=False
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import numpy as np
23+
import pytest
24+
from gymnasium import spaces
25+
26+
from openrl.configs.config import create_config_parser
27+
from openrl.modules.networks.policy_value_network_gpt import (
28+
PolicyValueNetworkGPT as PolicyValueNetwork,
29+
)
30+
31+
32+
@pytest.fixture(scope="module", params=["--model_path test_gpt2"])
33+
def config(request):
34+
cfg_parser = create_config_parser()
35+
cfg = cfg_parser.parse_args(request.param.split())
36+
return cfg
37+
38+
39+
@pytest.mark.unittest
40+
def test_gpt_network(config):
41+
net = PolicyValueNetwork(
42+
cfg=config,
43+
input_space=spaces.Discrete(2),
44+
action_space=spaces.Discrete(2),
45+
)
46+
47+
net.get_actor_para()
48+
net.get_critic_para()
49+
50+
obs = {
51+
"input_encoded_pt": np.zeros([1, 2]),
52+
"input_attention_mask_pt": np.zeros([1, 2]),
53+
}
54+
rnn_states = np.zeros(2)
55+
masks = np.zeros(2)
56+
action = np.zeros(1)
57+
net.get_actions(obs=obs, rnn_states=rnn_states, masks=masks)
58+
net.eval_actions(
59+
obs=obs, rnn_states=rnn_states, action=action, masks=masks, action_masks=None
60+
)
61+
net.get_values(obs=obs, rnn_states=rnn_states, masks=masks)
62+
63+
64+
if __name__ == "__main__":
65+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)