Skip to content

Commit b411837

Browse files
committed
make register_buffer consistent with PyTorch 1.6.0+
1 parent f29dc55 commit b411837

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

torchdrug/patch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def device(self):
6161
tensor = next(self.buffers())
6262
return tensor.device
6363

64-
def register_buffer(self, name, tensor):
64+
def register_buffer(self, name, tensor, persistent=True):
65+
if persistent is False and isinstance(self, torch.jit.ScriptModule):
66+
raise RuntimeError("ScriptModule does not support non-persistent buffers")
67+
6568
if '_buffers' not in self.__dict__:
6669
raise AttributeError(
6770
"cannot assign buffer before Module.__init__() call")
@@ -80,6 +83,10 @@ def register_buffer(self, name, tensor):
8083
.format(torch.typename(tensor), name))
8184
else:
8285
self._buffers[name] = tensor
86+
if persistent:
87+
self._non_persistent_buffers_set.discard(name)
88+
else:
89+
self._non_persistent_buffers_set.add(name)
8390

8491

8592
class PatchedDistributedDataParallel(nn.parallel.DistributedDataParallel):

0 commit comments

Comments
 (0)