1
1
import os
2
- import sys
3
2
import inspect
3
+ import importlib
4
4
5
5
import torch
6
6
from torch import nn
13
13
from torchdrug import core , data
14
14
from torchdrug .core import Registry as R
15
15
16
- module = sys .modules [__name__ ]
17
-
18
16
19
17
class PatchedModule (nn .Module ):
20
18
@@ -121,36 +119,34 @@ def _get_build_directory(name, verbose):
121
119
return build_directory
122
120
123
121
124
- nn ._Module = nn .Module
125
- nn .Module = PatchedModule
126
-
127
- nn .parallel ._DistributedDataParallel = nn .parallel .DistributedDataParallel
128
- nn .parallel .DistributedDataParallel = PatchedDistributedDataParallel
122
+ def patch (module , name , cls ):
123
+ backup = getattr (module , name )
124
+ setattr (module , "_%s" % name , backup )
125
+ setattr (module , name , cls )
129
126
130
- cpp_extension .__get_build_directory = cpp_extension ._get_build_directory
131
- cpp_extension ._get_build_directory = _get_build_directory
132
127
128
+ patch (nn , "Module" , PatchedModule )
129
+ patch (nn .parallel , "DistributedDataParallel" , PatchedDistributedDataParallel )
130
+ patch (cpp_extension , "_get_build_directory" , _get_build_directory )
133
131
134
132
Optimizer = optim .Optimizer
135
133
for name , cls in inspect .getmembers (optim ):
136
134
if inspect .isclass (cls ) and issubclass (cls , Optimizer ):
137
- setattr (optim , "_%s" % name , cls )
138
135
cls = core .make_configurable (cls , ignore_args = ("params" ,))
139
136
cls = R .register ("optim.%s" % name )(cls )
140
- setattr (optim , name , cls )
137
+ patch (optim , name , cls )
141
138
142
139
Scheduler = scheduler ._LRScheduler
143
140
for name , cls in inspect .getmembers (scheduler ):
144
141
if inspect .isclass (cls ) and issubclass (cls , Scheduler ):
145
- setattr (scheduler , "_%s" % name , cls )
146
142
cls = core .make_configurable (cls , ignore_args = ("optimizer" ,))
147
143
cls = R .register ("scheduler.%s" % name )(cls )
148
- setattr ( optim , name , cls )
144
+ patch ( scheduler , name , cls )
149
145
150
146
Dataset = dataset .Dataset
151
147
for name , cls in inspect .getmembers (dataset ):
152
148
if inspect .isclass (cls ) and issubclass (cls , Dataset ):
153
- setattr (dataset , "_%s" % name , cls )
154
149
cls = core .make_configurable (cls )
155
150
cls = R .register ("dataset.%s" % name )(cls )
156
- setattr (dataset , name , cls )
151
+ patch (dataset , name , cls )
152
+ importlib .reload (torch .utils .data )
0 commit comments