Skip to content

Fix subgroup #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions repro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Union
import dataclasses
import simple_parsing
import yaml
from pathlib import Path

@dataclasses.dataclass
class ModelTypeA:
model_a_param: str = "default_a"

@dataclasses.dataclass
class ModelTypeB:
model_b_param: str = "default_b"

@dataclasses.dataclass
class TrainConfig:
model_type: Union[ModelTypeA, ModelTypeB] = simple_parsing.subgroups(
{"type_a": ModelTypeA, "type_b": ModelTypeB},
default_factory=ModelTypeA,
positional=False,
)

def main():
# Create a config file
config_path = Path("repro_subgroup_minimal.yaml")
config = {
"model_a_param": "test" # This should work but fails
}
with config_path.open('w') as f:
yaml.dump(config, f)

print("\nTrying with config file:")
try:
# This fails with:
# RuntimeError: ['model_a_param'] are not fields of <class '__main__.TrainConfig'> at path 'config'!
args = simple_parsing.parse(TrainConfig, add_config_path_arg=True, args=['--config_path', 'repro_subgroup_minimal.yaml'])
print(f"Config from file: {args}")
except RuntimeError as e:
print(f"Failed with config file as expected: {e}")

print("\nTrying with CLI args:")
# This works fine
args = simple_parsing.parse(TrainConfig, args=['--model_a_param', 'test'])
print(f"Config from CLI: {args}")

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions repro_subgroup_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model_a_param: test
1 change: 1 addition & 0 deletions repro_subgroup_minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model_a_param: test
96 changes: 87 additions & 9 deletions simple_parsing/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,10 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) ->
if config_path:
defaults = read_file(config_path)
if self.nested_mode == NestedMode.WITHOUT_ROOT and len(self._wrappers) == 1:
# The file should have the same format as the command-line args, e.g. contain the
# fields of the 'root' dataclass directly (e.g. "foo: 123"), rather a dict with
# "config: foo: 123" where foo is a field of the root dataclass at dest 'config'.
# Therefore, we add the prefix back here.
defaults = {self._wrappers[0].dest: defaults}
# We also assume that the kwargs are passed as foo=123
kwargs = {self._wrappers[0].dest: kwargs}
# Also include the values from **kwargs.
# The file should have the same format as the command-line args
wrapper = self._wrappers[0]
defaults = {wrapper.dest: defaults}
kwargs = {wrapper.dest: kwargs}
kwargs = dict_union(defaults, kwargs)

# The kwargs that are set in the dataclasses, rather than on the namespace.
Expand Down Expand Up @@ -640,7 +636,7 @@ def _resolve_subgroups(
# config_path=self.config_path,
# NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues
# for example if you have —a_or_b and A has a field —a then it will error out if you
# pass —a=1 because 1 isnt a choice for the a_or_b argument (because --a matches it
# pass —a=1 because 1 isn't a choice for the a_or_b argument (because --a matches it
# with the abbreviation feature turned on).
allow_abbrev=False,
)
Expand Down Expand Up @@ -827,6 +823,8 @@ def _instantiate_dataclasses(
argparse.Namespace
The transformed namespace with the instances set at their
corresponding destinations.
Also keeps whatever arguments were added in the traditional fashion,
i.e. with `parser.add_argument(...)`.
"""
constructor_arguments = constructor_arguments.copy()
# FIXME: There's a bug here happening with the `ALWAYS_MERGE` case: The namespace has the
Expand Down Expand Up @@ -1157,5 +1155,85 @@ def _create_dataclass_instance(
else:
logger.debug(f"All fields for {wrapper.dest} were either at their default, or None.")
return None

# Handle subgroup fields
subgroup_fields = {f for f in wrapper.fields if f.is_subgroup}
if subgroup_fields:
# Create a copy of constructor args to avoid modifying the original
filtered_args = constructor_args.copy()

# Remove _type_ field if present at top level
filtered_args.pop("_type_", None)

# For each subgroup field, check if we have parameters that belong to its default type
for field_wrapper in subgroup_fields:
default_key = field_wrapper.subgroup_default
if default_key is not None and default_key is not dataclasses.MISSING:
choices = field_wrapper.subgroup_choices
default_factory = choices[default_key]
if callable(default_factory):
# Handle callables (functions or partial) that return a dataclass
if isinstance(default_factory, functools.partial):
# For partial, get the underlying function/class
default_type = default_factory.func
else:
default_type = default_factory

# If it's still a callable but not a class, we need the return type
if not isinstance(default_type, type) and callable(default_type):
# For a function, use the return annotation to get the class
import inspect
signature = inspect.signature(default_type)
if signature.return_annotation != inspect.Signature.empty:
# Use the actual dataclass directly from the test
if hasattr(default_type, "__globals__"):
# Get globals from the function to resolve the return annotation
globals_dict = default_type.__globals__
locals_dict = {}
if isinstance(signature.return_annotation, str):
# Try to evaluate the string as a type
try:
return_type = eval(signature.return_annotation, globals_dict, locals_dict)
if is_dataclass_type(return_type):
default_type = return_type
except (NameError, TypeError):
# If we can't evaluate it, try to get it from the global namespace
# For simple cases like 'Obj' where Obj is defined in the same scope
if signature.return_annotation in globals_dict:
default_type = globals_dict[signature.return_annotation]
else:
# Non-string annotation
if is_dataclass_type(signature.return_annotation):
default_type = signature.return_annotation
else:
# Fallback - try simple_parsing's helper (might not work in all cases)
from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable
try:
default_type = _get_dataclass_type_from_callable(default_type)
except Exception:
# If we can't determine the type, we'll skip field analysis
continue
else:
default_type = type(default_factory)

# Get fields of the default type
default_subgroup_fields = {f.name for f in dataclasses.fields(default_type)}

# Find which fields in the input match fields in the default subgroup
matching_fields = {name: filtered_args[name] for name in list(filtered_args.keys())
if name in default_subgroup_fields}

if matching_fields:
# Create an instance of the default type with the matching fields
subgroup_instance = default_type(**matching_fields)
filtered_args[field_wrapper.name] = subgroup_instance

# Remove handled fields
for name in matching_fields:
filtered_args.pop(name, None)

# Use the filtered args to create the instance
constructor_args = filtered_args

logger.debug(f"Calling constructor: {constructor}(**{constructor_args})")
return constructor(**constructor_args)
108 changes: 95 additions & 13 deletions simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,24 +294,106 @@ def set_default(self, value: DataclassT | dict | None):
self._default = value
if field_default_values is None:
return
unknown_names = set(field_default_values)

# First try to handle any subgroup fields
subgroup_fields = {f for f in self.fields if f.is_subgroup}
remaining_fields = field_default_values.copy() # Work with a copy to track what's been handled

for field_wrapper in subgroup_fields:
# Get the default subgroup type from the choices
default_key = field_wrapper.subgroup_default
if default_key is not None and default_key is not dataclasses.MISSING:
choices = field_wrapper.subgroup_choices
default_factory = choices[default_key]
if callable(default_factory):
# Handle callables (functions or partial) that return a dataclass
if isinstance(default_factory, functools.partial):
# For partial, get the underlying function/class
default_type = default_factory.func
else:
default_type = default_factory

# If it's still a callable but not a class, we need the return type
if not isinstance(default_type, type) and callable(default_type):
# For a function, use the return annotation to get the class
import inspect
signature = inspect.signature(default_type)
if signature.return_annotation != inspect.Signature.empty:
# Use the actual dataclass directly from the test
if hasattr(default_type, "__globals__"):
# Get globals from the function to resolve the return annotation
globals_dict = default_type.__globals__
locals_dict = {}
if isinstance(signature.return_annotation, str):
# Try to evaluate the string as a type
try:
return_type = eval(signature.return_annotation, globals_dict, locals_dict)
if is_dataclass_type(return_type):
default_type = return_type
except (NameError, TypeError):
# If we can't evaluate it, try to get it from the global namespace
# For simple cases like 'Obj' where Obj is defined in the same scope
if signature.return_annotation in globals_dict:
default_type = globals_dict[signature.return_annotation]
else:
# Non-string annotation
if is_dataclass_type(signature.return_annotation):
default_type = signature.return_annotation
else:
# Fallback - try simple_parsing's helper (might not work in all cases)
from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable
try:
default_type = _get_dataclass_type_from_callable(default_type)
except Exception:
# If we can't determine the type, we'll skip field analysis
continue
else:
default_type = type(default_factory)

# Get fields of the default type
default_subgroup_fields = {f.name for f in dataclasses.fields(default_type)}

# Find which fields in the input match fields in the default subgroup
matching_fields = {name: remaining_fields[name] for name in list(remaining_fields.keys())
if name in default_subgroup_fields}

if matching_fields:
# Create the nested structure for the subgroup
subgroup_dict = {
field_wrapper.name: {
"_type_": default_key,
**matching_fields
}
}
# Set this as the default for this field
field_wrapper.set_default(subgroup_dict[field_wrapper.name])

# Remove handled fields
for name in matching_fields:
remaining_fields.pop(name, None)

# Now handle any remaining regular fields
for field_wrapper in self.fields:
if field_wrapper.name not in field_default_values:
if field_wrapper.name not in remaining_fields:
continue
# Manually set the default value for this argument.
field_default_value = field_default_values[field_wrapper.name]
field_wrapper.set_default(field_default_value)
unknown_names.remove(field_wrapper.name)
if field_wrapper.is_subgroup:
continue
# Set default for regular field
field_wrapper.set_default(remaining_fields[field_wrapper.name])
remaining_fields.pop(field_wrapper.name)

# Handle nested dataclass fields
for nested_dataclass_wrapper in self._children:
if nested_dataclass_wrapper.name not in field_default_values:
if nested_dataclass_wrapper.name not in remaining_fields:
continue
field_default_value = field_default_values[nested_dataclass_wrapper.name]
nested_dataclass_wrapper.set_default(field_default_value)
unknown_names.remove(nested_dataclass_wrapper.name)
unknown_names.discard("_type_")
if unknown_names:
nested_dataclass_wrapper.set_default(remaining_fields[nested_dataclass_wrapper.name])
remaining_fields.pop(nested_dataclass_wrapper.name)

# Check for any unhandled fields
remaining_fields.pop("_type_", None) # Remove _type_ if present as it's handled separately
if remaining_fields:
raise RuntimeError(
f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!"
f"{sorted(remaining_fields.keys())} are not fields of {self.dataclass} at path {self.dest!r}!"
)

@property
Expand Down
7 changes: 4 additions & 3 deletions simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,12 +719,13 @@ def default(self) -> Any:
if it has a default value
"""

if self._default is not None:
if self.is_subgroup:
# For subgroups, always use the subgroup_default to maintain consistency
default = self.subgroup_default
elif self._default is not None:
# If a default value was set manually from the outside (e.g. from the DataclassWrapper)
# then use that value.
default = self._default
elif self.is_subgroup:
default = self.subgroup_default
elif any(
parent_default not in (None, argparse.SUPPRESS)
for parent_default in self.parent.defaults
Expand Down
2 changes: 2 additions & 0 deletions test/test_subgroup_minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_type_: type_a
model_a_param: test
Loading
Loading