Skip to content

Commit 2dffbd3

Browse files
committed
Work on module path + schema file connection
1 parent 52b7ee4 commit 2dffbd3

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

src/uwtools/api/driver.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def execute(
6161
:param stdin_ok: OK to read from stdin?
6262
:return: ``True`` if task completes without raising an exception.
6363
"""
64-
if not (class_ := _get_driver_class(module, classname)):
64+
class_, module_path = _get_driver_class(module, classname)
65+
if not class_:
6566
return False
67+
assert module_path is not None
6668
args = dict(locals())
6769
accepted = set(getfullargspec(class_).args)
6870
non_optional = {STR.cycle, STR.leadtime}
@@ -78,7 +80,7 @@ def execute(
7880
config=ensure_data_source(config, bool(stdin_ok)),
7981
dry_run=dry_run,
8082
key_path=key_path,
81-
schema_file=schema_file or Path(module).with_suffix(".jsonschema"),
83+
schema_file=schema_file or module_path.with_suffix(".jsonschema"),
8284
)
8385
required = non_optional & accepted
8486
for arg in sorted([STR.batch, *required]):
@@ -93,46 +95,50 @@ def execute(
9395
return True
9496

9597

96-
def tasks(module: str, classname: str) -> dict[str, str]:
98+
def tasks(module: Union[Path, str], classname: str) -> dict[str, str]:
9799
"""
98100
Returns a mapping from task names to their one-line descriptions.
99101
100102
:param module: Name of driver module.
101103
:param classname: Name of driver class to instantiate.
102104
"""
103-
if not (class_ := _get_driver_class(module, classname)):
105+
class_, _ = _get_driver_class(module, classname)
106+
if not class_:
104107
log.error("Could not get tasks from class %s in module %s", classname, module)
105108
return {}
106109
return _tasks(class_)
107110

108111

109-
def _get_driver_class(module: Union[Path, str], classname: str) -> Optional[Type]:
112+
def _get_driver_class(
113+
module: Union[Path, str], classname: str
114+
) -> tuple[Optional[Type], Optional[Path]]:
110115
"""
111116
Returns the driver class.
112117
113118
:param module: Name of driver module to load.
114119
:param classname: Name of driver class to instantiate.
115120
"""
116-
module = str(module)
117-
if not (m := _get_driver_module_explicit(module)):
118-
if not (m := _get_driver_module_implicit(module)):
121+
if not (m := _get_driver_module_explicit(Path(module))):
122+
if not (m := _get_driver_module_implicit(str(module))):
119123
log.error("Could not load module %s", module)
120-
return None
124+
return None, None
125+
assert m.__file__ is not None
126+
module_path = Path(m.__file__)
121127
if hasattr(m, classname):
122128
c: Type = getattr(m, classname)
123-
return c
129+
return c, module_path
124130
log.error("Module %s has no class %s", module, classname)
125-
return None
131+
return None, module_path
126132

127133

128-
def _get_driver_module_explicit(module: str) -> Optional[ModuleType]:
134+
def _get_driver_module_explicit(module: Path) -> Optional[ModuleType]:
129135
"""
130136
Returns the named module found via explicit lookup of given path.
131137
132138
:param module: Name of driver module to load.
133139
"""
134140
log.debug("Loading module %s", module)
135-
if spec := spec_from_file_location(Path(module).name, module):
141+
if spec := spec_from_file_location(module.name, module):
136142
m = module_from_spec(spec)
137143
if loader := spec.loader:
138144
try:
@@ -150,8 +156,8 @@ def _get_driver_module_implicit(module: str) -> Optional[ModuleType]:
150156
151157
:param module: Name of driver module to load.
152158
"""
159+
log.debug("Loading module %s from sys.path", module)
153160
try:
154-
log.debug("Loading module %s from sys.path", module)
155161
return import_module(module)
156162
except Exception: # pylint: disable=broad-exception-caught
157163
return None

src/uwtools/tests/api/test_driver.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,82 +117,86 @@ def test_tasks_fail_no_cycle(args, caplog, kwargs):
117117
assert logged(caplog, "%s requires argument '%s'" % (args.classname, "cycle"))
118118

119119

120-
def test_tasks_pass(args):
121-
tasks = driver_api.tasks(classname=args.classname, module=args.module)
120+
@mark.parametrize("f", [Path, str])
121+
def test_tasks_pass(args, f):
122+
tasks = driver_api.tasks(classname=args.classname, module=f(args.module))
122123
assert tasks["eighty_eight"] == "88"
123124

124125

125126
def test__get_driver_class_explicit_fail_bad_class(caplog, args):
126127
log.setLevel(logging.DEBUG)
127128
bad_class = "BadClass"
128-
c = driver_api._get_driver_class(classname=bad_class, module=args.module)
129+
c, module_path = driver_api._get_driver_class(classname=bad_class, module=args.module)
129130
assert c is None
131+
assert module_path == args.module
130132
assert logged(caplog, "Module %s has no class %s" % (args.module, bad_class))
131133

132134

133135
def test__get_driver_class_explicit_fail_bad_name(caplog, args):
134136
log.setLevel(logging.DEBUG)
135-
bad_name = "bad_name"
136-
c = driver_api._get_driver_class(classname=args.classname, module=bad_name)
137+
bad_name = Path("bad_name")
138+
c, module_path = driver_api._get_driver_class(classname=args.classname, module=bad_name)
137139
assert c is None
140+
assert module_path is None
138141
assert logged(caplog, "Could not load module %s" % bad_name)
139142

140143

141144
def test__get_driver_class_explicit_fail_bad_path(caplog, args, tmp_path):
142145
log.setLevel(logging.DEBUG)
143146
module = tmp_path / "not.py"
144-
c = driver_api._get_driver_class(classname=args.classname, module=module)
147+
c, module_path = driver_api._get_driver_class(classname=args.classname, module=module)
145148
assert c is None
149+
assert module_path is None
146150
assert logged(caplog, "Could not load module %s" % module)
147151

148152

149153
def test__get_driver_class_explicit_fail_bad_spec(caplog, args):
150154
log.setLevel(logging.DEBUG)
151155
with patch.object(driver_api, "spec_from_file_location", return_value=None):
152-
c = driver_api._get_driver_class(classname=args.classname, module=args.module)
156+
c, module_path = driver_api._get_driver_class(classname=args.classname, module=args.module)
153157
assert c is None
158+
assert module_path is None
154159
assert logged(caplog, "Could not load module %s" % args.module)
155160

156161

157162
def test__get_driver_class_explicit_pass(args):
158163
log.setLevel(logging.DEBUG)
159-
c = driver_api._get_driver_class(classname=args.classname, module=args.module)
164+
c, module_path = driver_api._get_driver_class(classname=args.classname, module=args.module)
160165
assert c
161166
assert c.__name__ == "TestDriver"
167+
assert module_path == args.module
162168

163169

164170
def test__get_driver_class_implicit_pass(args):
165171
log.setLevel(logging.DEBUG)
166172
with patch.object(Path, "cwd", return_value=fixture_path()):
167-
c = driver_api._get_driver_class(classname=args.classname, module=args.module)
168-
assert c
169-
assert c.__name__ == "TestDriver"
173+
c, module_path = driver_api._get_driver_class(classname=args.classname, module=args.module)
174+
assert c
175+
assert c.__name__ == "TestDriver"
176+
assert module_path == args.module
170177

171178

172179
def test__get_driver_module_explicit_absolute_fail(args):
173180
assert args.module.is_absolute()
174-
module = str(args.module.with_suffix(".bad"))
181+
module = args.module.with_suffix(".bad")
175182
assert not driver_api._get_driver_module_explicit(module=module)
176183

177184

178185
def test__get_driver_module_explicit_absolute_pass(args):
179186
assert args.module.is_absolute()
180-
module = str(args.module)
181-
assert driver_api._get_driver_module_explicit(module=module)
187+
assert driver_api._get_driver_module_explicit(module=args.module)
182188

183189

184190
def test__get_driver_module_explicit_relative_fail(args):
185191
args.module = Path(os.path.relpath(args.module)).with_suffix(".bad")
186192
assert not args.module.is_absolute()
187-
module = str(args.module)
188-
assert not driver_api._get_driver_module_explicit(module=module)
193+
assert not driver_api._get_driver_module_explicit(module=args.module)
189194

190195

191196
def test__get_driver_module_explicit_relative_pass(args):
192197
args.module = Path(os.path.relpath(args.module))
193198
assert not args.module.is_absolute()
194-
module = str(args.module)
195-
assert driver_api._get_driver_module_explicit(module=module)
199+
assert driver_api._get_driver_module_explicit(module=args.module)
196200

197201

198202
def test__get_driver_module_implicit_pass_full_package():

0 commit comments

Comments
 (0)