Skip to content

Commit 1433482

Browse files
fix mem leakage (#10344)
1 parent c3c6695 commit 1433482

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

paddlenlp/transformers/token_dispatcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,11 @@ def __init__(self, token_dispatcher):
228228
self.probs_origin_shape = None
229229

230230
def reset_status(self):
231-
self.prob = None
231+
self.probs = None
232232
self.reshaped_probs = None
233233
self.token_indices = None
234234

235+
@paddle.no_grad()
235236
def forward(self, routing_map, probs):
236237
num_tokens = routing_map.shape[0]
237238
self.probs_origin_shape = probs.shape
@@ -243,8 +244,10 @@ def forward(self, routing_map, probs):
243244
reshaped_probs, self.token_dispatcher._comm_manager.router_topk, axis=-1
244245
)
245246
self.token_indices = token_indices
247+
token_probs.stop_gradient = False
246248
return token_indices, token_probs
247249

250+
@paddle.no_grad()
248251
def backward(self, token_probs_g):
249252
probs_grad = paddle._C_ops.topk_grad(
250253
self.reshaped_probs,

0 commit comments

Comments
 (0)