Skip to content

Commit f8879b3

Browse files
author
Wen-Tse Chen
committed
fix test w/o gpu bug
1 parent 3af7588 commit f8879b3

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

openrl/envs/nlp/rewards/intent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def __init__(
4141
self.use_model_parallel = False
4242

4343
if intent_model == "builtin_intent":
44+
45+
self._device = "cpu"
46+
self.use_data_parallel = False
47+
4448
from transformers import GPT2Config, GPT2LMHeadModel
4549

4650
class TestTokenizer:
@@ -66,6 +70,7 @@ def __init__(self, input_ids, attention_mask):
6670
self._model = GPT2LMHeadModel(config)
6771

6872
else:
73+
self._device = "cuda"
6974
model_path = data_abs_path(intent_model)
7075
self._tokenizer = AutoTokenizer.from_pretrained(intent_model)
7176
self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
@@ -81,12 +86,10 @@ def __init__(self, input_ids, attention_mask):
8186
with open(ds_config) as file:
8287
ds_config = json.load(file)
8388

84-
self._device = "cuda"
85-
self._model = self._model.to("cuda")
89+
self._model = self._model.to(self._device)
8690
self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
8791
self.use_fp16 = ds_config["fp16"]["enabled"]
8892
else:
89-
self._device = "cuda"
9093
if self.use_model_parallel:
9194
self._model.parallelize()
9295
elif self.use_data_parallel:

openrl/envs/nlp/rewards/kl_penalty.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __init__(
4747

4848
# reference model
4949
if ref_model == "builtin_ref":
50+
51+
self.device = "cpu"
52+
self.use_data_parallel = False
53+
5054
from transformers import GPT2Config, GPT2LMHeadModel
5155

5256
config = GPT2Config()
@@ -77,8 +81,9 @@ def __init__(
7781
elif self.use_data_parallel: # else defaults to data parallel
7882
if self.use_half:
7983
self._ref_net = self._ref_net.half()
80-
self._ref_net = torch.nn.DataParallel(self._ref_net)
81-
self._ref_net = self._ref_net.to(self.device)
84+
else:
85+
self._ref_net = torch.nn.DataParallel(self._ref_net)
86+
self._ref_net = self._ref_net.to(self.device)
8287

8388
# alpha adjustment
8489
self._alpha = 0.2

0 commit comments

Comments
 (0)