diff --git a/repro.py b/repro.py new file mode 100644 index 00000000..f59e31b4 --- /dev/null +++ b/repro.py @@ -0,0 +1,47 @@ +from typing import Union +import dataclasses +import simple_parsing +import yaml +from pathlib import Path + +@dataclasses.dataclass +class ModelTypeA: + model_a_param: str = "default_a" + +@dataclasses.dataclass +class ModelTypeB: + model_b_param: str = "default_b" + +@dataclasses.dataclass +class TrainConfig: + model_type: Union[ModelTypeA, ModelTypeB] = simple_parsing.subgroups( + {"type_a": ModelTypeA, "type_b": ModelTypeB}, + default_factory=ModelTypeA, + positional=False, + ) + +def main(): + # Create a config file + config_path = Path("repro_subgroup_minimal.yaml") + config = { + "model_a_param": "test" # This should work but fails + } + with config_path.open('w') as f: + yaml.dump(config, f) + + print("\nTrying with config file:") + try: + # This fails with: + # RuntimeError: ['model_a_param'] are not fields of at path 'config'! + args = simple_parsing.parse(TrainConfig, add_config_path_arg=True, args=['--config_path', 'repro_subgroup_minimal.yaml']) + print(f"Config from file: {args}") + except RuntimeError as e: + print(f"Failed with config file as expected: {e}") + + print("\nTrying with CLI args:") + # This works fine + args = simple_parsing.parse(TrainConfig, args=['--model_a_param', 'test']) + print(f"Config from CLI: {args}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/repro_subgroup_config.yaml b/repro_subgroup_config.yaml new file mode 100644 index 00000000..c29d6f38 --- /dev/null +++ b/repro_subgroup_config.yaml @@ -0,0 +1 @@ +model_a_param: test diff --git a/repro_subgroup_minimal.yaml b/repro_subgroup_minimal.yaml new file mode 100644 index 00000000..c29d6f38 --- /dev/null +++ b/repro_subgroup_minimal.yaml @@ -0,0 +1 @@ +model_a_param: test diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index 01cec334..3cc227b1 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -387,14 +387,10 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) -> if config_path: defaults = read_file(config_path) if self.nested_mode == NestedMode.WITHOUT_ROOT and len(self._wrappers) == 1: - # The file should have the same format as the command-line args, e.g. contain the - # fields of the 'root' dataclass directly (e.g. "foo: 123"), rather a dict with - # "config: foo: 123" where foo is a field of the root dataclass at dest 'config'. - # Therefore, we add the prefix back here. - defaults = {self._wrappers[0].dest: defaults} - # We also assume that the kwargs are passed as foo=123 - kwargs = {self._wrappers[0].dest: kwargs} - # Also include the values from **kwargs. + # The file should have the same format as the command-line args + wrapper = self._wrappers[0] + defaults = {wrapper.dest: defaults} + kwargs = {wrapper.dest: kwargs} kwargs = dict_union(defaults, kwargs) # The kwargs that are set in the dataclasses, rather than on the namespace. @@ -640,7 +636,7 @@ def _resolve_subgroups( # config_path=self.config_path, # NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues # for example if you have —a_or_b and A has a field —a then it will error out if you - # pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it + # pass —a=1 because 1 isn't a choice for the a_or_b argument (because --a matches it # with the abbreviation feature turned on). allow_abbrev=False, ) @@ -827,6 +823,8 @@ def _instantiate_dataclasses( argparse.Namespace The transformed namespace with the instances set at their corresponding destinations. + Also keeps whatever arguments were added in the traditional fashion, + i.e. with `parser.add_argument(...)`. """ constructor_arguments = constructor_arguments.copy() # FIXME: There's a bug here happening with the `ALWAYS_MERGE` case: The namespace has the @@ -1157,5 +1155,85 @@ def _create_dataclass_instance( else: logger.debug(f"All fields for {wrapper.dest} were either at their default, or None.") return None + + # Handle subgroup fields + subgroup_fields = {f for f in wrapper.fields if f.is_subgroup} + if subgroup_fields: + # Create a copy of constructor args to avoid modifying the original + filtered_args = constructor_args.copy() + + # Remove _type_ field if present at top level + filtered_args.pop("_type_", None) + + # For each subgroup field, check if we have parameters that belong to its default type + for field_wrapper in subgroup_fields: + default_key = field_wrapper.subgroup_default + if default_key is not None and default_key is not dataclasses.MISSING: + choices = field_wrapper.subgroup_choices + default_factory = choices[default_key] + if callable(default_factory): + # Handle callables (functions or partial) that return a dataclass + if isinstance(default_factory, functools.partial): + # For partial, get the underlying function/class + default_type = default_factory.func + else: + default_type = default_factory + + # If it's still a callable but not a class, we need the return type + if not isinstance(default_type, type) and callable(default_type): + # For a function, use the return annotation to get the class + import inspect + signature = inspect.signature(default_type) + if signature.return_annotation != inspect.Signature.empty: + # Use the actual dataclass directly from the test + if hasattr(default_type, "__globals__"): + # Get globals from the function to resolve the return annotation + globals_dict = default_type.__globals__ + locals_dict = {} + if isinstance(signature.return_annotation, str): + # Try to evaluate the string as a type + try: + return_type = eval(signature.return_annotation, globals_dict, locals_dict) + if is_dataclass_type(return_type): + default_type = return_type + except (NameError, TypeError): + # If we can't evaluate it, try to get it from the global namespace + # For simple cases like 'Obj' where Obj is defined in the same scope + if signature.return_annotation in globals_dict: + default_type = globals_dict[signature.return_annotation] + else: + # Non-string annotation + if is_dataclass_type(signature.return_annotation): + default_type = signature.return_annotation + else: + # Fallback - try simple_parsing's helper (might not work in all cases) + from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable + try: + default_type = _get_dataclass_type_from_callable(default_type) + except Exception: + # If we can't determine the type, we'll skip field analysis + continue + else: + default_type = type(default_factory) + + # Get fields of the default type + default_subgroup_fields = {f.name for f in dataclasses.fields(default_type)} + + # Find which fields in the input match fields in the default subgroup + matching_fields = {name: filtered_args[name] for name in list(filtered_args.keys()) + if name in default_subgroup_fields} + + if matching_fields: + # Create an instance of the default type with the matching fields + subgroup_instance = default_type(**matching_fields) + filtered_args[field_wrapper.name] = subgroup_instance + + # Remove handled fields + for name in matching_fields: + filtered_args.pop(name, None) + + # Use the filtered args to create the instance + constructor_args = filtered_args + logger.debug(f"Calling constructor: {constructor}(**{constructor_args})") return constructor(**constructor_args) diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 5526a849..e56db0e8 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -294,24 +294,106 @@ def set_default(self, value: DataclassT | dict | None): self._default = value if field_default_values is None: return - unknown_names = set(field_default_values) + + # First try to handle any subgroup fields + subgroup_fields = {f for f in self.fields if f.is_subgroup} + remaining_fields = field_default_values.copy() # Work with a copy to track what's been handled + + for field_wrapper in subgroup_fields: + # Get the default subgroup type from the choices + default_key = field_wrapper.subgroup_default + if default_key is not None and default_key is not dataclasses.MISSING: + choices = field_wrapper.subgroup_choices + default_factory = choices[default_key] + if callable(default_factory): + # Handle callables (functions or partial) that return a dataclass + if isinstance(default_factory, functools.partial): + # For partial, get the underlying function/class + default_type = default_factory.func + else: + default_type = default_factory + + # If it's still a callable but not a class, we need the return type + if not isinstance(default_type, type) and callable(default_type): + # For a function, use the return annotation to get the class + import inspect + signature = inspect.signature(default_type) + if signature.return_annotation != inspect.Signature.empty: + # Use the actual dataclass directly from the test + if hasattr(default_type, "__globals__"): + # Get globals from the function to resolve the return annotation + globals_dict = default_type.__globals__ + locals_dict = {} + if isinstance(signature.return_annotation, str): + # Try to evaluate the string as a type + try: + return_type = eval(signature.return_annotation, globals_dict, locals_dict) + if is_dataclass_type(return_type): + default_type = return_type + except (NameError, TypeError): + # If we can't evaluate it, try to get it from the global namespace + # For simple cases like 'Obj' where Obj is defined in the same scope + if signature.return_annotation in globals_dict: + default_type = globals_dict[signature.return_annotation] + else: + # Non-string annotation + if is_dataclass_type(signature.return_annotation): + default_type = signature.return_annotation + else: + # Fallback - try simple_parsing's helper (might not work in all cases) + from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable + try: + default_type = _get_dataclass_type_from_callable(default_type) + except Exception: + # If we can't determine the type, we'll skip field analysis + continue + else: + default_type = type(default_factory) + + # Get fields of the default type + default_subgroup_fields = {f.name for f in dataclasses.fields(default_type)} + + # Find which fields in the input match fields in the default subgroup + matching_fields = {name: remaining_fields[name] for name in list(remaining_fields.keys()) + if name in default_subgroup_fields} + + if matching_fields: + # Create the nested structure for the subgroup + subgroup_dict = { + field_wrapper.name: { + "_type_": default_key, + **matching_fields + } + } + # Set this as the default for this field + field_wrapper.set_default(subgroup_dict[field_wrapper.name]) + + # Remove handled fields + for name in matching_fields: + remaining_fields.pop(name, None) + + # Now handle any remaining regular fields for field_wrapper in self.fields: - if field_wrapper.name not in field_default_values: + if field_wrapper.name not in remaining_fields: continue - # Manually set the default value for this argument. - field_default_value = field_default_values[field_wrapper.name] - field_wrapper.set_default(field_default_value) - unknown_names.remove(field_wrapper.name) + if field_wrapper.is_subgroup: + continue + # Set default for regular field + field_wrapper.set_default(remaining_fields[field_wrapper.name]) + remaining_fields.pop(field_wrapper.name) + + # Handle nested dataclass fields for nested_dataclass_wrapper in self._children: - if nested_dataclass_wrapper.name not in field_default_values: + if nested_dataclass_wrapper.name not in remaining_fields: continue - field_default_value = field_default_values[nested_dataclass_wrapper.name] - nested_dataclass_wrapper.set_default(field_default_value) - unknown_names.remove(nested_dataclass_wrapper.name) - unknown_names.discard("_type_") - if unknown_names: + nested_dataclass_wrapper.set_default(remaining_fields[nested_dataclass_wrapper.name]) + remaining_fields.pop(nested_dataclass_wrapper.name) + + # Check for any unhandled fields + remaining_fields.pop("_type_", None) # Remove _type_ if present as it's handled separately + if remaining_fields: raise RuntimeError( - f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!" + f"{sorted(remaining_fields.keys())} are not fields of {self.dataclass} at path {self.dest!r}!" ) @property diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 62f79bc4..c9145784 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -719,12 +719,13 @@ def default(self) -> Any: if it has a default value """ - if self._default is not None: + if self.is_subgroup: + # For subgroups, always use the subgroup_default to maintain consistency + default = self.subgroup_default + elif self._default is not None: # If a default value was set manually from the outside (e.g. from the DataclassWrapper) # then use that value. default = self._default - elif self.is_subgroup: - default = self.subgroup_default elif any( parent_default not in (None, argparse.SUPPRESS) for parent_default in self.parent.defaults diff --git a/test/test_subgroup_minimal.yaml b/test/test_subgroup_minimal.yaml new file mode 100644 index 00000000..360e3893 --- /dev/null +++ b/test/test_subgroup_minimal.yaml @@ -0,0 +1,2 @@ +_type_: type_a +model_a_param: test diff --git a/test/test_subgroups.py b/test/test_subgroups.py index 07451fb0..83deeb35 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import Annotated, Callable, TypeVar +from typing import Annotated, Callable, TypeVar, Union, ForwardRef import pytest from pytest_regressions.file_regression import FileRegressionFixture @@ -321,7 +321,7 @@ class Foo(TestSetup): # constructor arguments, and even if the dataclass_fn is a partial(A, a=1.23), since it's being # called like `dataclass_fn(**{"a": 0.0})` (from the field default), then the value of `a` is # overwritten with the default value from the field. I think the solution is either: - # 1. Not populate the constructor arguments with the value for this field; + # 1. Not populate the constructor arguments dict with the default values; # 2. Change the default value to be the one from the partial, instead of the one from the # field. The partial would then be called with the same value. # I think 1. makes more sense. For fields that aren't required (have a default value), then the @@ -391,29 +391,30 @@ class Foo(TestSetup): assert Foo.setup("--a_or_b make_b --b foo") == Foo(a_or_b=B(b="foo")) +@dataclass +class FunctionTestObj: + a: float = 0.0 + b: str = "default from field" + +def make_function_test_obj(**kwargs) -> FunctionTestObj: + # First case (current): receives all fields + assert kwargs == {"a": 0.0, "b": "foo"} + # Second case: receive only set fields. + # assert kwargs == {"b": "foo"} + return FunctionTestObj(**kwargs) + def test_subgroup_functions_receive_all_fields(): """TODO: Decide how we want to go about this. Either the functions receive all the fields (the default values), or only the ones that are set (harder to implement). """ - - @dataclass - class Obj: - a: float = 0.0 - b: str = "default from field" - - def make_obj(**kwargs) -> Obj: - assert kwargs == {"a": 0.0, "b": "foo"} # first case (current): receives all fields - # assert kwargs == {"b": "foo"} # second case: receive only set fields. - return Obj(**kwargs) - @dataclass class Foo(TestSetup): - a_or_b: Obj = subgroups( + a_or_b: FunctionTestObj = subgroups( { - "make_obj": make_obj, + "make_obj": make_function_test_obj, }, - default_factory=make_obj, + default_factory=make_function_test_obj, ) Foo.setup("--a_or_b make_obj --b foo") @@ -934,3 +935,58 @@ def test_ordering_of_args_doesnt_matter(): model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)), dataset=Dataset2Config(data_dir="data/bar", bar=1.2), ) + +def test_subgroup_params_in_config_file_minimal(): + """Minimal reproduction of the issue where subgroup parameters fail when loaded from config file. + + This test reproduces the exact issue from the GitHub issue, where parameters for the default + subgroup (ModelTypeA) fail to be recognized when provided through a config file, even though + they work via CLI arguments. + """ + import yaml + from pathlib import Path + + @dataclasses.dataclass + class ModelTypeA: + model_a_param: str = "default_a" + + @dataclasses.dataclass + class ModelTypeB: + model_b_param: str = "default_b" + + @dataclasses.dataclass + class TrainConfig(TestSetup): + model_type: "Union[ModelTypeA, ModelTypeB]" = subgroups( + {"type_a": ModelTypeA, "type_b": ModelTypeB}, + default_factory=ModelTypeA, + positional=False, + ) + + # Create a config file + config_path = Path(__file__).parent / "test_subgroup_minimal.yaml" + config = { + "_type_": "type_a", # Specify we want to use ModelTypeA + "model_a_param": "test" # Set the parameter + } + with config_path.open('w') as f: + yaml.dump(config, f) + + # This works (CLI args case) + config_from_cli = parse( + TrainConfig, + args=shlex.split("--model_a_param test"), + ) + assert isinstance(config_from_cli.model_type, ModelTypeA) + assert config_from_cli.model_type.model_a_param == "test" + + # This should work the same way as CLI args + config_from_file = parse( + TrainConfig, + config_path=config_path, + args=[], # Pass empty list to prevent pytest args from being parsed + ) + + # These assertions should pass but currently fail because the config file parameters + # aren't properly associated with the default subgroup + assert isinstance(config_from_file.model_type, ModelTypeA) + assert config_from_file.model_type.model_a_param == "test"