Skip to content

Commit f29dc55

Browse files
committed
fix patch for torch.utils.data.Dataset
1 parent 119002a commit f29dc55

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

torchdrug/patch.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
2-
import sys
32
import inspect
3+
import importlib
44

55
import torch
66
from torch import nn
@@ -13,8 +13,6 @@
1313
from torchdrug import core, data
1414
from torchdrug.core import Registry as R
1515

16-
module = sys.modules[__name__]
17-
1816

1917
class PatchedModule(nn.Module):
2018

@@ -121,36 +119,34 @@ def _get_build_directory(name, verbose):
121119
return build_directory
122120

123121

124-
nn._Module = nn.Module
125-
nn.Module = PatchedModule
126-
127-
nn.parallel._DistributedDataParallel = nn.parallel.DistributedDataParallel
128-
nn.parallel.DistributedDataParallel = PatchedDistributedDataParallel
122+
def patch(module, name, cls):
123+
backup = getattr(module, name)
124+
setattr(module, "_%s" % name, backup)
125+
setattr(module, name, cls)
129126

130-
cpp_extension.__get_build_directory = cpp_extension._get_build_directory
131-
cpp_extension._get_build_directory = _get_build_directory
132127

128+
patch(nn, "Module", PatchedModule)
129+
patch(nn.parallel, "DistributedDataParallel", PatchedDistributedDataParallel)
130+
patch(cpp_extension, "_get_build_directory", _get_build_directory)
133131

134132
Optimizer = optim.Optimizer
135133
for name, cls in inspect.getmembers(optim):
136134
if inspect.isclass(cls) and issubclass(cls, Optimizer):
137-
setattr(optim, "_%s" % name, cls)
138135
cls = core.make_configurable(cls, ignore_args=("params",))
139136
cls = R.register("optim.%s" % name)(cls)
140-
setattr(optim, name, cls)
137+
patch(optim, name, cls)
141138

142139
Scheduler = scheduler._LRScheduler
143140
for name, cls in inspect.getmembers(scheduler):
144141
if inspect.isclass(cls) and issubclass(cls, Scheduler):
145-
setattr(scheduler, "_%s" % name, cls)
146142
cls = core.make_configurable(cls, ignore_args=("optimizer",))
147143
cls = R.register("scheduler.%s" % name)(cls)
148-
setattr(optim, name, cls)
144+
patch(scheduler, name, cls)
149145

150146
Dataset = dataset.Dataset
151147
for name, cls in inspect.getmembers(dataset):
152148
if inspect.isclass(cls) and issubclass(cls, Dataset):
153-
setattr(dataset, "_%s" % name, cls)
154149
cls = core.make_configurable(cls)
155150
cls = R.register("dataset.%s" % name)(cls)
156-
setattr(dataset, name, cls)
151+
patch(dataset, name, cls)
152+
importlib.reload(torch.utils.data)

0 commit comments

Comments
 (0)