Skip to content

Commit 8185373

Browse files
authored
deepspeed support
deepspeed support
2 parents 76d1e05 + fc02030 commit 8185373

File tree

16 files changed

+528
-118
lines changed

16 files changed

+528
-118
lines changed

examples/nlp/ds_config.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
"train_micro_batch_size_per_gpu": 16,
44
"steps_per_print": 10,
55
"zero_optimization": {
6-
"stage": 2,
7-
"reduce_bucket_size": 5e7,
8-
"allgather_bucket_size": 5e7
6+
"stage": 2
97
},
108
"fp16": {"enabled": false, "loss_scale_window": 100}
119
}

examples/nlp/nlp_ppo.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@ wandb_entity: "openrl-lab"
99
ppo_epoch: 5
1010
episode_length: 128
1111
num_mini_batch: 20
12-
use_share_model: true
1312

1413
hidden_size: 1
1514

16-
1715
model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
1816
env:
1917
args: {

examples/nlp/nlp_ppo_ds.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ wandb_entity: "openrl-lab"
99
ppo_epoch: 5
1010
episode_length: 128
1111
num_mini_batch: 20
12-
use_share_model: true
1312

1413
hidden_size: 1
1514

examples/nlp/train_ppo.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from openrl.configs.config import create_config_parser
44
from openrl.envs.common import make
55
from openrl.modules.common import PPONet as Net
6-
from openrl.modules.networks.policy_value_network_gpt import (
7-
PolicyValueNetworkGPT as PolicyValueNetwork,
8-
)
6+
from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork
7+
from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork
98
from openrl.runners.common import PPOAgent as Agent
109

1110

@@ -29,7 +28,7 @@ def train():
2928
)
3029

3130
# create the neural network
32-
model_dict = {"model": PolicyValueNetwork}
31+
model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork}
3332
net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict)
3433

3534
# initialize the trainer

openrl/algorithms/ppo.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __init__(
4545

4646
def ppo_update(self, sample, turn_on=True):
4747
for optimizer in self.algo_module.optimizers.values():
48-
optimizer.zero_grad()
48+
if not self.use_deepspeed:
49+
optimizer.zero_grad()
4950

5051
(
5152
critic_obs_batch,
@@ -152,8 +153,15 @@ def ppo_update(self, sample, turn_on=True):
152153

153154
self.algo_module.scaler.update()
154155
else:
155-
for optimizer in self.algo_module.optimizers.values():
156-
optimizer.step()
156+
if self.use_deepspeed:
157+
if self._use_share_model:
158+
self.algo_module.optimizers["model"].step()
159+
else:
160+
self.algo_module.optimizers["policy"].step()
161+
self.algo_module.optimizers["critic"].step()
162+
else:
163+
for optimizer in self.algo_module.optimizers.values():
164+
optimizer.step()
157165

158166
if self.world_size > 1:
159167
torch.cuda.synchronize()

openrl/envs/nlp/daily_dialog_env.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,16 @@ def __init__(
7272
# set the observation and action space here
7373
self._vocab_size = self.tokenizer.vocab_size
7474

75-
self.observation_space = DictSpace(
76-
{
77-
"input_encoded_pt": spaces.Box(
78-
low=0,
79-
high=self._vocab_size,
80-
shape=(self._max_text_length + self.max_steps,),
81-
),
82-
"input_attention_mask_pt": spaces.Box(
83-
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
84-
),
85-
}
86-
)
75+
self.observation_space = DictSpace({
76+
"input_encoded_pt": spaces.Box(
77+
low=0,
78+
high=self._vocab_size,
79+
shape=(self._max_text_length + self.max_steps,),
80+
),
81+
"input_attention_mask_pt": spaces.Box(
82+
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
83+
),
84+
})
8785
self.action_space = Discrete(n=self._vocab_size)
8886
# see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency
8987

@@ -113,7 +111,8 @@ def __init__(
113111
self.__time_step = None
114112
self.reward_function = None
115113

116-
def set_reward(self, reward_fn):
114+
def set_reward(self, reward_fn=None):
115+
117116
self.reward_function = reward_fn
118117

119118
def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:

openrl/envs/nlp/fake_dialog_env.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,16 @@ def __init__(
3030
# set the observation and action space here
3131
self._vocab_size = 2
3232

33-
self.observation_space = DictSpace(
34-
{
35-
"input_encoded_pt": spaces.Box(
36-
low=0,
37-
high=self._vocab_size,
38-
shape=(self._max_text_length + self.max_steps,),
39-
),
40-
"input_attention_mask_pt": spaces.Box(
41-
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
42-
),
43-
}
44-
)
33+
self.observation_space = DictSpace({
34+
"input_encoded_pt": spaces.Box(
35+
low=0,
36+
high=self._vocab_size,
37+
shape=(self._max_text_length + self.max_steps,),
38+
),
39+
"input_attention_mask_pt": spaces.Box(
40+
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
41+
),
42+
})
4543
self.action_space = Discrete(n=self._vocab_size)
4644

4745
n = 2

openrl/envs/nlp/rewards/intent.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,15 @@ def __init__(
3636

3737
self._intent_coeff = intent_coeff
3838
self.use_deepspeed = use_deepspeed
39+
self.use_half = False
40+
self.use_data_parallel = not use_deepspeed # default to use data parallel
41+
self.use_model_parallel = False
42+
3943
if intent_model == "builtin_intent":
44+
45+
self._device = "cpu"
46+
self.use_data_parallel = False
47+
4048
from transformers import GPT2Config, GPT2LMHeadModel
4149

4250
class TestTokenizer:
@@ -62,6 +70,7 @@ def __init__(self, input_ids, attention_mask):
6270
self._model = GPT2LMHeadModel(config)
6371

6472
else:
73+
self._device = "cuda"
6574
model_path = data_abs_path(intent_model)
6675
self._tokenizer = AutoTokenizer.from_pretrained(intent_model)
6776
self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
@@ -77,19 +86,17 @@ def __init__(self, input_ids, attention_mask):
7786
with open(ds_config) as file:
7887
ds_config = json.load(file)
7988

80-
self._device = "cuda"
81-
self._model = self._model.to("cuda")
89+
self._model = self._model.to(self._device)
8290
self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
91+
self.use_fp16 = ds_config["fp16"]["enabled"]
8392
else:
84-
if torch.cuda.is_available():
85-
manager = LocalGPUManager()
86-
manager.log_info()
87-
self._device = f"cuda:{manager.get_gpu()}"
88-
else:
89-
self._device = "cpu"
90-
print("Intent Model choose to use device:{}".format(self._device))
91-
92-
self._model = self._model.to(self._device)
93+
if self.use_model_parallel:
94+
self._model.parallelize()
95+
elif self.use_data_parallel:
96+
if self.use_half:
97+
self._model = self._model.half()
98+
self._model = torch.nn.DataParallel(self._model)
99+
self._model = self._model.to(self._device)
93100

94101
def __call__(
95102
self,
@@ -120,6 +127,13 @@ def get_input_for_classifier(prompt, generated_text):
120127
input_texts, return_tensors="pt", truncation=True, padding=True
121128
)
122129

130+
if self.use_half:
131+
encoded.input_ids = encoded.input_ids.int()
132+
encoded.attention_mask = encoded.attention_mask.int()
133+
else:
134+
encoded.input_ids = encoded.input_ids.long()
135+
encoded.attention_mask = encoded.attention_mask.long()
136+
123137
with torch.no_grad():
124138
outputs = self._model(
125139
input_ids=encoded.input_ids.to(self._device),

openrl/envs/nlp/rewards/kl_penalty.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,22 @@ def __init__(
3535
ds_config: str = "default",
3636
):
3737
super().__init__()
38+
39+
self.device = "cuda"
3840
self.use_deepspeed = use_deepspeed
41+
self.use_half = False
42+
self.use_data_parallel = not use_deepspeed
43+
self.use_model_parallel = False
44+
assert not (self.use_deepspeed and self.use_data_parallel)
45+
assert not (self.use_deepspeed and self.use_model_parallel)
46+
assert not (self.use_data_parallel and self.use_model_parallel)
3947

4048
# reference model
41-
self._apply_model_parallel = apply_model_parallel
4249
if ref_model == "builtin_ref":
50+
51+
self.device = "cpu"
52+
self.use_data_parallel = False
53+
4354
from transformers import GPT2Config, GPT2LMHeadModel
4455

4556
config = GPT2Config()
@@ -64,11 +75,15 @@ def __init__(
6475
self.use_fp16 = False
6576

6677
self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config)
67-
elif torch.cuda.is_available():
68-
if self._apply_model_parallel and self._ref_net.is_parallelizable:
78+
else:
79+
if self.use_model_parallel:
6980
self._ref_net.parallelize()
70-
else: # else defaults to data parallel
71-
self._ref_net = torch.nn.DataParallel(self._ref_net)
81+
elif self.use_data_parallel: # else defaults to data parallel
82+
if self.use_half:
83+
self._ref_net = self._ref_net.half()
84+
else:
85+
self._ref_net = torch.nn.DataParallel(self._ref_net)
86+
self._ref_net = self._ref_net.to(self.device)
7287

7388
# alpha adjustment
7489
self._alpha = 0.2
@@ -106,32 +121,35 @@ def __call__(
106121
self._ref_net, input_ids, past_model_kwargs
107122
)
108123

109-
if self.use_deepspeed:
110-
if self.use_fp16:
111-
for key in ["input_ids", "position_ids"]:
112-
model_inputs[key] = model_inputs[key].half().int()
113-
for key in ["attention_mask"]:
114-
model_inputs[key] = model_inputs[key].half()
124+
if self.use_half:
125+
for key in ["input_ids", "position_ids", "attention_mask"]:
126+
if key in model_inputs:
127+
model_inputs[key] = model_inputs[key].int()
128+
else:
129+
for key in ["input_ids", "position_ids", "attention_mask"]:
130+
if key in model_inputs:
131+
model_inputs[key] = model_inputs[key].long()
115132

116133
with torch.no_grad():
117134
output = self._ref_net(output_hidden_states=True, **model_inputs)
118135
output["past_key_values"] = None
119136
next_token_logits = output.logits[:, -1, :]
137+
if self.use_deepspeed and self.use_fp16:
138+
next_token_logits = next_token_logits.double()
120139
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
121140
action_input = actions.to(next_token_logits.device)
122141
ref_log_prob = dist.log_prob(action_input)
123142

124143
ref_log_prob = ref_log_prob.reshape(action_log_probs.shape)
144+
125145
kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy()
126146
rew = -self._alpha * kl_div
127147
infos = []
128148
for kl in kl_div:
129-
infos.append(
130-
{
131-
"alpha": self._alpha,
132-
"kl_div": kl.mean(),
133-
}
134-
)
149+
infos.append({
150+
"alpha": self._alpha,
151+
"kl_div": kl.mean(),
152+
})
135153
return rew, infos
136154

137155
def _prepare_inputs_for_model(
@@ -144,7 +162,7 @@ def _prepare_inputs_for_model(
144162
input_ids, **model_kwargs
145163
)
146164

147-
if self._apply_model_parallel and unwrap_model(model).is_parallelizable:
165+
if self.use_model_parallel:
148166
# if model is in parallel mode, move the tensors to the first device
149167
model_inputs = {
150168
key: (
@@ -155,8 +173,12 @@ def _prepare_inputs_for_model(
155173
)
156174
for key, value in model_inputs.items()
157175
}
158-
159-
if self.use_deepspeed:
176+
elif self.use_data_parallel:
177+
model_inputs = {
178+
key: value.to(self.device) if isinstance(value, torch.Tensor) else value
179+
for key, value in model_inputs.items()
180+
}
181+
elif self.use_deepspeed:
160182
model_inputs = {
161183
key: value.to("cuda") if isinstance(value, torch.Tensor) else value
162184
for key, value in model_inputs.items()

openrl/envs/nlp/utils/metrics/meteor.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,16 @@ def _info(self):
8888
citation=_CITATION,
8989
inputs_description=_KWARGS_DESCRIPTION,
9090
features=[
91-
datasets.Features(
92-
{
93-
"predictions": datasets.Value("string", id="sequence"),
94-
"references": datasets.Sequence(
95-
datasets.Value("string", id="sequence"), id="references"
96-
),
97-
}
98-
),
99-
datasets.Features(
100-
{
101-
"predictions": datasets.Value("string", id="sequence"),
102-
"references": datasets.Value("string", id="sequence"),
103-
}
104-
),
91+
datasets.Features({
92+
"predictions": datasets.Value("string", id="sequence"),
93+
"references": datasets.Sequence(
94+
datasets.Value("string", id="sequence"), id="references"
95+
),
96+
}),
97+
datasets.Features({
98+
"predictions": datasets.Value("string", id="sequence"),
99+
"references": datasets.Value("string", id="sequence"),
100+
}),
105101
],
106102
codebase_urls=[
107103
"https://github.com/nltk/nltk/blob/develop/nltk/translate/meteor_score.py"

openrl/envs/vec_env/wrappers/reward_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class RewardWrapper(VecEnvWrapper):
2929
def __init__(self, env: BaseVecEnv, reward_class: BaseReward):
3030
super().__init__(env)
3131
self.reward_class = reward_class
32-
if len(self.reward_class.inner_rew_funcs) > 0:
33-
env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})
32+
# if len(self.reward_class.inner_rew_funcs) > 0:
33+
# env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})
3434

3535
def step(
3636
self, action: ActType, extra_data: Optional[Dict[str, Any]]

0 commit comments

Comments
 (0)