diff --git a/lib/python/picongpu/picmi/diagnostics/auto.py b/lib/python/picongpu/picmi/diagnostics/auto.py index a8783e8d4c..79b0110803 100644 --- a/lib/python/picongpu/picmi/diagnostics/auto.py +++ b/lib/python/picongpu/picmi/diagnostics/auto.py @@ -1,11 +1,13 @@ """ This file is part of PIConGPU. Copyright 2025 PIConGPU contributors -Authors: Pawel Ordyna +Authors: Pawel Ordyna, Masoud Afshari License: GPLv3+ """ import typeguard +from typing import Union + from ...pypicongpu.output.auto import Auto as PyPIConGPUAuto from ..copy_attributes import default_converts_to @@ -20,13 +22,22 @@ class Auto: Parameters ---------- - period: int + period: int or TimeStepSpec Number of simulation steps between consecutive outputs. Unit: steps (simulation time steps). """ - period: TimeStepSpec - """Number of simulation steps between consecutive outputs. Unit: steps (simulation time steps).""" + def __init__(self, period: Union[int, TimeStepSpec]) -> None: + if not isinstance(period, (int, TimeStepSpec)): + raise TypeError("period must be an integer or TimeStepSpec") + if isinstance(period, int): + if period < 0: + raise ValueError("period must be non-negative") + self.period = TimeStepSpec[::period]("steps") if period > 0 else TimeStepSpec([])("steps") + else: + self.period = period - def __init__(self, period: TimeStepSpec) -> None: - self.period = period + def check(self, *args, **kwargs): + """Validate that period is a valid TimeStepSpec.""" + if not isinstance(self.period, TimeStepSpec): + raise TypeError("period must be a TimeStepSpec") diff --git a/test/python/picongpu/quick/picmi/diagnostics/__init__.py b/test/python/picongpu/quick/picmi/diagnostics/__init__.py index 2d4b6a07c1..8afd0d0fda 100644 --- a/test/python/picongpu/quick/picmi/diagnostics/__init__.py +++ b/test/python/picongpu/quick/picmi/diagnostics/__init__.py @@ -1,9 +1,10 @@ """ This file is part of PIConGPU. Copyright 2025 PIConGPU contributors -Authors: Julian Lenz +Authors: Julian Lenz, Masoud Afshari License: GPLv3+ """ # flake8: noqa from .timestepspec import * # pyflakes.ignore +from .auto import * # pyflakes.ignore diff --git a/test/python/picongpu/quick/picmi/diagnostics/auto.py b/test/python/picongpu/quick/picmi/diagnostics/auto.py new file mode 100644 index 0000000000..bb20bc6439 --- /dev/null +++ b/test/python/picongpu/quick/picmi/diagnostics/auto.py @@ -0,0 +1,64 @@ +""" +This file is part of PIConGPU. +Copyright 2025 PIConGPU contributors +Authors: Masoud Afshari +License: GPLv3+ +""" + +import unittest +import typeguard +from picongpu.picmi.diagnostics import Auto, TimeStepSpec +from picongpu.pypicongpu.output.auto import Auto as PyPIConGPUAuto +from picongpu.pypicongpu.output.timestepspec import TimeStepSpec as PyPIConGPUTimeStepSpec + + +TESTCASES_VALID = [ + (10, [{"start": 0, "stop": -1, "step": 10}]), + (TimeStepSpec(slice(None, None, 10)), [{"start": 0, "stop": -1, "step": 10}]), +] + +TESTCASES_INVALID = [ + ("invalid", "period must be an integer or TimeStepSpec"), + (-10, "period must be non-negative"), +] + + +class PICMI_TestAuto(unittest.TestCase): + def test_valid_periods(self): + """Test Auto instantiation, validation, and conversion.""" + for period, expected_specs in TESTCASES_VALID: + with self.subTest(period=period): + auto = Auto(period=period) + self.assertIsInstance(auto.period, TimeStepSpec) + auto.check() + + # Convert to PyPIConGPUAuto + pypicongpu_auto = auto.get_as_pypicongpu(0.5, 200) + self.assertIsInstance(pypicongpu_auto, PyPIConGPUAuto) + self.assertIsInstance(pypicongpu_auto.period, PyPIConGPUTimeStepSpec) + + # Validate rendered specs + serialized = pypicongpu_auto.get_rendering_context() + self.assertTrue(serialized["typeID"]["auto"]) + self.assertEqual(serialized["data"]["period"]["specs"], expected_specs) + + def test_invalid_period_type(self): + """Test invalid input types.""" + with self.assertRaises(typeguard.TypeCheckError): + Auto(period="invalid") + + def test_negative_period(self): + """Test negative integer period raises error.""" + with self.assertRaisesRegex(ValueError, "period must be non-negative"): + Auto(period=-5) + + def test_check_invalid_period(self): + """Test that check() catches wrong type.""" + auto = Auto(period=10) + auto.period = "invalid" + with self.assertRaisesRegex(TypeError, "period must be a TimeStepSpec"): + auto.check() + + +if __name__ == "__main__": + unittest.main()