From dd6fcb19e58ca14a1748fff4a341a8f8aa2ddbc2 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 3 Aug 2023 18:08:11 -0400 Subject: [PATCH 01/20] Add test to reproduce a bug with _type_ keys Signed-off-by: Fabrice Normandin --- test/test_huggingface_compat.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/test/test_huggingface_compat.py b/test/test_huggingface_compat.py index 5fd17fc0..39059e39 100644 --- a/test/test_huggingface_compat.py +++ b/test/test_huggingface_compat.py @@ -10,8 +10,9 @@ import pytest -from simple_parsing import ArgumentParser +from simple_parsing import ArgumentParser, parse from simple_parsing.docstring import get_attribute_docstring +from simple_parsing.helpers.serialization import load, save from .testutils import TestSetup, needs_yaml, raises_invalid_choice @@ -1292,8 +1293,20 @@ def test_entire_docstring_isnt_used_as_help(): ) def test_serialization(tmp_path: Path, filename: str, args: TrainingArguments): """test that serializing / deserializing a TrainingArguments works.""" - from simple_parsing.helpers.serialization import load, save path = tmp_path / filename save(args, path) assert load(TrainingArguments, path) == args + + +@pytest.mark.xfail( + raises=TypeError, + strict=True, + reason="All fields (non-init ones too) are passed to .set_defaults, which raises a TypeError", +) +@pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"]) +def test_parse_with_config_file(tmp_path: Path, filetype: str): + default_args = TrainingArguments(label_smoothing_factor=123.123) + config_path = (tmp_path / "bob").with_suffix(filetype) + save(default_args, config_path, save_dc_types=True) + assert parse(TrainingArguments, config_path=config_path, args="") == default_args From cd87f71fb674af9af9dd517a2bbe9d070a7f6692 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 3 Aug 2023 18:06:03 -0400 Subject: [PATCH 02/20] Fix a bug with parsing with _type_ keys in config Signed-off-by: Fabrice Normandin --- simple_parsing/helpers/serialization/serializable.py | 3 +++ simple_parsing/parsing.py | 10 ++++++++++ simple_parsing/wrappers/dataclass_wrapper.py | 3 +++ test/test_huggingface_compat.py | 12 +++++++++--- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index bc845ad5..83ab5498 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -1,4 +1,5 @@ from __future__ import annotations +import functools import json import pickle @@ -832,6 +833,7 @@ def from_dict( if name not in obj_dict: if ( field.metadata.get("to_dict", True) + and field.init and field.default is MISSING and field.default_factory is MISSING ): @@ -928,6 +930,7 @@ def is_dataclass_or_optional_dataclass_type(t: type) -> bool: return is_dataclass(t) or (is_optional(t) and is_dataclass(get_args(t)[0])) +@functools.lru_cache(maxsize=None) def _locate(path: str) -> Any: """COPIED FROM Hydra: https://github.com/facebookresearch/hydra/blob/f8940600d0ab5c695961ad83ab d042ffe9458caf/hydra/_internal/utils.py#L614. diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index ab1aa524..794e363b 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -1127,6 +1127,16 @@ def _create_dataclass_instance( # None. # TODO: (BUG!) This doesn't distinguish the case where the defaults are passed via the # command-line from the case where no arguments are passed at all! + if "_type_" in constructor_args: + from simple_parsing.helpers.serialization.serializable import _locate + + dc_type_in_config_file = _locate(constructor_args.pop("_type_")) + logger.info( + f"Overwriting constructor with the dc type from the config file: " + f"{constructor}->{dc_type_in_config_file}" + ) + constructor = dc_type_in_config_file + if wrapper.optional and wrapper.default is None: for field_wrapper in wrapper.fields: arg_value = constructor_args[field_wrapper.name] diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 2efa98af..6a86cbe7 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -298,6 +298,9 @@ def set_default(self, value: DataclassT | dict | None): for field_wrapper in self.fields: if field_wrapper.name not in field_default_values: continue + if not field_wrapper.field.init: + # NOTE: Ignore the non-init fields. + continue # Manually set the default value for this argument. field_default_value = field_default_values[field_wrapper.name] field_wrapper.set_default(field_default_value) diff --git a/test/test_huggingface_compat.py b/test/test_huggingface_compat.py index 39059e39..153bf5df 100644 --- a/test/test_huggingface_compat.py +++ b/test/test_huggingface_compat.py @@ -300,14 +300,14 @@ def test_enums_are_parsed_to_enum_member(): # However, it is, once we factor in what's happening in the __post_init__ of TrainingArguments. with pytest.raises(ValueError): - TrainingArguments.setup("--evaluation_strategy invalid") + parse(TrainingArguments, args="--evaluation_strategy invalid") for mode, enum_value in zip( ["no", "steps", "epoch"], [IntervalStrategy.NO, IntervalStrategy.STEPS, IntervalStrategy.EPOCH], ): assert ( - TrainingArguments.setup(f"--evaluation_strategy {mode}").evaluation_strategy + parse(TrainingArguments, args=f"--evaluation_strategy {mode}").evaluation_strategy == enum_value ) @@ -1266,7 +1266,13 @@ def test_docstring_parse_works_with_hf_training_args(): def test_entire_docstring_isnt_used_as_help(): - help_text = TrainingArguments.get_help_text() + parser = ArgumentParser() + parser.add_arguments(TrainingArguments, "config") + with io.StringIO() as f: + parser.print_help(file=f) + f.seek(0) + help_text = f.read() + # help_text = TrainingArguments.get_help_text() help_from_field = "Whether to use Apple Silicon chip based `mps` device." assert help_from_field in help_text From 047b65386f06b1da9727d48f28b55df8160291aa Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 3 Aug 2023 18:22:21 -0400 Subject: [PATCH 03/20] Fix the typing of the `subgroups` fn Signed-off-by: Fabrice Normandin --- simple_parsing/helpers/subgroups.py | 42 ++++------------------------- 1 file changed, 5 insertions(+), 37 deletions(-) diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index 8b16794c..4f4a9197 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -17,48 +17,16 @@ SubgroupKey: TypeAlias = Union[str, int, bool, Enum] Key = TypeVar("Key", str, int, bool, Enum) +DC = TypeVar("DC") -@overload def subgroups( - subgroups: dict[Key, DataclassT | type[DataclassT] | functools.partial[DataclassT]], + subgroups: dict[Key, type[DC] | functools.partial[DC]], *args, - default: Key | DataclassT, - default_factory: _MISSING_TYPE = MISSING, + default: Key | _MISSING_TYPE = MISSING, + default_factory: type[DC] | functools.partial[DC] | _MISSING_TYPE = MISSING, **kwargs, -) -> DataclassT: - ... - - -@overload -def subgroups( - subgroups: dict[Key, DataclassT | type[DataclassT] | functools.partial[DataclassT]], - *args, - default: _MISSING_TYPE = MISSING, - default_factory: type[DataclassT] | functools.partial[DataclassT], - **kwargs, -) -> DataclassT: - ... - - -@overload -def subgroups( - subgroups: dict[Key, DataclassT | type[DataclassT] | functools.partial[DataclassT]], - *args, - default: _MISSING_TYPE = MISSING, - default_factory: _MISSING_TYPE = MISSING, - **kwargs, -) -> DataclassT: - ... - - -def subgroups( - subgroups: dict[Key, DataclassT | type[DataclassT] | functools.partial[DataclassT]], - *args, - default: Key | DataclassT | _MISSING_TYPE = MISSING, - default_factory: type[DataclassT] | functools.partial[DataclassT] | _MISSING_TYPE = MISSING, - **kwargs, -) -> DataclassT: +) -> DC: """Creates a field that will be a choice between different subgroups of arguments. This is different than adding a subparser action. There can only be one subparser action, while From beb8cd0637b01a256d97de9af032b1dccf9c4835 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 3 Aug 2023 18:22:53 -0400 Subject: [PATCH 04/20] Add test to reproduce the subgroup issue#276 Signed-off-by: Fabrice Normandin --- test/test_huggingface_compat.py | 10 ++--- test/test_subgroups.py | 37 +++++++++++++++++++ .../test_help[Config---help].md | 2 +- ...est_help[Config---model=model_a --help].md | 2 +- ...est_help[Config---model=model_b --help].md | 2 +- ...gWithFrozen---conf=even --a 100 --help].md | 2 +- ...lp[ConfigWithFrozen---conf=even --help].md | 2 +- ...igWithFrozen---conf=odd --a 123 --help].md | 2 +- ...elp[ConfigWithFrozen---conf=odd --help].md | 2 +- .../test_help[ConfigWithFrozen---help].md | 2 +- 10 files changed, 50 insertions(+), 13 deletions(-) diff --git a/test/test_huggingface_compat.py b/test/test_huggingface_compat.py index 153bf5df..9fa7d530 100644 --- a/test/test_huggingface_compat.py +++ b/test/test_huggingface_compat.py @@ -1305,11 +1305,11 @@ def test_serialization(tmp_path: Path, filename: str, args: TrainingArguments): assert load(TrainingArguments, path) == args -@pytest.mark.xfail( - raises=TypeError, - strict=True, - reason="All fields (non-init ones too) are passed to .set_defaults, which raises a TypeError", -) +# @pytest.mark.xfail( +# raises=TypeError, +# strict=True, +# reason="All fields (non-init ones too) are passed to .set_defaults, which raises a TypeError", +# ) @pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"]) def test_parse_with_config_file(tmp_path: Path, filetype: str): default_args = TrainingArguments(label_smoothing_factor=123.123) diff --git a/test/test_subgroups.py b/test/test_subgroups.py index 1fb2688a..56164a19 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -11,6 +11,7 @@ from typing import Callable, TypeVar import pytest +from simple_parsing.helpers.serialization import save from pytest_regressions.file_regression import FileRegressionFixture from typing_extensions import Annotated @@ -935,3 +936,39 @@ def test_ordering_of_args_doesnt_matter(): model=ModelAConfig(lr=0.0003, optimizer="Adam", betas=(0.0, 1.0)), dataset=Dataset2Config(data_dir="data/bar", bar=1.2), ) + + +@dataclass +class A1: + a_val: int = 1 + + +@dataclass +class A2: + a_val: int = 2 + + +@dataclass +class A1OrA2: + a: A1 | A2 = subgroups({"a1": A1, "a2": A2}, default="a1") + + +@pytest.mark.parametrize( + ("value_in_config", "args", "expected"), + [ + (A1OrA2(a=A2()), "", A1OrA2(a=A1())), + (A1OrA2(a=A1()), "", A1OrA2(a=A1())), + ], +) +@pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"]) +def test_parse_with_config_file_with_different_subgroup( + tmp_path: Path, + filetype: str, + value_in_config: A1OrA2, + args: str, + expected: A1OrA2, +): + config_path = (tmp_path / "bob").with_suffix(filetype) + + save(value_in_config, config_path, save_dc_types=True) + assert parse(A1OrA2, config_path=config_path, args=args) == expected diff --git a/test/test_subgroups/test_help[Config---help].md b/test/test_subgroups/test_help[Config---help].md index 9a51d7d7..b0632885 100644 --- a/test/test_subgroups/test_help[Config---help].md +++ b/test/test_subgroups/test_help[Config---help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_a --help].md b/test/test_subgroups/test_help[Config---model=model_a --help].md index 7d2f8970..c20f7faf 100644 --- a/test/test_subgroups/test_help[Config---model=model_a --help].md +++ b/test/test_subgroups/test_help[Config---model=model_a --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_b --help].md b/test/test_subgroups/test_help[Config---model=model_b --help].md index 1e2fb4c0..39fdbaa6 100644 --- a/test/test_subgroups/test_help[Config---model=model_b --help].md +++ b/test/test_subgroups/test_help[Config---model=model_b --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md index 5b82f578..0aab6941 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md index 312b7218..9057799f 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md index cbfa7eeb..262aeb67 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md index 890e7326..61714b33 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---help].md b/test/test_subgroups/test_help[ConfigWithFrozen---help].md index d05b94ef..2773c36d 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:731) Given Source code: From 45bd7c03fdadfda24e04e3a9d058eae65408335e Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 27 Oct 2023 11:41:28 -0400 Subject: [PATCH 05/20] Fix issues with * import in helpers module Signed-off-by: Fabrice Normandin --- examples/simple/flag.py | 12 ++-------- simple_parsing/helpers/__init__.py | 35 +++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/examples/simple/flag.py b/examples/simple/flag.py index f3f28f44..328050ee 100644 --- a/examples/simple/flag.py +++ b/examples/simple/flag.py @@ -1,17 +1,9 @@ from dataclasses import dataclass -from simple_parsing import ArgumentParser +from simple_parsing import ArgumentParser, parse from simple_parsing.helpers import flag -def parse(cls, args: str = ""): - """Removes some boilerplate code from the examples.""" - parser = ArgumentParser() # Create an argument parser - parser.add_arguments(cls, dest="hparams") # add arguments for the dataclass - ns = parser.parse_args(args.split()) # parse the given `args` - return ns.hparams - - @dataclass class HParams: """Set of options for the training of a Model.""" @@ -32,7 +24,7 @@ class HParams: """ # Example 2 using the flags negative prefix -assert parse(HParams, "--no-train") == HParams(train=False) +assert parse(HParams, args="--no-train") == HParams(train=False) # showing what --help outputs diff --git a/simple_parsing/helpers/__init__.py b/simple_parsing/helpers/__init__.py index b0e635bd..b8546e1c 100644 --- a/simple_parsing/helpers/__init__.py +++ b/simple_parsing/helpers/__init__.py @@ -1,5 +1,15 @@ """Collection of helper classes and functions to reduce boilerplate code.""" -from .fields import * +from .fields import ( + choice, + dict_field, + field, + flag, + flags, + list_field, + mutable_field, + set_field, + subparsers, +) from .flatten import FlattenedAccess from .hparams import HyperParameters from .partial import Partial, config_for @@ -13,3 +23,26 @@ # For backward compatibility purposes JsonSerializable = Serializable SimpleEncoder = SimpleJsonEncoder + +__all__ = [ + "FlattenedAccess", + "HyperParameters", + "Partial", + "config_for", + "FrozenSerializable", + "Serializable", + "SimpleJsonEncoder", + "encode", + "JsonSerializable", + "SimpleEncoder", + "YamlSerializable", + "field", + "choice", + "list_field", + "dict_field", + "set_field", + "mutable_field", + "subparsers", + "flag", + "flags", +] From 28092e0cf629e35152b6e6d798c1703e241b50da Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 27 Oct 2023 11:51:55 -0400 Subject: [PATCH 06/20] Fix issue with import of `subgroups` from helpers Signed-off-by: Fabrice Normandin --- simple_parsing/helpers/__init__.py | 2 ++ simple_parsing/parsing.py | 42 ++++++++++++++++++++++-------- simple_parsing/utils.py | 9 ++++--- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/simple_parsing/helpers/__init__.py b/simple_parsing/helpers/__init__.py index b8546e1c..5a2088d4 100644 --- a/simple_parsing/helpers/__init__.py +++ b/simple_parsing/helpers/__init__.py @@ -14,6 +14,7 @@ from .hparams import HyperParameters from .partial import Partial, config_for from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode +from .subgroups import subgroups try: from .serialization import YamlSerializable @@ -45,4 +46,5 @@ "subparsers", "flag", "flags", + "subgroups", ] diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index 794e363b..c355fbc0 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -531,8 +531,8 @@ def _preprocessing(self, args: Sequence[str] = (), namespace: Namespace | None = # Create one argument group per dataclass for wrapped_dataclass in wrapped_dataclasses: logger.debug( - f"Parser {id(self)} is Adding arguments for dataclass: {wrapped_dataclass.dataclass} " - f"at destinations {wrapped_dataclass.destinations}" + f"Parser {id(self)} is Adding arguments for dataclass: " + f"{wrapped_dataclass.dataclass} at destinations {wrapped_dataclass.destinations}" ) wrapped_dataclass.add_arguments(parser=self) @@ -636,7 +636,8 @@ def _resolve_subgroups( # Do rounds of parsing with just the subgroup arguments, until all the subgroups # are resolved to a dataclass type. logger.debug( - f"Starting subgroup parsing round {current_nesting_level}: {list(unresolved_subgroups.keys())}" + f"Starting subgroup parsing round {current_nesting_level}: " + f"{list(unresolved_subgroups.keys())}" ) # Add all the unresolved subgroups arguments. for dest, subgroup_field in unresolved_subgroups.items(): @@ -877,8 +878,9 @@ def _instantiate_dataclasses( existing = getattr(parsed_args, destination) if dc_wrapper.dest in self._defaults: logger.debug( - f"Overwriting defaults in the namespace at destination '{destination}' " - f"on the Namespace ({existing}) to a value of {value_for_dataclass_field}" + f"Overwriting defaults in the namespace at destination " + f"'{destination}' on the Namespace ({existing}) to a value of " + f"{value_for_dataclass_field}" ) setattr(parsed_args, destination, value_for_dataclass_field) else: @@ -938,9 +940,28 @@ def _fill_constructor_arguments_with_fields( parsed_arg_values = vars(parsed_args) deleted_values: dict[str, Any] = {} - for wrapper in wrappers: - for field in wrapper.fields: - if argparse.SUPPRESS in wrapper.defaults and field.dest not in parsed_args: + # BUG: Need to check that the non-init fields DO have a FieldWrapper here, and that there + # isn't a value for that field in the constructor arguments! + + for dc_wrapper in wrappers: + for non_init_field in [ + f for f in dataclasses.fields(dc_wrapper.dataclass) if not f.init + ]: + field_dest = dc_wrapper.dest + "." + non_init_field.name + # We fetch the constructor arguments for the containing dataclass and check that it + # doesn't have a value set. + dc_constructor_args = constructor_arguments + for dest_part in dc_wrapper.dest.split("."): + dc_constructor_args = dc_constructor_args[dest_part] + if non_init_field.name in dc_constructor_args: + logger.warning( + f"Field {field_dest} is a field with init=False, but a value is " + f"present in the serialized config. This value will be ignored." + ) + dc_constructor_args.pop(non_init_field.name) + + for field in dc_wrapper.fields: + if argparse.SUPPRESS in dc_wrapper.defaults and field.dest not in parsed_args: continue if field.is_subgroup: @@ -948,9 +969,8 @@ def _fill_constructor_arguments_with_fields( logger.debug(f"Not calling the subgroup FieldWrapper for dest {field.dest}") continue - if not field.field.init: - # The field isn't an argument of the dataclass constructor. - continue + # We only create FieldWrappers for fields that have init=True. + assert field.field.init # NOTE: If the field is reused (when using the ConflictResolution.ALWAYS_MERGE # strategy), then we store the multiple values in the `dest` of the first field. diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 752fbbcb..a4bb4d8b 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -210,7 +210,8 @@ def get_argparse_type_for_container( 'str' is returned. Arguments: - container_type {Type} -- A container type (ideally a typing.Type such as List, Tuple, along with an item annotation: List[str], Tuple[int, int], etc.) + container_type -- A container type (ideally a typing.Type such as List, Tuple, along + with an item annotation: List[str], Tuple[int, int], etc.) Returns: typing.Type -- the type that should be used in argparse 'type' argument option. @@ -649,7 +650,8 @@ def _parse(value: str) -> list[Any]: # if it doesn't work, fall back to the parse_fn. values = _fallback_parse(value) - # we do the default 'argparse' action, which is to add the values to a bigger list of values. + # we do the default 'argparse' action, which is to add the values to a bigger list of + # values. # result.extend(values) logger.debug(f"returning values: {values}") return values @@ -662,7 +664,8 @@ def _parse_literal(value: str) -> list[Any] | Any: literal = ast.literal_eval(value) logger.debug(f"Parsed literal: {literal}") if not isinstance(literal, (list, tuple)): - # we were passed a single-element container, like "--some_list 1", which should give [1]. + # we were passed a single-element container, like "--some_list 1", which should give + # [1]. # We therefore return the literal itself, and argparse will append it. return T(literal) else: From 1d1a4704716d53d04b49791abcd5fc1150ccc9fd Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 27 Oct 2023 11:52:28 -0400 Subject: [PATCH 07/20] Fix non-init fields being passed to DC's __init__ Signed-off-by: Fabrice Normandin --- simple_parsing/parsing.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index c355fbc0..ad676111 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -948,11 +948,14 @@ def _fill_constructor_arguments_with_fields( f for f in dataclasses.fields(dc_wrapper.dataclass) if not f.init ]: field_dest = dc_wrapper.dest + "." + non_init_field.name - # We fetch the constructor arguments for the containing dataclass and check that it + # Fetch the constructor arguments for the containing dataclass and check that it # doesn't have a value set. - dc_constructor_args = constructor_arguments - for dest_part in dc_wrapper.dest.split("."): - dc_constructor_args = dc_constructor_args[dest_part] + # NOTE: The `constructor_arguments` dict is FLAT here, so each dataclass has its + # own corresponding arguments at their destination, like `"a.b": {}`. + dc_constructor_args = constructor_arguments[dc_wrapper.dest] + + # for dest_part in dc_wrapper.dest.split("."): + # dc_constructor_args = dc_constructor_args[dest_part] if non_init_field.name in dc_constructor_args: logger.warning( f"Field {field_dest} is a field with init=False, but a value is " From b47f0367a28278a65d763a8bdb8a860d0d72f0ae Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 27 Oct 2023 11:59:23 -0400 Subject: [PATCH 08/20] Deprecate a few unused functions in utils.py Signed-off-by: Fabrice Normandin --- simple_parsing/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index a4bb4d8b..8271e5b5 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -35,7 +35,7 @@ overload, ) -from typing_extensions import Literal, Protocol, TypeGuard, get_args, get_origin +from typing_extensions import Literal, Protocol, TypeGuard, deprecated, get_args, get_origin # There are cases where typing.Literal doesn't match typing_extensions.Literal: # https://github.com/python/typing_extensions/pull/148 @@ -114,6 +114,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +@deprecated("This is unused internally and will be removed soon.") def camel_case(name): s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() @@ -613,6 +614,7 @@ def get_container_nargs(container_type: type) -> int | str: raise NotImplementedError(f"Not sure what 'nargs' should be for type {container_type}") +@deprecated("This is likely going to be removed soon.") def _parse_multiple_containers( container_type: type, append_action: bool = False ) -> Callable[[str], list[Any]]: @@ -887,6 +889,7 @@ def dict_union(*dicts: dict[K, V], recurse: bool = True, dict_factory=dict) -> d return result +@deprecated("This is buggy and unused internally and will be removed soon.") def flatten(nested: PossiblyNestedMapping[K, V]) -> dict[tuple[K, ...], V]: """Flatten a dictionary of dictionaries. The returned dictionary's keys are tuples, one entry per layer. @@ -953,16 +956,19 @@ def unflatten_split( return unflatten({tuple(key.split(sep)): value for key, value in flattened.items()}) +@deprecated("This is unused internally and will be removed soon.") @overload def getitem_recursive(d: PossiblyNestedDict[K, V], keys: Iterable[K]) -> V: ... +@deprecated("This is unused internally and will be removed soon.") @overload def getitem_recursive(d: PossiblyNestedDict[K, V], keys: Iterable[K], default: T) -> V | T: ... +@deprecated("This is unused internally and will be removed soon.") def getitem_recursive( d: PossiblyNestedDict[K, V], keys: Iterable[K], default: T | _MISSING_TYPE = MISSING ) -> V | T: From 8b56786a0e606caccb8d1344b1ba800683b1ecfc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 27 Oct 2023 15:27:50 -0400 Subject: [PATCH 09/20] Add temporary fix for the config _type_ issue Signed-off-by: Fabrice Normandin --- simple_parsing/helpers/subgroups.py | 18 ++- simple_parsing/parsing.py | 182 +++++++++++++++++++++-- simple_parsing/replace.py | 8 +- simple_parsing/wrappers/field_wrapper.py | 6 +- test/test_subgroups.py | 16 +- 5 files changed, 204 insertions(+), 26 deletions(-) diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index 4f4a9197..509b2b91 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -6,7 +6,7 @@ from dataclasses import _MISSING_TYPE, MISSING from enum import Enum from logging import getLogger as get_logger -from typing import Any, Callable, TypeVar, Union, overload +from typing import Any, Callable, Mapping, TypeVar, Union from typing_extensions import TypeAlias @@ -21,7 +21,7 @@ def subgroups( - subgroups: dict[Key, type[DC] | functools.partial[DC]], + subgroups: Mapping[Key, type[DC] | functools.partial[DC]], *args, default: Key | _MISSING_TYPE = MISSING, default_factory: type[DC] | functools.partial[DC] | _MISSING_TYPE = MISSING, @@ -59,8 +59,8 @@ def subgroups( "dataclass." ) if default not in subgroups.values(): - # TODO: (@lebrice): Do we really need to enforce this? What is the reasoning behind this - # restriction again? + # NOTE: The reason we enforce this is perhaps artificial, but it's because the way we + # implement subgroups requires us to know the key that is selected in the dict. raise ValueError(f"Default value {default} needs to be a value in the subgroups dict.") elif default is not MISSING and default not in subgroups.keys(): raise ValueError("default must be a key in the subgroups dict!") @@ -186,7 +186,8 @@ def _get_dataclass_type_from_callable( return dataclass_fn.func # partial to a function that should return a dataclass. Hopefully it has a return type # annotation, otherwise we'd have to call the function just to know the return type! - # NOTE: recurse here, so it also works with `partial(partial(...))` and `partial(some_function)` + # NOTE: recurse here, so it also works with `partial(partial(...))` and + # `partial(some_function)` return _get_dataclass_type_from_callable( dataclass_fn=dataclass_fn.func, caller_frame=caller_frame ) @@ -218,7 +219,8 @@ def _get_dataclass_type_from_callable( caller_globals = caller_frame.f_globals try: - # NOTE: This doesn't seem to be very often different than just calling `get_type_hints` + # NOTE: This doesn't seem to be very often different than just calling + # `get_type_hints` type_hints = typing.get_type_hints( dataclass_fn, globalns=caller_globals, localns=caller_locals ) @@ -229,8 +231,8 @@ def _get_dataclass_type_from_callable( type_hints = typing.get_type_hints(dataclass_fn) dataclass_fn_type = type_hints["return"] - # Recursing here would be a bit extra, let's be real. Might be good enough to just assume that - # the return annotation needs to be a dataclass. + # Recursing here would be a bit extra, let's be real. Might be good enough to just assume + # that the return annotation needs to be a dataclass. # return _get_dataclass_type_from_callable(dataclass_fn_type, caller_frame=caller_frame) assert is_dataclass_type(dataclass_fn_type) return dataclass_fn_type diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index ad676111..a4dad2fe 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -7,6 +7,7 @@ import argparse import dataclasses import functools +import inspect import itertools import shlex import sys @@ -15,18 +16,23 @@ from collections import defaultdict from logging import getLogger from pathlib import Path -from typing import Any, Callable, Sequence, Type, overload - +from typing import Any, Callable, Mapping, Sequence, Type, overload +from typing_extensions import TypeGuard +import warnings from simple_parsing.helpers.subgroups import SubgroupKey +from simple_parsing.replace import SUBGROUP_KEY_FLAG from simple_parsing.wrappers.dataclass_wrapper import DataclassWrapperType from . import utils from .conflicts import ConflictResolution, ConflictResolver from .help_formatter import SimpleHelpFormatter -from .helpers.serialization.serializable import read_file +from .helpers.serialization.serializable import DC_TYPE_KEY, read_file from .utils import ( + K, + V, Dataclass, DataclassT, + PossiblyNestedDict, dict_union, is_dataclass_instance, is_dataclass_type, @@ -593,7 +599,7 @@ def _resolve_subgroups( This modifies the wrappers in-place, by possibly adding children to the wrappers in the list. - Returns a list with the modified wrappers. + Returns a list with the (now modified) wrappers. Each round does the following: 1. Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting @@ -618,13 +624,7 @@ def _resolve_subgroups( # times. subgroup_choice_parser = argparse.ArgumentParser( add_help=False, - # conflict_resolution=self.conflict_resolution, - # add_option_string_dash_variants=self.add_option_string_dash_variants, - # argument_generation_mode=self.argument_generation_mode, - # nested_mode=self.nested_mode, formatter_class=self.formatter_class, - # add_config_path_arg=self.add_config_path_arg, - # config_path=self.config_path, # NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues # for example if you have —a_or_b and A has a field —a then it will error out if you # pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it @@ -644,10 +644,27 @@ def _resolve_subgroups( flags = subgroup_field.option_strings argument_options = subgroup_field.arg_options + # Sanity checks: if subgroup_field.subgroup_default is dataclasses.MISSING: assert argument_options["required"] + elif isinstance(argument_options["default"], dict): + # BUG #276: The default here is a dict because it came from a config file. + # Here we want the subgroup field to have a 'str' default, because we just want + # to be able to choose between the subgroup names. + _default = argument_options["default"] + _default_key = _infer_subgroup_key_to_use_from_config( + default_in_config=_default, + # subgroup_default=subgroup_field.subgroup_default, + subgroup_choices=subgroup_field.subgroup_choices, + ) + # We'd like this field to (at least temporarily) have a different default + # value that is the subgroup key instead of the dictionary. + argument_options["default"] = _default_key + else: - assert argument_options["default"] is subgroup_field.subgroup_default + assert ( + argument_options["default"] is subgroup_field.subgroup_default + ), argument_options["default"] assert not is_dataclass_instance(argument_options["default"]) # TODO: Do we really need to care about this "SUPPRESS" stuff here? @@ -1177,3 +1194,146 @@ def _create_dataclass_instance( return None logger.debug(f"Calling constructor: {constructor}(**{constructor_args})") return constructor(**constructor_args) + + +def _has_values_of_type( + mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...] +) -> TypeGuard[Mapping[K, V]]: + # Utility functions used to narrow the type of dictionaries. + return all(isinstance(v, value_type) for v in mapping.values()) + + +def _has_keys_of_type( + mapping: Mapping[Any, V], key_type: type[K] | tuple[type[K], ...] +) -> TypeGuard[Mapping[K, V]]: + # Utility functions used to narrow the type of dictionaries. + return all(isinstance(k, key_type) for k in mapping.keys()) + + +def _has_items_of_type( + mapping: Mapping[Any, Any], + item_type: tuple[type[K] | tuple[type[K], ...], type[V] | tuple[type[V], ...]], +) -> TypeGuard[Mapping[K, V]]: + # Utility functions used to narrow the type of a dictionary or mapping. + key_type, value_type = item_type + return _has_keys_of_type(mapping, key_type) and _has_values_of_type(mapping, value_type) + + +def _infer_subgroup_key_to_use_from_config( + default_in_config: dict[str, Any], + # subgroup_default: Hashable, + subgroup_choices: Mapping[SubgroupKey, type[Dataclass] | functools.partial[Dataclass]], +) -> SubgroupKey: + config_default = default_in_config + + if SUBGROUP_KEY_FLAG in default_in_config: + return default_in_config[SUBGROUP_KEY_FLAG] + + for subgroup_key, subgroup_value in subgroup_choices.items(): + if default_in_config == subgroup_value: + return subgroup_key + + assert ( + DC_TYPE_KEY in config_default + ), f"FIXME: assuming that the {DC_TYPE_KEY} is in the config dict." + _default_type_name: str = config_default[DC_TYPE_KEY] + + if _has_values_of_type(subgroup_choices, type) and all( + dataclasses.is_dataclass(subgroup_option) for subgroup_option in subgroup_choices.values() + ): + # Simpler case: All the subgroup options are dataclass types. We just get the key that + # matches the type that was saved in the config dict. + subgroup_keys_with_value_matching_config_default_type: list[SubgroupKey] = [ + k + for k, v in subgroup_choices.items() + if (isinstance(v, type) and f"{v.__module__}.{v.__qualname__}" == _default_type_name) + ] + # NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})` + assert len(subgroup_keys_with_value_matching_config_default_type) >= 1 + return subgroup_keys_with_value_matching_config_default_type[0] + + # IDEA: Try to find the best subgroup key to use, based on the number of matching constructor + # arguments between the default in the config and the defaults for each subgroup. + constructor_args_in_each_subgroup = { + key: _default_constructor_argument_values(subgroup_value) + for key, subgroup_value in subgroup_choices.items() + } + n_matching_values = { + k: _num_matching_values(config_default, constructor_args_in_subgroup_value) + for k, constructor_args_in_subgroup_value in constructor_args_in_each_subgroup.items() + } + closest_subgroups_first = sorted( + subgroup_choices.keys(), + key=n_matching_values.__getitem__, + reverse=True, + ) + warnings.warn( + # TODO: Return the dataclass type instead, and be done with it! + RuntimeWarning( + f"TODO: The config file contains a default value for a subgroup that isn't in the " + f"dict of subgroup options. Because of how subgroups are currently implemented, we " + f"need to find the key in the subgroup choice dict ({subgroup_choices}) that most " + f"closely matches the value {config_default}." + f"The current implementation tries to use the dataclass type of this closest match " + f"to parse the additional values from the command-line. " + f"{default_in_config}. Consider adding the " + f"{SUBGROUP_KEY_FLAG}: " + ) + ) + return closest_subgroups_first[0] + return closest_subgroups_first[0] + + sorted( + [k for k, v in subgroup_choices.items()], + key=_num_matching_values, + reversed=True, + ) + # _default_values = copy.deepcopy(config_default) + # _default_values.pop(DC_TYPE_KEY) + + # default_constructor_args_for_each_subgroup = { + # k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type) + # } + + +def _default_constructor_argument_values( + some_dataclass_type: type[Dataclass] | functools.partial[Dataclass], +) -> PossiblyNestedDict[str, Any]: + result = {} + if isinstance(some_dataclass_type, functools.partial) and is_dataclass_type( + some_dataclass_type.func + ): + constructor_arguments_from_classdef = _default_constructor_argument_values( + some_dataclass_type.func + ) + # TODO: will probably raise an error! + constructor_arguments_from_partial = ( + inspect.signature(some_dataclass_type.func) + .bind_partial(*some_dataclass_type.args, **some_dataclass_type.keywords) + .arguments + ) + constructor_arguments_from_classdef.update(constructor_arguments_from_partial) + return constructor_arguments_from_classdef + + assert is_dataclass_type(some_dataclass_type) + for field in dataclasses.fields(some_dataclass_type): + key = field.name + if field.default is not dataclasses.MISSING: + result[key] = field.default + elif is_dataclass_type(field.type) or ( + isinstance(field.default_factory, functools.partial) + and dataclasses.is_dataclass(field.default_factory.func) + ): + result[key] = _default_constructor_argument_values(field.type) + return result + + +def _num_matching_values(subgroup_default: dict[str, Any], subgroup_choice: dict[str, Any]) -> int: + """Returns the number of matching entries in the subgroup dict w/ the default from the + config.""" + return sum( + _num_matching_values(default_v, subgroup_choice[k]) + if isinstance(subgroup_choice.get(k), dict) and isinstance(default_v, dict) + else int(subgroup_choice.get(k) == default_v) + for k, default_v in subgroup_default.items() + ) diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index db350fba..0f25ec61 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -180,8 +180,14 @@ def replace_subgroups( return dataclasses.replace(obj, **replace_kwargs) +SUBGROUP_KEY_FLAG = "__key__" + + def _unflatten_selection_dict( - flattened: Mapping[str, V], keyword: str = "__key__", sep: str = ".", recursive: bool = True + flattened: Mapping[str, V], + keyword: str = SUBGROUP_KEY_FLAG, + sep: str = ".", + recursive: bool = True, ) -> PossiblyNestedDict[str, V]: """This function convert a flattened dict into a nested dict and it inserts the `keyword` as the selection into the nested dict. diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 3a4d1860..fa2839dd 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -962,10 +962,12 @@ def subgroup_choices(self) -> dict[Hashable, Callable[[], Dataclass] | Dataclass return self.field.metadata["subgroups"] @property - def subgroup_default(self) -> Hashable | Literal[dataclasses.MISSING] | None: + def subgroup_default(self) -> Hashable | Literal[dataclasses.MISSING]: if not self.is_subgroup: raise RuntimeError(f"Field {self.field} doesn't have subgroups! ") - return self.field.metadata.get("subgroup_default") + subgroup_default_key = self.field.metadata.get("subgroup_default") + assert subgroup_default_key is not None + return subgroup_default_key @property def type_arguments(self) -> tuple[type, ...] | None: diff --git a/test/test_subgroups.py b/test/test_subgroups.py index 56164a19..0a39cb4d 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -938,26 +938,31 @@ def test_ordering_of_args_doesnt_matter(): ) -@dataclass +@dataclass(frozen=True, unsafe_hash=True) class A1: a_val: int = 1 -@dataclass +@dataclass(frozen=True, unsafe_hash=True) class A2: a_val: int = 2 +also_a2_default = functools.partial(A2, a_val=123) + + @dataclass class A1OrA2: - a: A1 | A2 = subgroups({"a1": A1, "a2": A2}, default="a1") + a: A1 | A2 = subgroups({"a1": A1, "a2": A2, "also_a2": also_a2_default}, default="a1") @pytest.mark.parametrize( ("value_in_config", "args", "expected"), [ - (A1OrA2(a=A2()), "", A1OrA2(a=A1())), + (A1OrA2(a=A2()), "", A1OrA2(a=A2())), (A1OrA2(a=A1()), "", A1OrA2(a=A1())), + (A1OrA2(a=A1()), "--a=a2", A1OrA2(a=A2())), + (A1OrA2(a=also_a2_default()), "", A1OrA2(a=also_a2_default())), ], ) @pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"]) @@ -968,6 +973,9 @@ def test_parse_with_config_file_with_different_subgroup( args: str, expected: A1OrA2, ): + """TODO: Honestly not 100% sure what I was testing here.""" + # I think I was trying to reproduce the issue from #276 + config_path = (tmp_path / "bob").with_suffix(filetype) save(value_in_config, config_path, save_dc_types=True) From aacecd629b94d43e6c63e5e11ea0bc3373d11243 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 27 Oct 2023 15:47:45 -0400 Subject: [PATCH 10/20] Leave a TODO for nested edge case, move fns around Signed-off-by: Fabrice Normandin --- simple_parsing/parsing.py | 48 ++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index a4dad2fe..f2bd2d9c 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -648,6 +648,8 @@ def _resolve_subgroups( if subgroup_field.subgroup_default is dataclasses.MISSING: assert argument_options["required"] elif isinstance(argument_options["default"], dict): + # TODO: In this case here, the value of a nested subgroup in this default dict + # should also be used! # BUG #276: The default here is a dict because it came from a config file. # Here we want the subgroup field to have a 'str' default, because we just want # to be able to choose between the subgroup names. @@ -1196,29 +1198,6 @@ def _create_dataclass_instance( return constructor(**constructor_args) -def _has_values_of_type( - mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...] -) -> TypeGuard[Mapping[K, V]]: - # Utility functions used to narrow the type of dictionaries. - return all(isinstance(v, value_type) for v in mapping.values()) - - -def _has_keys_of_type( - mapping: Mapping[Any, V], key_type: type[K] | tuple[type[K], ...] -) -> TypeGuard[Mapping[K, V]]: - # Utility functions used to narrow the type of dictionaries. - return all(isinstance(k, key_type) for k in mapping.keys()) - - -def _has_items_of_type( - mapping: Mapping[Any, Any], - item_type: tuple[type[K] | tuple[type[K], ...], type[V] | tuple[type[V], ...]], -) -> TypeGuard[Mapping[K, V]]: - # Utility functions used to narrow the type of a dictionary or mapping. - key_type, value_type = item_type - return _has_keys_of_type(mapping, key_type) and _has_values_of_type(mapping, value_type) - - def _infer_subgroup_key_to_use_from_config( default_in_config: dict[str, Any], # subgroup_default: Hashable, @@ -1296,6 +1275,29 @@ def _infer_subgroup_key_to_use_from_config( # } +def _has_values_of_type( + mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...] +) -> TypeGuard[Mapping[K, V]]: + # Utility functions used to narrow the type of dictionaries. + return all(isinstance(v, value_type) for v in mapping.values()) + + +def _has_keys_of_type( + mapping: Mapping[Any, V], key_type: type[K] | tuple[type[K], ...] +) -> TypeGuard[Mapping[K, V]]: + # Utility functions used to narrow the type of dictionaries. + return all(isinstance(k, key_type) for k in mapping.keys()) + + +def _has_items_of_type( + mapping: Mapping[Any, Any], + item_type: tuple[type[K] | tuple[type[K], ...], type[V] | tuple[type[V], ...]], +) -> TypeGuard[Mapping[K, V]]: + # Utility functions used to narrow the type of a dictionary or mapping. + key_type, value_type = item_type + return _has_keys_of_type(mapping, key_type) and _has_values_of_type(mapping, value_type) + + def _default_constructor_argument_values( some_dataclass_type: type[Dataclass] | functools.partial[Dataclass], ) -> PossiblyNestedDict[str, Any]: From 234612ca3f23eef3c595b44218f53346867cfd99 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 21 Dec 2023 10:47:43 -0500 Subject: [PATCH 11/20] Remove non-init fields in set_default Signed-off-by: Fabrice Normandin --- simple_parsing/wrappers/dataclass_wrapper.py | 24 +++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 6a86cbe7..33e27241 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -171,7 +171,8 @@ def __init__( # a "normal" attribute field_wrapper = self.field_wrapper_class(field, parent=self, prefix=self.prefix) logger.debug( - f"wrapped field at {field_wrapper.dest} has a default value of {field_wrapper.default}" + f"wrapped field at {field_wrapper.dest} has a default value of " + f"{field_wrapper.default}" ) if field_default is not dataclasses.MISSING: field_wrapper.set_default(field_default) @@ -216,9 +217,12 @@ def add_arguments(self, parser: argparse.ArgumentParser): def equivalent_argparse_code(self, leading="group") -> str: code = "" code += textwrap.dedent( - f""" - group = parser.add_argument_group(title="{self.title.strip()}", description="{self.description.strip()}") - """ + f"""\ + group = parser.add_argument_group( + title="{self.title.strip()}", + description="{self.description.strip()}", + ) + """ ) for wrapped_field in self.fields: if wrapped_field.is_subparser: @@ -294,6 +298,11 @@ def set_default(self, value: DataclassT | dict | None): self._default = value if field_default_values is None: return + # Ignore default values for fields that have init=False. + for field in dataclasses.fields(self.dataclass): + if not field.init and field.name in field_default_values: + field_default_values.pop(field.name) + unknown_names = set(field_default_values) for field_wrapper in self.fields: if field_wrapper.name not in field_default_values: @@ -314,7 +323,8 @@ def set_default(self, value: DataclassT | dict | None): unknown_names.discard("_type_") if unknown_names: raise RuntimeError( - f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!" + f"{sorted(unknown_names)} are not fields of {self.dataclass} at path " + f"{self.dest!r}!" ) @property @@ -431,7 +441,9 @@ def merge(self, other: DataclassWrapper): # logger.debug(f"merging \n{self}\n with \n{other}") logger.debug(f"self destinations: {self.destinations}") logger.debug(f"other destinations: {other.destinations}") - # assert not set(self.destinations).intersection(set(other.destinations)), "shouldn't have overlap in destinations" + # assert not set(self.destinations).intersection(set(other.destinations)), ( + # "shouldn't have overlap in destinations" + # ) # self.destinations.extend(other.destinations) for dest in other.destinations: if dest not in self.destinations: From 8715594417a771699f42147733a913404ae0d6db Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 21 Dec 2023 11:15:35 -0500 Subject: [PATCH 12/20] Add test and use trick from anivegesena Signed-off-by: Fabrice Normandin --- .pre-commit-config.yaml | 3 +- .../helpers/serialization/serializable.py | 24 +++++++++---- simple_parsing/helpers/subgroups.py | 14 +++++--- test/test_subgroups.py | 34 +++++++++++++++---- 4 files changed, 56 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92165901..e3df32e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: # python docstring formatting - repo: https://github.com/myint/docformatter - rev: v1.5.1 + rev: v1.7.5 hooks: - id: docformatter exclude: ^test/test_docstrings.py @@ -63,7 +63,6 @@ repos: - id: nbstripout require_serial: true - # md formatting - repo: https://github.com/executablebooks/mdformat rev: 0.7.16 diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index 83ab5498..4206f6cc 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -208,7 +208,8 @@ def __init_subclass__( if parent in SerializableMixin.subclasses and parent is not SerializableMixin: decode_into_subclasses = parent.decode_into_subclasses logger.debug( - f"Parent class {parent} has decode_into_subclasses = {decode_into_subclasses}" + f"Parent class {parent} has decode_into_subclasses = " + f"{decode_into_subclasses}" ) break @@ -220,7 +221,10 @@ def __init_subclass__( register_decoding_fn(cls, cls.from_dict) def to_dict( - self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: bool = False + self, + dict_factory: type[dict] = dict, + recurse: bool = True, + save_dc_types: bool | int = False, ) -> dict: """Serializes this dataclass to a dict. @@ -596,6 +600,7 @@ def loads_yaml( def read_file(path: str | Path) -> dict: """Returns the contents of the given file as a dictionary. + Uses the right function depending on `path.suffix`: { ".yml": yaml.safe_load, @@ -614,7 +619,7 @@ def save( obj: Any, path: str | Path, format: FormatExtension | None = None, - save_dc_types: bool = False, + save_dc_types: bool | int = False, **kwargs, ) -> None: """Save the given dataclass or dictionary to the given file.""" @@ -705,7 +710,7 @@ def to_dict( dc: DataclassT, dict_factory: type[dict] = dict, recurse: bool = True, - save_dc_types: bool = False, + save_dc_types: bool | int = False, ) -> dict: """Serializes this dataclass to a dict. @@ -737,6 +742,11 @@ def to_dict( else: d[DC_TYPE_KEY] = module + "." + class_name + # Decrement save_dc_types if it is an int, so that we only save the type of the subgroups + # dataclass, not all dataclasses recursively. + if save_dc_types is not True and save_dc_types > 0: + save_dc_types -= 1 + for f in fields(dc): name = f.name value = getattr(dc, name) @@ -764,7 +774,8 @@ def to_dict( encoded = encoding_fn(value) except Exception as e: logger.error( - f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. (exception: {e})" + f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. " + f"(exception: {e})" ) encoded = value d[name] = encoded @@ -971,7 +982,8 @@ def _locate(path: str) -> Any: except ModuleNotFoundError as exc_import: raise ImportError( f"Error loading '{path}':\n{repr(exc_import)}" - + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" + + f"\nAre you sure that '{part}' is importable from module " + f"'{parent_dotpath}'?" ) from exc_import except Exception as exc_import: raise ImportError( diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index 509b2b91..3f9715cd 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -10,7 +10,8 @@ from typing_extensions import TypeAlias -from simple_parsing.utils import DataclassT, is_dataclass_instance, is_dataclass_type +from simple_parsing.helpers.serialization.serializable import to_dict +from simple_parsing.utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type logger = get_logger(__name__) @@ -80,7 +81,11 @@ def subgroups( metadata["subgroup_default"] = default metadata["subgroup_dataclass_types"] = {} - subgroup_dataclass_types: dict[Key, type[DataclassT]] = {} + # Custom encoding function that will add the _type_ key with the subgroup dataclass type. + # Using an int here means that only to the subgroup dataclass. + kwargs.setdefault("encoding_fn", functools.partial(to_dict, save_dc_types=1)) + + subgroup_dataclass_types: dict[Key, type[Dataclass]] = {} choices = subgroups.keys() # NOTE: Perhaps we could raise a warning if the default_factory is a Lambda, since we have to @@ -198,7 +203,8 @@ def _get_dataclass_type_from_callable( f"{dataclass_fn!r}, because it doesn't have a return type annotation, and we don't " f"want to call it just to figure out what it produces." ) - # NOTE: recurse here, so it also works with `partial(partial(...))` and `partial(some_function)` + # NOTE: recurse here, so it also works with `partial(partial(...))` and + # `partial(some_function)` # Recurse, so this also works with partial(partial(...)) (idk why you'd do that though.) if isinstance(signature.return_annotation, str): @@ -241,7 +247,7 @@ def _get_dataclass_type_from_callable( def is_lambda(obj: Any) -> bool: """Returns True if the given object is a lambda expression. - Taken froma-lambda + Taken from https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda """ LAMBDA = lambda: 0 # noqa: E731 return isinstance(obj, type(LAMBDA)) and obj.__name__ == LAMBDA.__name__ diff --git a/test/test_subgroups.py b/test/test_subgroups.py index 0a39cb4d..fa2c8fde 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -11,11 +11,13 @@ from typing import Callable, TypeVar import pytest -from simple_parsing.helpers.serialization import save from pytest_regressions.file_regression import FileRegressionFixture from typing_extensions import Annotated +from simple_parsing.utils import Dataclass +from simple_parsing.helpers.serialization import save from simple_parsing import ArgumentParser, parse, subgroups +from simple_parsing.helpers.serialization.serializable import from_dict, to_dict from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode from .test_choice import Color @@ -426,7 +428,10 @@ class Foo(TestSetup): marks=pytest.mark.xfail( strict=True, raises=NotImplementedError, - reason="Lambda expressions aren't allowed in the subgroup dict or default_factory at the moment.", + reason=( + "Lambda expressions aren't allowed in the subgroup dict or default_factory at the " + "moment." + ), ), ) @@ -785,11 +790,12 @@ def test_help( # ModelConfig = _ModelConfig() # SmallModel = _ModelConfig(num_layers=1, hidden_dim=32) # BigModel = _ModelConfig(num_layers=32, hidden_dim=128) - -# @dataclasses.dataclass -# class Config(TestSetup): -# model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel) - +# +# @dataclasses.dataclass +# class Config(TestSetup): +# model: Model = subgroups({"small": SmallModel, "big": BigModel}, +# default_factory=SmallModel) +# # assert Config.setup().model == SmallModel() # # Hopefully this illustrates why Annotated aren't exactly great: # # At runtime, they are basically the same as the original dataclass when called. @@ -980,3 +986,17 @@ def test_parse_with_config_file_with_different_subgroup( save(value_in_config, config_path, save_dc_types=True) assert parse(A1OrA2, config_path=config_path, args=args) == expected + + +@pytest.mark.parametrize( + "value", + [ + A1OrA2(), + A1OrA2(a=A2(a_val=2)), + ], +) +def test_roundtrip(value: Dataclass): + """Test to reproduce + https://github.com/lebrice/SimpleParsing/pull/284#issuecomment-1783490388.""" + assert from_dict(type(value), to_dict(value)) == value + assert to_dict(from_dict(type(value), to_dict(value))) == to_dict(value) From 459d3f711aaef32f32621a0e5868e58e7c6344a0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 21 Dec 2023 11:23:03 -0500 Subject: [PATCH 13/20] Fix broken tests for equivalent_argparse_code Signed-off-by: Fabrice Normandin --- examples/simple/basic.py | 5 ++++- examples/simple/inheritance.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/simple/basic.py b/examples/simple/basic.py index f66a6fee..75e94c08 100644 --- a/examples/simple/basic.py +++ b/examples/simple/basic.py @@ -49,7 +49,10 @@ class HParams: expected += """ parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) -group = parser.add_argument_group(title="HParams ['hparams']", description="Set of options for the training of a Model.") +group = parser.add_argument_group( + title="HParams ['hparams']", + description="Set of options for the training of a Model.", +) group.add_argument(*['--num_layers'], **{'type': int, 'required': False, 'dest': 'hparams.num_layers', 'default': 4, 'help': ' '}) group.add_argument(*['--num_units'], **{'type': int, 'required': False, 'dest': 'hparams.num_units', 'default': 64, 'help': ' '}) group.add_argument(*['--optimizer'], **{'type': str, 'required': False, 'dest': 'hparams.optimizer', 'default': 'ADAM', 'help': ' '}) diff --git a/examples/simple/inheritance.py b/examples/simple/inheritance.py index b116135d..de8c42c4 100644 --- a/examples/simple/inheritance.py +++ b/examples/simple/inheritance.py @@ -58,7 +58,10 @@ class MAML(Method): expected += """ parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) -group = parser.add_argument_group(title="MAML ['hparams']", description="Overwrites some of the default values and adds new arguments/attributes.") +group = parser.add_argument_group( + title="MAML ['hparams']", + description="Overwrites some of the default values and adds new arguments/attributes.", +) group.add_argument(*['--num_layers'], **{'type': int, 'required': False, 'dest': 'hparams.num_layers', 'default': 6, 'help': ' '}) group.add_argument(*['--num_units'], **{'type': int, 'required': False, 'dest': 'hparams.num_units', 'default': 128, 'help': ' '}) group.add_argument(*['--optimizer'], **{'type': str, 'required': False, 'dest': 'hparams.optimizer', 'default': 'ADAM', 'help': ' '}) From f4b68fcfe15ea4a9732b7328d99cbb388b21ee88 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 21 Dec 2023 11:27:00 -0500 Subject: [PATCH 14/20] Fix changed regression files Signed-off-by: Fabrice Normandin --- test/test_subgroups/test_help[Config---help].md | 4 ++++ .../test_help[Config---model=model_a --help].md | 4 ++++ .../test_help[Config---model=model_b --help].md | 4 ++++ .../test_help[ConfigWithFrozen---conf=even --a 100 --help].md | 4 ++++ .../test_help[ConfigWithFrozen---conf=even --help].md | 4 ++++ .../test_help[ConfigWithFrozen---conf=odd --a 123 --help].md | 4 ++++ .../test_help[ConfigWithFrozen---conf=odd --help].md | 4 ++++ test/test_subgroups/test_help[ConfigWithFrozen---help].md | 4 ++++ 8 files changed, 32 insertions(+) diff --git a/test/test_subgroups/test_help[Config---help].md b/test/test_subgroups/test_help[Config---help].md index b0632885..931e65ff 100644 --- a/test/test_subgroups/test_help[Config---help].md +++ b/test/test_subgroups/test_help[Config---help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_a --help].md b/test/test_subgroups/test_help[Config---model=model_a --help].md index c20f7faf..7bc268e4 100644 --- a/test/test_subgroups/test_help[Config---model=model_a --help].md +++ b/test/test_subgroups/test_help[Config---model=model_a --help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_b --help].md b/test/test_subgroups/test_help[Config---model=model_b --help].md index 39fdbaa6..21cc825d 100644 --- a/test/test_subgroups/test_help[Config---model=model_b --help].md +++ b/test/test_subgroups/test_help[Config---model=model_b --help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md index 0aab6941..1c23e793 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md index 9057799f..f4b558d3 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md index 262aeb67..c2169036 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md index 61714b33..c5511112 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---help].md b/test/test_subgroups/test_help[ConfigWithFrozen---help].md index 2773c36d..2ef6ddc9 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---help].md @@ -1,4 +1,8 @@ +<<<<<<< HEAD # Regression file for [this test](test/test_subgroups.py:731) +======= +# Regression file for [this test](test/test_subgroups.py:736) +>>>>>>> Fix changed regression files Given Source code: From f0800c445512f6dd050ae20ec9d443b27355bb73 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 12 Jan 2024 13:10:32 -0500 Subject: [PATCH 15/20] Refactoring subgroup default Signed-off-by: Fabrice Normandin --- simple_parsing/helpers/subgroups.py | 8 +- simple_parsing/parsing.py | 262 ++++++++++++++++------------ test/test_subgroups.py | 44 ++++- 3 files changed, 200 insertions(+), 114 deletions(-) diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index 3f9715cd..3287b393 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -57,12 +57,16 @@ def subgroups( if not isinstance(default, Hashable): raise ValueError( "'default' can either be a key of the subgroups dict or a hashable (frozen) " - "dataclass." + "dataclass in the values of the subgroup dict." ) if default not in subgroups.values(): # NOTE: The reason we enforce this is perhaps artificial, but it's because the way we # implement subgroups requires us to know the key that is selected in the dict. - raise ValueError(f"Default value {default} needs to be a value in the subgroups dict.") + raise ValueError( + f"When passing a dataclass instance as the `default` for the subgroups, it needs " + f"to be a hashable value (e.g. frozen dataclass) in the subgroups dict. " + f"Got {default}" + ) elif default is not MISSING and default not in subgroups.keys(): raise ValueError("default must be a key in the subgroups dict!") diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index f2bd2d9c..7623ca8a 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -16,7 +16,7 @@ from collections import defaultdict from logging import getLogger from pathlib import Path -from typing import Any, Callable, Mapping, Sequence, Type, overload +from typing import Any, Callable, Hashable, Mapping, Sequence, Type, overload from typing_extensions import TypeGuard import warnings from simple_parsing.helpers.subgroups import SubgroupKey @@ -647,32 +647,18 @@ def _resolve_subgroups( # Sanity checks: if subgroup_field.subgroup_default is dataclasses.MISSING: assert argument_options["required"] - elif isinstance(argument_options["default"], dict): - # TODO: In this case here, the value of a nested subgroup in this default dict - # should also be used! - # BUG #276: The default here is a dict because it came from a config file. - # Here we want the subgroup field to have a 'str' default, because we just want - # to be able to choose between the subgroup names. - _default = argument_options["default"] - _default_key = _infer_subgroup_key_to_use_from_config( - default_in_config=_default, - # subgroup_default=subgroup_field.subgroup_default, - subgroup_choices=subgroup_field.subgroup_choices, - ) - # We'd like this field to (at least temporarily) have a different default - # value that is the subgroup key instead of the dictionary. - argument_options["default"] = _default_key - + if "default" in argument_options: + # todo: should ideally not set this in the first place... + assert argument_options["default"] is dataclasses.MISSING + argument_options.pop("default") + assert "default" not in argument_options else: - assert ( - argument_options["default"] is subgroup_field.subgroup_default - ), argument_options["default"] - assert not is_dataclass_instance(argument_options["default"]) - - # TODO: Do we really need to care about this "SUPPRESS" stuff here? - if argparse.SUPPRESS in subgroup_field.parent.defaults: - assert argument_options["default"] is argparse.SUPPRESS - argument_options["default"] = argparse.SUPPRESS + assert "default" in argument_options + assert argument_options["default"] == subgroup_field.default + argument_options["default"] = _adjust_default_value_for_subgroup_field( + subgroup_field=subgroup_field, + subgroup_default=argument_options["default"], + ) logger.debug( f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})" @@ -1198,83 +1184,6 @@ def _create_dataclass_instance( return constructor(**constructor_args) -def _infer_subgroup_key_to_use_from_config( - default_in_config: dict[str, Any], - # subgroup_default: Hashable, - subgroup_choices: Mapping[SubgroupKey, type[Dataclass] | functools.partial[Dataclass]], -) -> SubgroupKey: - config_default = default_in_config - - if SUBGROUP_KEY_FLAG in default_in_config: - return default_in_config[SUBGROUP_KEY_FLAG] - - for subgroup_key, subgroup_value in subgroup_choices.items(): - if default_in_config == subgroup_value: - return subgroup_key - - assert ( - DC_TYPE_KEY in config_default - ), f"FIXME: assuming that the {DC_TYPE_KEY} is in the config dict." - _default_type_name: str = config_default[DC_TYPE_KEY] - - if _has_values_of_type(subgroup_choices, type) and all( - dataclasses.is_dataclass(subgroup_option) for subgroup_option in subgroup_choices.values() - ): - # Simpler case: All the subgroup options are dataclass types. We just get the key that - # matches the type that was saved in the config dict. - subgroup_keys_with_value_matching_config_default_type: list[SubgroupKey] = [ - k - for k, v in subgroup_choices.items() - if (isinstance(v, type) and f"{v.__module__}.{v.__qualname__}" == _default_type_name) - ] - # NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})` - assert len(subgroup_keys_with_value_matching_config_default_type) >= 1 - return subgroup_keys_with_value_matching_config_default_type[0] - - # IDEA: Try to find the best subgroup key to use, based on the number of matching constructor - # arguments between the default in the config and the defaults for each subgroup. - constructor_args_in_each_subgroup = { - key: _default_constructor_argument_values(subgroup_value) - for key, subgroup_value in subgroup_choices.items() - } - n_matching_values = { - k: _num_matching_values(config_default, constructor_args_in_subgroup_value) - for k, constructor_args_in_subgroup_value in constructor_args_in_each_subgroup.items() - } - closest_subgroups_first = sorted( - subgroup_choices.keys(), - key=n_matching_values.__getitem__, - reverse=True, - ) - warnings.warn( - # TODO: Return the dataclass type instead, and be done with it! - RuntimeWarning( - f"TODO: The config file contains a default value for a subgroup that isn't in the " - f"dict of subgroup options. Because of how subgroups are currently implemented, we " - f"need to find the key in the subgroup choice dict ({subgroup_choices}) that most " - f"closely matches the value {config_default}." - f"The current implementation tries to use the dataclass type of this closest match " - f"to parse the additional values from the command-line. " - f"{default_in_config}. Consider adding the " - f"{SUBGROUP_KEY_FLAG}: " - ) - ) - return closest_subgroups_first[0] - return closest_subgroups_first[0] - - sorted( - [k for k, v in subgroup_choices.items()], - key=_num_matching_values, - reversed=True, - ) - # _default_values = copy.deepcopy(config_default) - # _default_values.pop(DC_TYPE_KEY) - - # default_constructor_args_for_each_subgroup = { - # k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type) - # } - - def _has_values_of_type( mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...] ) -> TypeGuard[Mapping[K, V]]: @@ -1330,12 +1239,143 @@ def _default_constructor_argument_values( return result -def _num_matching_values(subgroup_default: dict[str, Any], subgroup_choice: dict[str, Any]) -> int: - """Returns the number of matching entries in the subgroup dict w/ the default from the - config.""" - return sum( - _num_matching_values(default_v, subgroup_choice[k]) - if isinstance(subgroup_choice.get(k), dict) and isinstance(default_v, dict) - else int(subgroup_choice.get(k) == default_v) - for k, default_v in subgroup_default.items() +def _adjust_default_value_for_subgroup_field( + subgroup_field: FieldWrapper, subgroup_default: Any +) -> str | Hashable: + + if argparse.SUPPRESS in subgroup_field.parent.defaults: + assert subgroup_default is argparse.SUPPRESS + assert isinstance(subgroup_default, str) + return subgroup_default + + if isinstance(subgroup_default, dict): + default_from_config_file = subgroup_default + default_from_dataclass_field = subgroup_field.subgroup_default + + if SUBGROUP_KEY_FLAG in default_from_config_file: + _default_subgroup = default_from_config_file[SUBGROUP_KEY_FLAG] + logger.debug(f"Using subgroup key {_default_subgroup} as default (from config file)") + return _default_subgroup + + if DC_TYPE_KEY in default_from_config_file: + # The type of dataclass is specified in the config file. + # We can use that to figure out which subgroup to use. + default_dataclass_type_from_config = default_from_config_file[DC_TYPE_KEY] + if isinstance(default_dataclass_type_from_config, str): + from simple_parsing.helpers.serialization.serializable import _locate + + # Try to import the type of dataclass given its import path as a string in the + # config file. + default_dataclass_type_from_config = _locate(default_dataclass_type_from_config) + assert is_dataclass_type(default_dataclass_type_from_config) + + from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable + + subgroup_choices_with_matching_type: dict[ + Hashable, Dataclass | Callable[[], Dataclass] + ] = { + subgroup_key: subgroup_value + for subgroup_key, subgroup_value in subgroup_field.subgroup_choices.items() + if is_dataclass_type(subgroup_value) + and subgroup_value == default_dataclass_type_from_config + or is_dataclass_instance(subgroup_value) + and type(subgroup_value) == default_dataclass_type_from_config + or _get_dataclass_type_from_callable(subgroup_value) + == default_dataclass_type_from_config + } + logger.debug( + f"Subgroup choices that match the type in the config file: " + f"{subgroup_choices_with_matching_type}" + ) + + # IDEA: Try to find the best subgroup key to use, based on the number of matching + # constructor arguments between the default in the config and the defaults for each + # subgroup. + constructor_args_of_each_subgroup_val = { + key: ( + dataclasses.asdict(subgroup_value) + if is_dataclass_instance(subgroup_value) + # (the type should have been narrowed by the is_dataclass_instance typeguard, + # but somehow isn't...) + else _default_constructor_argument_values(subgroup_value) # type: ignore + ) + for key, subgroup_value in subgroup_choices_with_matching_type.items() + } + logger.debug( + f"Constructor arguments for each subgroup choice: " + f"{constructor_args_of_each_subgroup_val}" + ) + + def _num_overlapping_keys( + subgroup_default_in_config: PossiblyNestedDict[str, Any], + subgroup_option_from_field: PossiblyNestedDict[str, Any], + ) -> int: + """Returns the number of matching entries in the subgroup dict w/ the default from + the config.""" + overlap = 0 + for key, value in subgroup_default_in_config.items(): + if key in subgroup_option_from_field: + overlap += 1 + if isinstance(value, dict) and isinstance( + subgroup_option_from_field[key], dict + ): + overlap += _num_overlapping_keys( + value, subgroup_option_from_field[key] + ) + return overlap + + n_matching_values = { + k: _num_overlapping_keys(default_from_config_file, constructor_args_in_value) + for k, constructor_args_in_value in constructor_args_of_each_subgroup_val.items() + } + logger.debug( + f"Number of overlapping keys for each subgroup choice: {n_matching_values}" + ) + closest_subgroups_first = sorted( + subgroup_choices_with_matching_type.keys(), + key=n_matching_values.__getitem__, + reverse=True, + ) + closest_subgroup_key = closest_subgroups_first[0] + + warnings.warn( + RuntimeWarning( + f"The config file contains a default value for a subgroup field that isn't in " + f"the dict of subgroup options. " + f"Because of how subgroups are currently implemented, we need to find the key " + f"in the subgroup choice dict that most closely matches the value " + f"{default_from_config_file} in order to populate the default values for " + f"other fields.\n" + f"The default in the config file: {default_from_config_file}\n" + f"The default in the dataclass field: {default_from_dataclass_field}\n" + f"The subgroups dict: {subgroup_field.subgroup_choices}\n" + f"The current implementation tries to use the dataclass type of this closest " + f"match to parse the additional values from the command-line. " + f"Consider adding a {SUBGROUP_KEY_FLAG!r}: item " + f"in the dict entry for that subgroup field in your config, to make it easier " + f"to tell directly which subgroup to use." + ) + ) + return closest_subgroup_key + + logger.debug( + f"Using subgroup key {default_from_dataclass_field} as default (from the dataclass " + f"field)" + ) + return default_from_dataclass_field + + if subgroup_default in subgroup_field.subgroup_choices.keys(): + return subgroup_default + + if subgroup_default in subgroup_field.subgroup_choices.values(): + matching_keys = [ + k for k, v in subgroup_field.subgroup_choices.items() if v == subgroup_default + ] + return matching_keys[0] + + raise RuntimeError( + f"Error: Unable to figure out what key matches the default value for the subgroup at " + f"{subgroup_field.dest}! (expected to either have the {SUBGROUP_KEY_FLAG!r} flag set, or " + f"one of the keys or values of the subgroups dict of that field: " + f"{subgroup_field.subgroup_choices})" ) diff --git a/test/test_subgroups.py b/test/test_subgroups.py index fa2c8fde..13e102be 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -3,6 +3,7 @@ import dataclasses import functools import inspect +import json import shlex import sys from dataclasses import dataclass, field @@ -970,6 +971,7 @@ class A1OrA2: (A1OrA2(a=A1()), "--a=a2", A1OrA2(a=A2())), (A1OrA2(a=also_a2_default()), "", A1OrA2(a=also_a2_default())), ], + ids=repr, ) @pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"]) def test_parse_with_config_file_with_different_subgroup( @@ -983,8 +985,8 @@ def test_parse_with_config_file_with_different_subgroup( # I think I was trying to reproduce the issue from #276 config_path = (tmp_path / "bob").with_suffix(filetype) - save(value_in_config, config_path, save_dc_types=True) + assert parse(A1OrA2, config_path=config_path, args=args) == expected @@ -1000,3 +1002,43 @@ def test_roundtrip(value: Dataclass): https://github.com/lebrice/SimpleParsing/pull/284#issuecomment-1783490388.""" assert from_dict(type(value), to_dict(value)) == value assert to_dict(from_dict(type(value), to_dict(value))) == to_dict(value) + + +@dataclass +class AorB: + a_or_b: A | B = subgroups( + {"a": A, "b": B, "also_a": functools.partial(A, a=1.23)}, default="a" + ) + + +def test_saved_with_key_as_default(tmp_path: Path): + """Test to try to reproduce + https://github.com/lebrice/SimpleParsing/pull/284#discussion_r1434421587 + """ + + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"a_or_b": "b"})) + + assert parse(AorB, args="") == AorB(a_or_b=A()) + assert parse(AorB, config_path=config_path, args="") == AorB(a_or_b=B()) + assert parse(AorB, config_path=config_path, args="--a_or_b=a") == AorB(a_or_b=A()) + + +def test_saved_with_custom_dict_as_default(tmp_path: Path): + """Test when a customized dict is set in the config for a subgroups field. + + We expect to have a warning but for things to work. + """ + + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"a_or_b": {"b": "somefoo"}})) + assert parse(AorB, args="") == AorB(a_or_b=A()) + + with pytest.raises(TypeError): + # Default is 'a', so we should get a TypeError because b="somefoo" is passed to `A`. + assert parse(AorB, config_path=config_path, args="") + + with pytest.warns(RuntimeWarning): + assert parse(AorB, config_path=config_path, args="--b=bobo") == AorB(a_or_b=B(b="bobo")) + + assert parse(AorB, config_path=config_path, args="--a_or_b=a") == AorB(a_or_b=A()) From b930cfc9d4853f73e74c6cbc23301d812c206eaa Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 22 Jan 2024 12:32:58 -0500 Subject: [PATCH 16/20] Clarify goal of test with comment Signed-off-by: Fabrice Normandin --- test/test_subgroups.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/test_subgroups.py b/test/test_subgroups.py index 13e102be..bb6dc9c1 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -981,13 +981,10 @@ def test_parse_with_config_file_with_different_subgroup( args: str, expected: A1OrA2, ): - """TODO: Honestly not 100% sure what I was testing here.""" - # I think I was trying to reproduce the issue from #276 - + """Test the case where a subgroup different from the default is saved in the config file.""" config_path = (tmp_path / "bob").with_suffix(filetype) save(value_in_config, config_path, save_dc_types=True) - - assert parse(A1OrA2, config_path=config_path, args=args) == expected + assert parse(type(value_in_config), config_path=config_path, args=args) == expected @pytest.mark.parametrize( From ea8a0e16dd22f6b5eaa087c216cb1ea0d6b0acc0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 31 Jan 2024 11:49:01 -0500 Subject: [PATCH 17/20] Move subgroup stuff into a new separate file Signed-off-by: Fabrice Normandin --- simple_parsing/parsing.py | 438 +++-------------------------- simple_parsing/subgroup_parsing.py | 413 +++++++++++++++++++++++++++ 2 files changed, 455 insertions(+), 396 deletions(-) create mode 100644 simple_parsing/subgroup_parsing.py diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index 7623ca8a..d4a19355 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -7,32 +7,32 @@ import argparse import dataclasses import functools +import gettext import inspect -import itertools import shlex import sys import typing -from argparse import SUPPRESS, Action, HelpFormatter, Namespace, _ +from argparse import SUPPRESS, HelpFormatter, Namespace from collections import defaultdict from logging import getLogger from pathlib import Path -from typing import Any, Callable, Hashable, Mapping, Sequence, Type, overload +from typing import Any, Callable, Mapping, Sequence, Type, overload + from typing_extensions import TypeGuard -import warnings -from simple_parsing.helpers.subgroups import SubgroupKey -from simple_parsing.replace import SUBGROUP_KEY_FLAG + from simple_parsing.wrappers.dataclass_wrapper import DataclassWrapperType from . import utils from .conflicts import ConflictResolution, ConflictResolver from .help_formatter import SimpleHelpFormatter -from .helpers.serialization.serializable import DC_TYPE_KEY, read_file +from .helpers.serialization.serializable import read_file +from .subgroup_parsing import remove_subgroups_from_namespace, resolve_subgroups from .utils import ( - K, - V, Dataclass, DataclassT, + K, PossiblyNestedDict, + V, dict_union, is_dataclass_instance, is_dataclass_type, @@ -158,7 +158,7 @@ def __init__( default_prefix * 2 + "help", action="help", default=SUPPRESS, - help=_("show this help message and exit"), + help=gettext.gettext("show this help message and exit"), ) self.config_path = Path(config_path) if isinstance(config_path, str) else config_path @@ -167,17 +167,6 @@ def __init__( add_config_path_arg = bool(config_path) self.add_config_path_arg = add_config_path_arg - # TODO: Remove, since the base class already has nicer type hints. - def add_argument( - self, - *name_or_flags: str, - **kwargs, - ) -> Action: - return super().add_argument( - *name_or_flags, - **kwargs, - ) - @overload def add_arguments( self, @@ -389,37 +378,42 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) -> kwargs = {self._wrappers[0].dest: kwargs} # Also include the values from **kwargs. kwargs = dict_union(defaults, kwargs) + self._set_defaults(**kwargs) + def _set_defaults(self, **kwargs: Any) -> None: # The kwargs that are set in the dataclasses, rather than on the namespace. kwarg_defaults_set_in_dataclasses = {} for wrapper in self._wrappers: - if wrapper.dest in kwargs: - default_for_dataclass = kwargs[wrapper.dest] - - if isinstance(default_for_dataclass, (str, Path)): - default_for_dataclass = read_file(path=default_for_dataclass) - elif not isinstance(default_for_dataclass, dict) and not dataclasses.is_dataclass( - default_for_dataclass - ): - raise ValueError( - f"Got a default for field {wrapper.dest} that isn't a dataclass, dict or " - f"path: {default_for_dataclass}" - ) + if wrapper.dest not in kwargs: + continue - # Set the .default attribute on the DataclassWrapper (which also updates the - # defaults of the fields and any nested dataclass fields). - wrapper.set_default(default_for_dataclass) + default_for_dataclass = kwargs[wrapper.dest] - # It's impossible for multiple wrappers in kwargs to have the same destination. - assert wrapper.dest not in kwarg_defaults_set_in_dataclasses - value_for_constructor_arguments = ( - default_for_dataclass - if isinstance(default_for_dataclass, dict) - else dataclasses.asdict(default_for_dataclass) + if isinstance(default_for_dataclass, (str, Path)): + default_for_dataclass = read_file(path=default_for_dataclass) + elif not isinstance(default_for_dataclass, dict) and not dataclasses.is_dataclass( + default_for_dataclass + ): + raise ValueError( + f"Got a default for field {wrapper.dest} that isn't a dataclass, dict or " + f"path: {default_for_dataclass}" ) - kwarg_defaults_set_in_dataclasses[wrapper.dest] = value_for_constructor_arguments - # Remove this from the **kwargs, so they don't get set on the namespace. - kwargs.pop(wrapper.dest) + + # Set the .default attribute on the DataclassWrapper (which also updates the + # defaults of the fields and any nested dataclass fields). + wrapper.set_default(default_for_dataclass) + + # It's impossible for multiple wrappers in kwargs to have the same destination. + assert wrapper.dest not in kwarg_defaults_set_in_dataclasses + value_for_constructor_arguments = ( + default_for_dataclass + if isinstance(default_for_dataclass, dict) + else dataclasses.asdict(default_for_dataclass) + ) + kwarg_defaults_set_in_dataclasses[wrapper.dest] = value_for_constructor_arguments + # Remove this from the **kwargs, so they don't get set on the namespace. + kwargs.pop(wrapper.dest) + # TODO: Stop using a defaultdict for the very important `self.constructor_arguments`! self.constructor_arguments = dict_union( self.constructor_arguments, @@ -526,8 +520,8 @@ def _preprocessing(self, args: Sequence[str] = (), namespace: Namespace | None = # Fix the potential conflicts between dataclass fields with the same names. wrapped_dataclasses = self._conflict_resolver.resolve_and_flatten(wrapped_dataclasses) - wrapped_dataclasses, chosen_subgroups = self._resolve_subgroups( - wrappers=wrapped_dataclasses, args=args, namespace=namespace + wrapped_dataclasses, chosen_subgroups = resolve_subgroups( + self, wrappers=wrapped_dataclasses, args=args, namespace=namespace ) # NOTE: We keep the subgroup fields in their dataclasses so they show up with the other @@ -571,7 +565,7 @@ def _postprocessing(self, parsed_args: Namespace) -> Namespace: logger.debug("\nPOST PROCESSING\n") logger.debug(f"(raw) parsed args: {parsed_args}") - self._remove_subgroups_from_namespace(parsed_args) + remove_subgroups_from_namespace(self, parsed_args) # create the constructor arguments for each instance by consuming all # the relevant attributes from `parsed_args` wrappers = _flatten_wrappers(self._wrappers) @@ -589,201 +583,6 @@ def _postprocessing(self, parsed_args: Namespace) -> Namespace: ) return parsed_args - def _resolve_subgroups( - self, - wrappers: list[DataclassWrapper], - args: list[str], - namespace: Namespace | None = None, - ) -> tuple[list[DataclassWrapper], dict[str, str]]: - """Iteratively add and resolve all the choice of argument subgroups, if any. - - This modifies the wrappers in-place, by possibly adding children to the wrappers in the - list. - Returns a list with the (now modified) wrappers. - - Each round does the following: - 1. Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting - level, with the same name, get a different prefix, for example "--generator.optimizer" - and "--discriminator.optimizer". - 2. Add all the subgroup choice arguments to a parser. - 3. Add the chosen dataclasses to the list of dataclasses to parse later in the main - parser. This is done by adding wrapping the dataclass and adding it to the `wrappers` - list. - """ - - unresolved_subgroups = _get_subgroup_fields(wrappers) - # Dictionary of the subgroup choices that were resolved (key: subgroup dest, value: chosen - # subgroup name). - resolved_subgroups: dict[str, SubgroupKey] = {} - - if not unresolved_subgroups: - # No subgroups to parse. - return wrappers, {} - - # Use a temporary parser, to avoid parsing "vanilla argparse" arguments of `self` multiple - # times. - subgroup_choice_parser = argparse.ArgumentParser( - add_help=False, - formatter_class=self.formatter_class, - # NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues - # for example if you have —a_or_b and A has a field —a then it will error out if you - # pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it - # with the abbreviation feature turned on). - allow_abbrev=False, - ) - - for current_nesting_level in itertools.count(): - # Do rounds of parsing with just the subgroup arguments, until all the subgroups - # are resolved to a dataclass type. - logger.debug( - f"Starting subgroup parsing round {current_nesting_level}: " - f"{list(unresolved_subgroups.keys())}" - ) - # Add all the unresolved subgroups arguments. - for dest, subgroup_field in unresolved_subgroups.items(): - flags = subgroup_field.option_strings - argument_options = subgroup_field.arg_options - - # Sanity checks: - if subgroup_field.subgroup_default is dataclasses.MISSING: - assert argument_options["required"] - if "default" in argument_options: - # todo: should ideally not set this in the first place... - assert argument_options["default"] is dataclasses.MISSING - argument_options.pop("default") - assert "default" not in argument_options - else: - assert "default" in argument_options - assert argument_options["default"] == subgroup_field.default - argument_options["default"] = _adjust_default_value_for_subgroup_field( - subgroup_field=subgroup_field, - subgroup_default=argument_options["default"], - ) - - logger.debug( - f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})" - ) - subgroup_choice_parser.add_argument(*flags, **argument_options) - - # Parse `args` repeatedly until all the subgroup choices are resolved. - parsed_args, unused_args = subgroup_choice_parser.parse_known_args( - args=args, namespace=namespace - ) - logger.debug( - f"Nesting level {current_nesting_level}: args: {args}, " - f"parsed_args: {parsed_args}, unused_args: {unused_args}" - ) - - for dest, subgroup_field in list(unresolved_subgroups.items()): - # NOTE: There should always be a parsed value for the subgroup argument on the - # namespace. This is because we added all the subgroup arguments before we get - # here. - subgroup_dict = subgroup_field.subgroup_choices - chosen_subgroup_key: SubgroupKey = getattr(parsed_args, dest) - assert chosen_subgroup_key in subgroup_dict - - # Changing the default value of the (now parsed) field for the subgroup choice, - # just so it shows (default: {chosen_subgroup_key}) on the command-line. - # Note: This really isn't required, we could have it just be the default value, but - # it seems a bit more consistent with us then showing the --help string for the - # chosen dataclass type (as we're doing below). - # subgroup_field.set_default(chosen_subgroup_key) - logger.debug( - f"resolved the subgroup at {dest!r}: will use the subgroup at key " - f"{chosen_subgroup_key!r}" - ) - - default_or_dataclass_fn = subgroup_dict[chosen_subgroup_key] - if is_dataclass_instance(default_or_dataclass_fn): - # The chosen value in the subgroup dict is a frozen dataclass instance. - default = default_or_dataclass_fn - dataclass_fn = functools.partial(dataclasses.replace, default) - dataclass_type = type(default) - else: - default = None - dataclass_fn = default_or_dataclass_fn - dataclass_type = subgroup_field.field.metadata["subgroup_dataclass_types"][ - chosen_subgroup_key - ] - - assert default is None or is_dataclass_instance(default) - assert callable(dataclass_fn) - assert is_dataclass_type(dataclass_type) - - name = dest.split(".")[-1] - parent_dataclass_wrapper = subgroup_field.parent - # NOTE: Using self._add_arguments so it returns the modified wrapper and doesn't - # affect the `self._wrappers` list. - new_wrapper = self._add_arguments( - dataclass_type=dataclass_type, - name=name, - dataclass_fn=dataclass_fn, - default=default, - parent=parent_dataclass_wrapper, - ) - # Make the new wrapper a child of the class which contains the field. - # - it isn't already a child - # - it's parent is the parent dataclass wrapper - # - the parent is already in the tree of DataclassWrappers. - assert new_wrapper not in parent_dataclass_wrapper._children - parent_dataclass_wrapper._children.append(new_wrapper) - assert new_wrapper.parent is parent_dataclass_wrapper - assert parent_dataclass_wrapper in _flatten_wrappers(wrappers) - assert new_wrapper in _flatten_wrappers(wrappers) - - # Mark this subgroup as resolved. - unresolved_subgroups.pop(dest) - resolved_subgroups[dest] = chosen_subgroup_key - # TODO: Should we remove the FieldWrapper for the subgroups now that it's been - # resolved? - - # Find the new subgroup fields that weren't resolved before. - # TODO: What if a name conflict occurs between a subgroup field and one of the new - # fields below it? For example, something like --model model_a (and inside the `ModelA` - # dataclass, there's a field called `model`. Then, this will cause a conflict!) - # For now, I'm just going to wait and see how this plays out. I'm hoping that the - # auto conflict resolution shouldn't run into any issues in this case. - - wrappers = self._conflict_resolver.resolve(wrappers) - - all_subgroup_fields = _get_subgroup_fields(wrappers) - unresolved_subgroups = { - k: v for k, v in all_subgroup_fields.items() if k not in resolved_subgroups - } - logger.debug(f"All subgroups: {list(all_subgroup_fields.keys())}") - logger.debug(f"Resolved subgroups: {resolved_subgroups}") - logger.debug(f"Unresolved subgroups: {list(unresolved_subgroups.keys())}") - - if not unresolved_subgroups: - logger.debug("Done parsing all the subgroups!") - break - else: - logger.debug( - f"Done parsing a round of subparsers at nesting level " - f"{current_nesting_level}. Moving to the next round which has " - f"{len(unresolved_subgroups)} unresolved subgroup choices." - ) - return wrappers, resolved_subgroups - - def _remove_subgroups_from_namespace(self, parsed_args: argparse.Namespace) -> None: - """Removes the subgroup choice results from the namespace. - - Modifies the namespace in-place. - """ - # find all subgroup fields - subgroup_fields = _get_subgroup_fields(self._wrappers) - - if not subgroup_fields: - return - # IDEA: Store the choices in a `subgroups` dict on the namespace. - if not hasattr(parsed_args, "subgroups"): - parsed_args.subgroups = {} - - for dest in subgroup_fields: - chosen_value = getattr(parsed_args, dest) - parsed_args.subgroups[dest] = chosen_value - delattr(parsed_args, dest) - def _instantiate_dataclasses( self, parsed_args: argparse.Namespace, @@ -1105,17 +904,6 @@ def parse_known_args( return config, unknown_args -def _get_subgroup_fields(wrappers: list[DataclassWrapper]) -> dict[str, FieldWrapper]: - subgroup_fields = {} - all_wrappers = _flatten_wrappers(wrappers) - for wrapper in all_wrappers: - for field in wrapper.fields: - if field.is_subgroup: - assert field not in subgroup_fields.values() - subgroup_fields[field.dest] = field - return subgroup_fields - - def _remove_duplicates(wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]: return list(set(wrappers)) @@ -1237,145 +1025,3 @@ def _default_constructor_argument_values( ): result[key] = _default_constructor_argument_values(field.type) return result - - -def _adjust_default_value_for_subgroup_field( - subgroup_field: FieldWrapper, subgroup_default: Any -) -> str | Hashable: - - if argparse.SUPPRESS in subgroup_field.parent.defaults: - assert subgroup_default is argparse.SUPPRESS - assert isinstance(subgroup_default, str) - return subgroup_default - - if isinstance(subgroup_default, dict): - default_from_config_file = subgroup_default - default_from_dataclass_field = subgroup_field.subgroup_default - - if SUBGROUP_KEY_FLAG in default_from_config_file: - _default_subgroup = default_from_config_file[SUBGROUP_KEY_FLAG] - logger.debug(f"Using subgroup key {_default_subgroup} as default (from config file)") - return _default_subgroup - - if DC_TYPE_KEY in default_from_config_file: - # The type of dataclass is specified in the config file. - # We can use that to figure out which subgroup to use. - default_dataclass_type_from_config = default_from_config_file[DC_TYPE_KEY] - if isinstance(default_dataclass_type_from_config, str): - from simple_parsing.helpers.serialization.serializable import _locate - - # Try to import the type of dataclass given its import path as a string in the - # config file. - default_dataclass_type_from_config = _locate(default_dataclass_type_from_config) - assert is_dataclass_type(default_dataclass_type_from_config) - - from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable - - subgroup_choices_with_matching_type: dict[ - Hashable, Dataclass | Callable[[], Dataclass] - ] = { - subgroup_key: subgroup_value - for subgroup_key, subgroup_value in subgroup_field.subgroup_choices.items() - if is_dataclass_type(subgroup_value) - and subgroup_value == default_dataclass_type_from_config - or is_dataclass_instance(subgroup_value) - and type(subgroup_value) == default_dataclass_type_from_config - or _get_dataclass_type_from_callable(subgroup_value) - == default_dataclass_type_from_config - } - logger.debug( - f"Subgroup choices that match the type in the config file: " - f"{subgroup_choices_with_matching_type}" - ) - - # IDEA: Try to find the best subgroup key to use, based on the number of matching - # constructor arguments between the default in the config and the defaults for each - # subgroup. - constructor_args_of_each_subgroup_val = { - key: ( - dataclasses.asdict(subgroup_value) - if is_dataclass_instance(subgroup_value) - # (the type should have been narrowed by the is_dataclass_instance typeguard, - # but somehow isn't...) - else _default_constructor_argument_values(subgroup_value) # type: ignore - ) - for key, subgroup_value in subgroup_choices_with_matching_type.items() - } - logger.debug( - f"Constructor arguments for each subgroup choice: " - f"{constructor_args_of_each_subgroup_val}" - ) - - def _num_overlapping_keys( - subgroup_default_in_config: PossiblyNestedDict[str, Any], - subgroup_option_from_field: PossiblyNestedDict[str, Any], - ) -> int: - """Returns the number of matching entries in the subgroup dict w/ the default from - the config.""" - overlap = 0 - for key, value in subgroup_default_in_config.items(): - if key in subgroup_option_from_field: - overlap += 1 - if isinstance(value, dict) and isinstance( - subgroup_option_from_field[key], dict - ): - overlap += _num_overlapping_keys( - value, subgroup_option_from_field[key] - ) - return overlap - - n_matching_values = { - k: _num_overlapping_keys(default_from_config_file, constructor_args_in_value) - for k, constructor_args_in_value in constructor_args_of_each_subgroup_val.items() - } - logger.debug( - f"Number of overlapping keys for each subgroup choice: {n_matching_values}" - ) - closest_subgroups_first = sorted( - subgroup_choices_with_matching_type.keys(), - key=n_matching_values.__getitem__, - reverse=True, - ) - closest_subgroup_key = closest_subgroups_first[0] - - warnings.warn( - RuntimeWarning( - f"The config file contains a default value for a subgroup field that isn't in " - f"the dict of subgroup options. " - f"Because of how subgroups are currently implemented, we need to find the key " - f"in the subgroup choice dict that most closely matches the value " - f"{default_from_config_file} in order to populate the default values for " - f"other fields.\n" - f"The default in the config file: {default_from_config_file}\n" - f"The default in the dataclass field: {default_from_dataclass_field}\n" - f"The subgroups dict: {subgroup_field.subgroup_choices}\n" - f"The current implementation tries to use the dataclass type of this closest " - f"match to parse the additional values from the command-line. " - f"Consider adding a {SUBGROUP_KEY_FLAG!r}: item " - f"in the dict entry for that subgroup field in your config, to make it easier " - f"to tell directly which subgroup to use." - ) - ) - return closest_subgroup_key - - logger.debug( - f"Using subgroup key {default_from_dataclass_field} as default (from the dataclass " - f"field)" - ) - return default_from_dataclass_field - - if subgroup_default in subgroup_field.subgroup_choices.keys(): - return subgroup_default - - if subgroup_default in subgroup_field.subgroup_choices.values(): - matching_keys = [ - k for k, v in subgroup_field.subgroup_choices.items() if v == subgroup_default - ] - return matching_keys[0] - - raise RuntimeError( - f"Error: Unable to figure out what key matches the default value for the subgroup at " - f"{subgroup_field.dest}! (expected to either have the {SUBGROUP_KEY_FLAG!r} flag set, or " - f"one of the keys or values of the subgroups dict of that field: " - f"{subgroup_field.subgroup_choices})" - ) diff --git a/simple_parsing/subgroup_parsing.py b/simple_parsing/subgroup_parsing.py new file mode 100644 index 00000000..91ab29f0 --- /dev/null +++ b/simple_parsing/subgroup_parsing.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +import argparse +import dataclasses +import functools +import itertools +import typing +import warnings +from argparse import Namespace +from logging import getLogger as get_logger +from typing import Any, Callable, Hashable + +from simple_parsing.helpers.serialization.serializable import DC_TYPE_KEY +from simple_parsing.helpers.subgroups import SubgroupKey +from simple_parsing.replace import SUBGROUP_KEY_FLAG +from simple_parsing.utils import ( + Dataclass, + PossiblyNestedDict, + is_dataclass_instance, + is_dataclass_type, +) +from simple_parsing.wrappers import DataclassWrapper, FieldWrapper + +if typing.TYPE_CHECKING: + from simple_parsing.parsing import ArgumentParser + + +logger = get_logger(__name__) + + +def resolve_subgroups( + parser: ArgumentParser, + wrappers: list[DataclassWrapper], + args: list[str], + namespace: Namespace | None = None, +) -> tuple[list[DataclassWrapper], dict[str, str]]: + """Iteratively add and resolve all the choice of argument subgroups, if any. + + This modifies the wrappers in-place, by possibly adding children to the wrappers in the + list. + Returns a list with the (now modified) wrappers. + + Each round does the following: + 1. Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting + level, with the same name, get a different prefix, for example "--generator.optimizer" + and "--discriminator.optimizer". + 2. Add all the subgroup choice arguments to a parser. + 3. Add the chosen dataclasses to the list of dataclasses to parse later in the main + parser. This is done by adding wrapping the dataclass and adding it to the `wrappers` + list. + """ + + unresolved_subgroups = _get_subgroup_fields(wrappers) + # Dictionary of the subgroup choices that were resolved (key: subgroup dest, value: chosen + # subgroup name). + resolved_subgroups: dict[str, SubgroupKey] = {} + + if not unresolved_subgroups: + # No subgroups to parse. + return wrappers, {} + + # Use a temporary parser, to avoid parsing "vanilla argparse" arguments of `self` multiple + # times. + subgroup_choice_parser = argparse.ArgumentParser( + add_help=False, + formatter_class=parser.formatter_class, + # NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues + # for example if you have —a_or_b and A has a field —a then it will error out if you + # pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it + # with the abbreviation feature turned on). + allow_abbrev=False, + ) + + for current_nesting_level in itertools.count(): + # Do rounds of parsing with just the subgroup arguments, until all the subgroups + # are resolved to a dataclass type. + logger.debug( + f"Starting subgroup parsing round {current_nesting_level}: " + f"{list(unresolved_subgroups.keys())}" + ) + # Add all the unresolved subgroups arguments. + for dest, subgroup_field in unresolved_subgroups.items(): + flags = subgroup_field.option_strings + argument_options = subgroup_field.arg_options + + # Sanity checks: + if subgroup_field.subgroup_default is dataclasses.MISSING: + assert argument_options["required"] + if "default" in argument_options: + # todo: should ideally not set this in the first place... + assert argument_options["default"] is dataclasses.MISSING + argument_options.pop("default") + assert "default" not in argument_options + else: + assert "default" in argument_options + assert argument_options["default"] == subgroup_field.default + argument_options["default"] = _adjust_default_value_for_subgroup_field( + subgroup_field=subgroup_field, + subgroup_default=argument_options["default"], + ) + + logger.debug( + f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})" + ) + subgroup_choice_parser.add_argument(*flags, **argument_options) + + # Parse `args` repeatedly until all the subgroup choices are resolved. + parsed_args, unused_args = subgroup_choice_parser.parse_known_args( + args=args, namespace=namespace + ) + logger.debug( + f"Nesting level {current_nesting_level}: args: {args}, " + f"parsed_args: {parsed_args}, unused_args: {unused_args}" + ) + + for dest, subgroup_field in list(unresolved_subgroups.items()): + # NOTE: There should always be a parsed value for the subgroup argument on the + # namespace. This is because we added all the subgroup arguments before we get + # here. + subgroup_dict = subgroup_field.subgroup_choices + chosen_subgroup_key: SubgroupKey = getattr(parsed_args, dest) + assert chosen_subgroup_key in subgroup_dict + + # Changing the default value of the (now parsed) field for the subgroup choice, + # just so it shows (default: {chosen_subgroup_key}) on the command-line. + # Note: This really isn't required, we could have it just be the default value, but + # it seems a bit more consistent with us then showing the --help string for the + # chosen dataclass type (as we're doing below). + # subgroup_field.set_default(chosen_subgroup_key) + logger.debug( + f"resolved the subgroup at {dest!r}: will use the subgroup at key " + f"{chosen_subgroup_key!r}" + ) + + default_or_dataclass_fn = subgroup_dict[chosen_subgroup_key] + if is_dataclass_instance(default_or_dataclass_fn): + # The chosen value in the subgroup dict is a frozen dataclass instance. + default = default_or_dataclass_fn + dataclass_fn = functools.partial(dataclasses.replace, default) + dataclass_type = type(default) + else: + default = None + dataclass_fn = default_or_dataclass_fn + dataclass_type = subgroup_field.field.metadata["subgroup_dataclass_types"][ + chosen_subgroup_key + ] + + assert default is None or is_dataclass_instance(default) + assert callable(dataclass_fn) + assert is_dataclass_type(dataclass_type) + + name = dest.split(".")[-1] + parent_dataclass_wrapper = subgroup_field.parent + # NOTE: Using self._add_arguments so it returns the modified wrapper and doesn't + # affect the `self._wrappers` list. + new_wrapper = parser._add_arguments( + dataclass_type=dataclass_type, + name=name, + dataclass_fn=dataclass_fn, + default=default, + parent=parent_dataclass_wrapper, + ) + # Make the new wrapper a child of the class which contains the field. + # - it isn't already a child + # - it's parent is the parent dataclass wrapper + # - the parent is already in the tree of DataclassWrappers. + assert new_wrapper not in parent_dataclass_wrapper._children + parent_dataclass_wrapper._children.append(new_wrapper) + assert new_wrapper.parent is parent_dataclass_wrapper + assert parent_dataclass_wrapper in _flatten_wrappers(wrappers) + assert new_wrapper in _flatten_wrappers(wrappers) + + # Mark this subgroup as resolved. + unresolved_subgroups.pop(dest) + resolved_subgroups[dest] = chosen_subgroup_key + # TODO: Should we remove the FieldWrapper for the subgroups now that it's been + # resolved? + + # Find the new subgroup fields that weren't resolved before. + # TODO: What if a name conflict occurs between a subgroup field and one of the new + # fields below it? For example, something like --model model_a (and inside the `ModelA` + # dataclass, there's a field called `model`. Then, this will cause a conflict!) + # For now, I'm just going to wait and see how this plays out. I'm hoping that the + # auto conflict resolution shouldn't run into any issues in this case. + + wrappers = parser._conflict_resolver.resolve(wrappers) + + all_subgroup_fields = _get_subgroup_fields(wrappers) + unresolved_subgroups = { + k: v for k, v in all_subgroup_fields.items() if k not in resolved_subgroups + } + logger.debug(f"All subgroups: {list(all_subgroup_fields.keys())}") + logger.debug(f"Resolved subgroups: {resolved_subgroups}") + logger.debug(f"Unresolved subgroups: {list(unresolved_subgroups.keys())}") + + if not unresolved_subgroups: + logger.debug("Done parsing all the subgroups!") + break + else: + logger.debug( + f"Done parsing a round of subparsers at nesting level " + f"{current_nesting_level}. Moving to the next round which has " + f"{len(unresolved_subgroups)} unresolved subgroup choices." + ) + return wrappers, resolved_subgroups + + +def _get_subgroup_fields(wrappers: list[DataclassWrapper]) -> dict[str, FieldWrapper]: + subgroup_fields = {} + all_wrappers = _flatten_wrappers(wrappers) + for wrapper in all_wrappers: + for field in wrapper.fields: + if field.is_subgroup: + assert field not in subgroup_fields.values() + subgroup_fields[field.dest] = field + return subgroup_fields + + +def _flatten_wrappers(wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]: + """Takes a list of nodes, returns a flattened list of all nodes in the tree.""" + _assert_no_duplicates(wrappers) + roots_only = _unflatten_wrappers(wrappers) + return sum(([w] + list(w.descendants) for w in roots_only), []) + + +def _assert_no_duplicates(wrappers: list[DataclassWrapper]) -> None: + if len(wrappers) != len(set(wrappers)): + raise RuntimeError( + "Duplicate wrappers found! This is a potentially nasty bug on our " + "part. Please make an issue at https://www.github.com/lebrice/SimpleParsing/issues " + ) + + +def _unflatten_wrappers(wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]: + """Given a list of nodes in one or more trees, returns only the root nodes. + + In our context, this is all the dataclass arg groups that were added with + `parser.add_arguments`. + """ + _assert_no_duplicates(wrappers) + return [w for w in wrappers if w.parent is None] + + +def remove_subgroups_from_namespace( + argument_parser: ArgumentParser, parsed_args: argparse.Namespace +) -> None: + """Removes the subgroup choice results from the namespace. + + Modifies the namespace in-place. + """ + # find all subgroup fields + subgroup_fields = _get_subgroup_fields(argument_parser._wrappers) + + if not subgroup_fields: + return + # IDEA: Store the choices in a `subgroups` dict on the namespace. + if not hasattr(parsed_args, "subgroups"): + parsed_args.subgroups = {} + + for dest in subgroup_fields: + chosen_value = getattr(parsed_args, dest) + parsed_args.subgroups[dest] = chosen_value + delattr(parsed_args, dest) + + +def _adjust_default_value_for_subgroup_field( + subgroup_field: FieldWrapper, subgroup_default: Any +) -> str | Hashable: + if argparse.SUPPRESS in subgroup_field.parent.defaults: + assert subgroup_default is argparse.SUPPRESS + assert isinstance(subgroup_default, str) + return subgroup_default + + if isinstance(subgroup_default, dict): + default_from_config_file = subgroup_default + default_from_dataclass_field = subgroup_field.subgroup_default + + if SUBGROUP_KEY_FLAG in default_from_config_file: + _default_subgroup = default_from_config_file[SUBGROUP_KEY_FLAG] + logger.debug(f"Using subgroup key {_default_subgroup} as default (from config file)") + return _default_subgroup + + if DC_TYPE_KEY in default_from_config_file: + # The type of dataclass is specified in the config file. + # We can use that to figure out which subgroup to use. + default_dataclass_type_from_config = default_from_config_file[DC_TYPE_KEY] + if isinstance(default_dataclass_type_from_config, str): + from simple_parsing.helpers.serialization.serializable import _locate + + # Try to import the type of dataclass given its import path as a string in the + # config file. + default_dataclass_type_from_config = _locate(default_dataclass_type_from_config) + assert is_dataclass_type(default_dataclass_type_from_config) + + from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable + + subgroup_choices_with_matching_type: dict[ + Hashable, Dataclass | Callable[[], Dataclass] + ] = { + subgroup_key: subgroup_value + for subgroup_key, subgroup_value in subgroup_field.subgroup_choices.items() + if is_dataclass_type(subgroup_value) + and subgroup_value == default_dataclass_type_from_config + or is_dataclass_instance(subgroup_value) + and type(subgroup_value) == default_dataclass_type_from_config + or _get_dataclass_type_from_callable(subgroup_value) + == default_dataclass_type_from_config + } + logger.debug( + f"Subgroup choices that match the type in the config file: " + f"{subgroup_choices_with_matching_type}" + ) + if len(subgroup_choices_with_matching_type) > 1: + raise RuntimeError( + f"The dataclass type {default_dataclass_type_from_config} matches more than " + f"one value in the subgroups dict:\n" + f"{subgroup_field.subgroup_choices}\n" + f"Use the {SUBGROUP_KEY_FLAG!r} flag to specify which subgroup key to use as " + f"the default." + ) + return subgroup_choices_with_matching_type.popitem()[0] + + # IDEA: Try to find the best subgroup key to use, based on the number of matching + # constructor arguments between the default in the config and the defaults for each + # subgroup. + constructor_args_of_each_subgroup_val = { + key: ( + dataclasses.asdict(subgroup_value) + if is_dataclass_instance(subgroup_value) + # (the type should have been narrowed by the is_dataclass_instance typeguard, + # but somehow isn't...) + else _default_constructor_argument_values(subgroup_value) # type: ignore + ) + for key, subgroup_value in subgroup_choices_with_matching_type.items() + } + logger.debug( + f"Constructor arguments for each subgroup choice: " + f"{constructor_args_of_each_subgroup_val}" + ) + + def _num_overlapping_keys( + subgroup_default_in_config: PossiblyNestedDict[str, Any], + subgroup_option_from_field: PossiblyNestedDict[str, Any], + ) -> int: + """Returns the number of matching entries in the subgroup dict w/ the default from + the config.""" + overlap = 0 + for key, value in subgroup_default_in_config.items(): + if key in subgroup_option_from_field: + overlap += 1 + if isinstance(value, dict) and isinstance( + subgroup_option_from_field[key], dict + ): + overlap += _num_overlapping_keys( + value, subgroup_option_from_field[key] + ) + return overlap + + n_matching_values = { + k: _num_overlapping_keys(default_from_config_file, constructor_args_in_value) + for k, constructor_args_in_value in constructor_args_of_each_subgroup_val.items() + } + logger.debug( + f"Number of overlapping keys for each subgroup choice: {n_matching_values}" + ) + closest_subgroups_first = sorted( + subgroup_choices_with_matching_type.keys(), + key=n_matching_values.__getitem__, + reverse=True, + ) + closest_subgroup_key = closest_subgroups_first[0] + + warnings.warn( + RuntimeWarning( + f"The config file contains a default value for a subgroup field that isn't in " + f"the dict of subgroup options. " + f"Because of how subgroups are currently implemented, we need to find the key " + f"in the subgroup choice dict that most closely matches the value " + f"{default_from_config_file} in order to populate the default values for " + f"other fields.\n" + f"The default in the config file: {default_from_config_file}\n" + f"The default in the dataclass field: {default_from_dataclass_field}\n" + f"The subgroups dict: {subgroup_field.subgroup_choices}\n" + f"The current implementation tries to use the dataclass type of this closest " + f"match to parse the additional values from the command-line. " + f"Consider adding a {SUBGROUP_KEY_FLAG!r}: item " + f"in the dict entry for that subgroup field in your config, to make it easier " + f"to tell directly which subgroup to use." + ) + ) + return closest_subgroup_key + + logger.debug( + f"Using subgroup key {default_from_dataclass_field} as default (from the dataclass " + f"field)" + ) + return default_from_dataclass_field + + if subgroup_default in subgroup_field.subgroup_choices.keys(): + return subgroup_default + + if subgroup_default in subgroup_field.subgroup_choices.values(): + matching_keys = [ + k for k, v in subgroup_field.subgroup_choices.items() if v == subgroup_default + ] + return matching_keys[0] + + raise RuntimeError( + f"Error: Unable to figure out what key matches the default value for the subgroup at " + f"{subgroup_field.dest}! (expected to either have the {SUBGROUP_KEY_FLAG!r} flag set, or " + f"one of the keys or values of the subgroups dict of that field: " + f"{subgroup_field.subgroup_choices})" + ) From a3525f59199768743c570e51ba0b9c9834063275 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 31 Jan 2024 14:59:05 -0500 Subject: [PATCH 18/20] Fix issues with regression files Signed-off-by: Fabrice Normandin --- .pre-commit-config.yaml | 3 +++ examples/docstrings/docstrings_example.py | 4 ++-- examples/merging/multiple_example.py | 2 +- examples/ugly/ugly_example_after.py | 2 +- examples/ugly/ugly_example_before.py | 2 +- simple_parsing/docstring.py | 1 + simple_parsing/helpers/fields.py | 4 ++-- simple_parsing/helpers/serialization/serializable.py | 2 +- simple_parsing/helpers/subgroups.py | 3 ++- simple_parsing/utils.py | 10 +++++----- test/test_base.py | 2 +- test/test_future_annotations.py | 2 +- test/test_huggingface_compat.py | 2 +- test/test_issue64.py | 2 +- test/test_set_defaults.py | 2 +- test/test_subgroups.py | 4 ++-- test/test_subgroups/test_help[Config---help].md | 6 +----- .../test_help[Config---model=model_a --help].md | 6 +----- .../test_help[Config---model=model_b --help].md | 6 +----- ...elp[ConfigWithFrozen---conf=even --a 100 --help].md | 6 +----- .../test_help[ConfigWithFrozen---conf=even --help].md | 6 +----- ...help[ConfigWithFrozen---conf=odd --a 123 --help].md | 6 +----- .../test_help[ConfigWithFrozen---conf=odd --help].md | 6 +----- .../test_help[ConfigWithFrozen---help].md | 6 +----- 24 files changed, 34 insertions(+), 61 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3df32e5..9f86c346 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,9 @@ repos: require_serial: true - id: check-added-large-files require_serial: true + - id: check-merge-conflict + require_serial: true + args: ["--assume-in-merge"] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. diff --git a/examples/docstrings/docstrings_example.py b/examples/docstrings/docstrings_example.py index 97f3e85f..d2911d4c 100644 --- a/examples/docstrings/docstrings_example.py +++ b/examples/docstrings/docstrings_example.py @@ -22,7 +22,7 @@ class DocStringsExample: # comment above 42 attribute4: float = 1.0 # inline comment - """docstring below (this appears in --help)""" + """Docstring below (this appears in --help)""" # comment above (this appears in --help) 46 attribute5: float = 1.0 # inline comment @@ -30,7 +30,7 @@ class DocStringsExample: attribute6: float = 1.0 # inline comment (this appears in --help) attribute7: float = 1.0 # inline comment - """docstring below (this appears in --help)""" + """Docstring below (this appears in --help)""" parser.add_arguments(DocStringsExample, "example") diff --git a/examples/merging/multiple_example.py b/examples/merging/multiple_example.py index f582748e..047ee140 100644 --- a/examples/merging/multiple_example.py +++ b/examples/merging/multiple_example.py @@ -19,7 +19,7 @@ class Config: run_name: str = "train" # Some parameter for the run name. some_int: int = 10 # an optional int parameter. log_dir: str = "logs" # an optional string parameter. - """the logging directory to use. + """The logging directory to use. (This is an attribute docstring for the log_dir attribute, and shows up when using the "--help" argument!) diff --git a/examples/ugly/ugly_example_after.py b/examples/ugly/ugly_example_after.py index f0fc1cec..ea2d1e4f 100644 --- a/examples/ugly/ugly_example_after.py +++ b/examples/ugly/ugly_example_after.py @@ -189,7 +189,7 @@ class RenderingParams: @dataclass class Parameters: - """base options.""" + """Base options.""" # Dataset parameters. dataset: DatasetParams = field(default_factory=DatasetParams) diff --git a/examples/ugly/ugly_example_before.py b/examples/ugly/ugly_example_before.py index c659e2f6..505a6679 100644 --- a/examples/ugly/ugly_example_before.py +++ b/examples/ugly/ugly_example_before.py @@ -9,7 +9,7 @@ class Parameters: - """base options.""" + """Base options.""" def __init__(self): """Constructor.""" diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index 4cc9e129..0be1cc7a 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -47,6 +47,7 @@ def get_attribute_docstring( dataclass: type, field_name: str, accumulate_from_bases: bool = True ) -> AttributeDocString: """Returns the docstrings of a dataclass field. + NOTE: a docstring can either be: - An inline comment, starting with <#> - A Comment on the preceding line, starting with <#> diff --git a/simple_parsing/helpers/fields.py b/simple_parsing/helpers/fields.py index a2008084..bf27ba5f 100644 --- a/simple_parsing/helpers/fields.py +++ b/simple_parsing/helpers/fields.py @@ -262,7 +262,7 @@ def _decoding_fn(value: Any) -> Any: def list_field(*default_items: T, **kwargs) -> list[T]: - """shorthand function for setting a `list` attribute on a dataclass, so that every instance of + """Shorthand function for setting a `list` attribute on a dataclass, so that every instance of the dataclass doesn't share the same list. Accepts any of the arguments of the `dataclasses.field` function. @@ -285,7 +285,7 @@ def list_field(*default_items: T, **kwargs) -> list[T]: def dict_field(default_items: dict[K, V] | Iterable[tuple[K, V]] = (), **kwargs) -> dict[K, V]: - """shorthand function for setting a `dict` attribute on a dataclass, so that every instance of + """Shorthand function for setting a `dict` attribute on a dataclass, so that every instance of the dataclass doesn't share the same `dict`. NOTE: Do not use keyword arguments as you usually would with a dictionary diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index 4206f6cc..1e12f96b 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -1,6 +1,6 @@ from __future__ import annotations -import functools +import functools import json import pickle import warnings diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index 3287b393..30c7bc00 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -251,7 +251,8 @@ def _get_dataclass_type_from_callable( def is_lambda(obj: Any) -> bool: """Returns True if the given object is a lambda expression. - Taken from https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda + Taken from + https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda """ LAMBDA = lambda: 0 # noqa: E731 return isinstance(obj, type(LAMBDA)) and obj.__name__ == LAMBDA.__name__ diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 8271e5b5..38c54663 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -272,7 +272,7 @@ def is_literal(t: type) -> bool: def is_list(t: type) -> bool: - """returns True when `t` is a List type. + """Returns True when `t` is a List type. Args: t (Type): a type. @@ -303,7 +303,7 @@ def is_list(t: type) -> bool: def is_tuple(t: type) -> bool: - """returns True when `t` is a tuple type. + """Returns True when `t` is a tuple type. Args: t (Type): a type. @@ -334,7 +334,7 @@ def is_tuple(t: type) -> bool: def is_dict(t: type) -> bool: - """returns True when `t` is a dict type or annotation. + """Returns True when `t` is a dict type or annotation. Args: t (Type): a type. @@ -371,7 +371,7 @@ def is_dict(t: type) -> bool: def is_set(t: type) -> bool: - """returns True when `t` is a set type or annotation. + """Returns True when `t` is a set type or annotation. Args: t (Type): a type. @@ -659,7 +659,7 @@ def _parse(value: str) -> list[Any]: return values def _parse_literal(value: str) -> list[Any] | Any: - """try to parse the string to a python expression directly. + """Try to parse the string to a python expression directly. (useful for nested lists or tuples.) """ diff --git a/test/test_base.py b/test/test_base.py index 66acf6aa..6f7b7f35 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -138,7 +138,7 @@ class Extended(Base): """Some extension of base-class `Base`""" d: int = 5 - """docstring for 'd' in Extended.""" + """Docstring for 'd' in Extended.""" e: Color = Color.BLUE diff --git a/test/test_future_annotations.py b/test/test_future_annotations.py index aa8f443e..fac89e54 100644 --- a/test/test_future_annotations.py +++ b/test/test_future_annotations.py @@ -246,7 +246,7 @@ class OptimizerConfig(TestSetup): @dataclass class SubclassOfOptimizerConfig(OptimizerConfig): bar: int | float = 123 - """some dummy arg bar.""" + """Some dummy arg bar.""" def test_missing_annotation_on_subclass(): diff --git a/test/test_huggingface_compat.py b/test/test_huggingface_compat.py index 9fa7d530..d5383fd8 100644 --- a/test/test_huggingface_compat.py +++ b/test/test_huggingface_compat.py @@ -1298,7 +1298,7 @@ def test_entire_docstring_isnt_used_as_help(): ], ) def test_serialization(tmp_path: Path, filename: str, args: TrainingArguments): - """test that serializing / deserializing a TrainingArguments works.""" + """Test that serializing / deserializing a TrainingArguments works.""" path = tmp_path / filename save(args, path) diff --git a/test/test_issue64.py b/test/test_issue64.py index 27e51418..d806251c 100644 --- a/test/test_issue64.py +++ b/test/test_issue64.py @@ -85,7 +85,7 @@ def test_vanilla_argparse_issue64(): def test_solved_issue64(): - """test that shows that Issue 64 is solved now, by adding a single space as the 'help' + """Test that shows that Issue 64 is solved now, by adding a single space as the 'help' argument, the help formatter can then add the "(default: bbb)" after the argument.""" parser = ArgumentParser("issue64") parser.add_arguments(Options, dest="options") diff --git a/test/test_set_defaults.py b/test/test_set_defaults.py index cf4bfa32..a7695fba 100644 --- a/test/test_set_defaults.py +++ b/test/test_set_defaults.py @@ -66,7 +66,7 @@ def test_set_broken_defaults_from_file(tmp_path: Path): def test_set_defaults_from_file_without_root(tmp_path: Path): - """test that set_defaults accepts the fields of the dataclass directly, when the parser has + """Test that set_defaults accepts the fields of the dataclass directly, when the parser has nested_mode=NestedMode.WITHOUT_ROOT.""" parser = ArgumentParser(nested_mode=NestedMode.WITHOUT_ROOT) parser.add_arguments(Foo, dest="foo") diff --git a/test/test_subgroups.py b/test/test_subgroups.py index bb6dc9c1..375bf2ae 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -15,10 +15,10 @@ from pytest_regressions.file_regression import FileRegressionFixture from typing_extensions import Annotated -from simple_parsing.utils import Dataclass -from simple_parsing.helpers.serialization import save from simple_parsing import ArgumentParser, parse, subgroups +from simple_parsing.helpers.serialization import save from simple_parsing.helpers.serialization.serializable import from_dict, to_dict +from simple_parsing.utils import Dataclass from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode from .test_choice import Color diff --git a/test/test_subgroups/test_help[Config---help].md b/test/test_subgroups/test_help[Config---help].md index 931e65ff..9a51d7d7 100644 --- a/test/test_subgroups/test_help[Config---help].md +++ b/test/test_subgroups/test_help[Config---help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_a --help].md b/test/test_subgroups/test_help[Config---model=model_a --help].md index 7bc268e4..7d2f8970 100644 --- a/test/test_subgroups/test_help[Config---model=model_a --help].md +++ b/test/test_subgroups/test_help[Config---model=model_a --help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_b --help].md b/test/test_subgroups/test_help[Config---model=model_b --help].md index 21cc825d..1e2fb4c0 100644 --- a/test/test_subgroups/test_help[Config---model=model_b --help].md +++ b/test/test_subgroups/test_help[Config---model=model_b --help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md index 1c23e793..5b82f578 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md index f4b558d3..312b7218 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md index c2169036..cbfa7eeb 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md index c5511112..890e7326 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---help].md b/test/test_subgroups/test_help[ConfigWithFrozen---help].md index 2ef6ddc9..d05b94ef 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---help].md @@ -1,8 +1,4 @@ -<<<<<<< HEAD -# Regression file for [this test](test/test_subgroups.py:731) -======= -# Regression file for [this test](test/test_subgroups.py:736) ->>>>>>> Fix changed regression files +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: From dc263058ed0f10de5afbd27a74a466b3ca6c0b22 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 31 Jan 2024 15:00:55 -0500 Subject: [PATCH 19/20] Update regression files slightly Signed-off-by: Fabrice Normandin --- test/test_subgroups/test_help[Config---help].md | 2 +- test/test_subgroups/test_help[Config---model=model_a --help].md | 2 +- test/test_subgroups/test_help[Config---model=model_b --help].md | 2 +- .../test_help[ConfigWithFrozen---conf=even --a 100 --help].md | 2 +- .../test_help[ConfigWithFrozen---conf=even --help].md | 2 +- .../test_help[ConfigWithFrozen---conf=odd --a 123 --help].md | 2 +- .../test_help[ConfigWithFrozen---conf=odd --help].md | 2 +- test/test_subgroups/test_help[ConfigWithFrozen---help].md | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_subgroups/test_help[Config---help].md b/test/test_subgroups/test_help[Config---help].md index 9a51d7d7..94fce246 100644 --- a/test/test_subgroups/test_help[Config---help].md +++ b/test/test_subgroups/test_help[Config---help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_a --help].md b/test/test_subgroups/test_help[Config---model=model_a --help].md index 7d2f8970..7b992dd9 100644 --- a/test/test_subgroups/test_help[Config---model=model_a --help].md +++ b/test/test_subgroups/test_help[Config---model=model_a --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[Config---model=model_b --help].md b/test/test_subgroups/test_help[Config---model=model_b --help].md index 1e2fb4c0..11a8c76d 100644 --- a/test/test_subgroups/test_help[Config---model=model_b --help].md +++ b/test/test_subgroups/test_help[Config---model=model_b --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md index 5b82f578..fccd9926 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md index 312b7218..f3a3fc88 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md index cbfa7eeb..2c79832f 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md index 890e7326..8938ffcb 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---help].md b/test/test_subgroups/test_help[ConfigWithFrozen---help].md index d05b94ef..189200cd 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:730) +# Regression file for [this test](test/test_subgroups.py:737) Given Source code: From ce02fd1de3cf45a892deecb3b23f21ba47a3e689 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 31 Jan 2024 15:04:38 -0500 Subject: [PATCH 20/20] Add note for later Signed-off-by: Fabrice Normandin --- simple_parsing/subgroup_parsing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/simple_parsing/subgroup_parsing.py b/simple_parsing/subgroup_parsing.py index 91ab29f0..06327944 100644 --- a/simple_parsing/subgroup_parsing.py +++ b/simple_parsing/subgroup_parsing.py @@ -49,6 +49,8 @@ def resolve_subgroups( parser. This is done by adding wrapping the dataclass and adding it to the `wrappers` list. """ + # TODO: Check if there is anything useful in the `_defaults` here. + _defaults_not_associated_with_known_dataclasses = parser._defaults.copy() unresolved_subgroups = _get_subgroup_fields(wrappers) # Dictionary of the subgroup choices that were resolved (key: subgroup dest, value: chosen