Skip to content

Commit 87fe26f

Browse files
authored
Fix for driver-config validation in combination with keypath (#568)
1 parent cb2700c commit 87fe26f

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

src/uwtools/drivers/driver.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,13 @@ def __init__(
5454
}
5555
)
5656
self._config_full: dict = config_input.data
57-
config_intermediate, _ = walk_key_path(self._config_full, key_path or [])
58-
self._platform = config_intermediate.get("platform")
57+
self._config_intermediate, _ = walk_key_path(self._config_full, key_path or [])
5958
try:
60-
self._config: dict = config_intermediate[self.driver_name]
59+
self._config: dict = self._config_intermediate[self.driver_name]
6160
except KeyError as e:
6261
raise UWConfigError("Required '%s' block missing in config" % self.driver_name) from e
6362
if controller:
64-
self._config[STR.rundir] = config_intermediate[controller][STR.rundir]
63+
self._config[STR.rundir] = self._config_intermediate[controller][STR.rundir]
6564
self._validate(schema_file)
6665
dryrun(enable=dry_run)
6766

@@ -200,10 +199,10 @@ def _validate(self, schema_file: Optional[Path] = None) -> None:
200199
:raises: UWConfigError if config fails validation.
201200
"""
202201
if schema_file:
203-
validate_external(schema_file=schema_file, config=self.config_full)
202+
validate_external(schema_file=schema_file, config=self._config_intermediate)
204203
else:
205204
validate_internal(
206-
schema_name=self.driver_name.replace("_", "-"), config=self.config_full
205+
schema_name=self.driver_name.replace("_", "-"), config=self._config_intermediate
207206
)
208207

209208

@@ -390,13 +389,13 @@ def _run_resources(self) -> dict[str, Any]:
390389
"""
391390
Returns platform configuration data.
392391
"""
393-
if not self._platform:
392+
if not (platform := self._config_intermediate.get("platform")):
394393
raise UWConfigError("Required 'platform' block missing in config")
395394
threads = self.config.get(STR.execution, {}).get(STR.threads)
396395
return {
397-
STR.account: self._platform[STR.account],
396+
STR.account: platform[STR.account],
398397
STR.rundir: self.rundir,
399-
STR.scheduler: self._platform[STR.scheduler],
398+
STR.scheduler: platform[STR.scheduler],
400399
STR.stdout: "%s.out" % self._runscript_path.name, # config may override
401400
**({STR.threads: threads} if threads else {}),
402401
**self.config.get(STR.execution, {}).get(STR.batchargs, {}),
@@ -481,10 +480,10 @@ def _validate(self, schema_file: Optional[Path] = None) -> None:
481480
:raises: UWConfigError if config fails validation.
482481
"""
483482
if schema_file:
484-
validate_external(schema_file=schema_file, config=self.config_full)
483+
validate_external(schema_file=schema_file, config=self._config_intermediate)
485484
else:
486485
validate_internal(
487-
schema_name=self.driver_name.replace("_", "-"), config=self.config_full
486+
schema_name=self.driver_name.replace("_", "-"), config=self._config_intermediate
488487
)
489488
validate_internal(schema_name=STR.platform, config=self.config_full)
490489

src/uwtools/tests/drivers/test_driver.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def test_Assets_key_path(config, tmp_path):
219219
config=config_file, dry_run=False, key_path=["foo", "bar"]
220220
)
221221
assert assetsobj.config == config[assetsobj.driver_name]
222-
assert assetsobj._platform == config["platform"]
223222

224223

225224
def test_Assets_leadtime(config):
@@ -236,6 +235,14 @@ def test_Assets_validate(assetsobj, caplog):
236235
assert regex_logged(caplog, "State: Ready")
237236

238237

238+
def test_Assets_validate_key_path(config, controller_schema):
239+
config = {"a": {"b": config}}
240+
with patch.object(ConcreteAssetsTimeInvariant, "_validate", driver.Assets._validate):
241+
assert ConcreteAssetsTimeInvariant(
242+
config=config, key_path=["a", "b"], schema_file=controller_schema
243+
)
244+
245+
239246
@mark.parametrize(
240247
"base_file,update_values,expected",
241248
[
@@ -442,7 +449,7 @@ def test_Driver__namelist_schema_default_disable(driverobj):
442449

443450

444451
def test_Driver__run_resources_fail(driverobj):
445-
driverobj._platform = None
452+
del driverobj._config_intermediate["platform"]
446453
with raises(UWConfigError) as e:
447454
assert driverobj._run_resources
448455
assert str(e.value) == "Required 'platform' block missing in config"

0 commit comments

Comments
 (0)