Skip to content
11 changes: 11 additions & 0 deletions tests/assets/exp_config/RegClass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from utilsd.config import Registry


class DummyReg(metaclass=Registry, name="test"):
pass

@DummyReg.register_module()
class Reg():
def __init__(self, a: int, b: str):
self.a = a
self.b = b
4 changes: 4 additions & 0 deletions tests/assets/exp_config/TestClass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class DummyClass():
def __init__(self, c: int, d: str):
self.c = c
self.d = d
Empty file.
116 changes: 116 additions & 0 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import sys
import json
import tempfile
from pathlib import Path

try:
import torch
test_torch = True
except ImportError:
test_torch = False

from utilsd.config import RegistryConfig, ClassConfig, configclass, RuntimeConfig, PythonConfig, Registry
from utilsd.experiment import (
get_checkpoint_dir,
get_config_types,
get_output_dir,
get_tb_log_dir,
get_runtime_config,
is_debugging,
print_config,
setup_experiment,
use_cuda,
)

import pytest


def test_setup_experiment():
with pytest.raises(AssertionError):
# raise AssertionError when cruntime_config is not intialized
config = get_runtime_config()


# Create a temporary directory for the experiment
with tempfile.TemporaryDirectory() as tmpdir:
# Set up the experiment
runtime_config = RuntimeConfig(
output_dir=Path(tmpdir),
seed=123,
debug=True,
)
_ = setup_experiment(runtime_config)

runtime_config_to_check = get_runtime_config()
assert runtime_config_to_check.output_dir == get_output_dir() == Path(tmpdir)
assert runtime_config_to_check.checkpoint_dir == get_checkpoint_dir() == Path(tmpdir) / "checkpoints"
assert runtime_config_to_check.tb_log_dir == get_tb_log_dir() == Path(tmpdir) / "tb"
assert runtime_config_to_check.seed == 123
assert runtime_config_to_check.debug == is_debugging() == True

# Check that PyTorch is using CUDA if available
if test_torch or "torch" in sys.modules:
assert use_cuda() == (runtime_config_to_check.use_cuda and torch.cuda.is_available())


from tests.assets.exp_config.RegClass import DummyReg, Reg
from tests.assets.exp_config.TestClass import DummyClass

@configclass
class ExpConfig(PythonConfig):
reg: RegistryConfig[DummyReg]
clss: ClassConfig[DummyClass]
int_var: int = 1
runtime: RuntimeConfig = RuntimeConfig()

# Define a configuration dictionary
with tempfile.TemporaryDirectory() as tmpdir:
config = ExpConfig.fromdict({
"reg": {
"type": "Reg",
"a": 1,
"b": "hello",
},
"clss": {
"c": 2,
"d": "world",
},
"int_var": 3,
"runtime": {
"output_dir": tmpdir,
"debug": True,
}
})


def test_get_config_types():
# Get the types of the configuration values
types = get_config_types(config)
# Check that the types are correct
assert types["_config_type_name"] == ExpConfig.__name__
assert types["reg"]["_config_type_name"] == Reg.__name__ + "Config"
assert types["reg"]["_type_module"] == Reg.__module__
assert types["reg"]["_type_name"] == Reg.__name__
assert types["clss"]["_config_type_name"] == DummyClass.__name__ + "Config"
assert types["clss"]["_type_module"] == DummyClass.__module__
assert types["clss"]["_type_name"] == DummyClass.__name__
assert types["runtime"]["_config_type_name"] == RuntimeConfig.__name__
assert "_type_name" not in types["runtime"] # RuntimeConfig doesn't have a type() method
assert "int_var" not in types # Non-dict or Non-dataclass values should not be included


def test_print_config():
# Print the configuration
with tempfile.TemporaryDirectory() as tmpdir:
print_config(config, output_dir=tmpdir, infer_types=True)

with open(os.path.join(tmpdir, "config.json")) as fh:
config_to_check = json.load(fh)
with open(os.path.join(tmpdir, "config_type.json")) as fh:
config_type_to_check = json.load(fh)

assert config_to_check["int_var"] == 3
assert config_to_check["reg"]["a"] == 1
assert config_to_check["runtime"]["debug"] == True
assert config_type_to_check["reg"]["_type_module"] == Reg.__module__
46 changes: 42 additions & 4 deletions utilsd/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import sys
import pprint
import random
import warnings
Expand All @@ -12,21 +13,25 @@
import numpy as np
try:
import torch
_use_torch = True
except ImportError:
warnings.warn('PyTorch is not installed. Some features of utilsd might not work.')
_use_torch = False
from .config.builtin import RuntimeConfig
from .config.registry import RegistryConfig
from .logging import mute_logger, print_log, setup_logger, reset_logger

_runtime_config: Optional[RuntimeConfig] = None
_use_cuda: Optional[bool] = None


def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if _use_torch or "torch" in sys.modules:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


def setup_distributed_training():
Expand Down Expand Up @@ -94,7 +99,31 @@ def setup_experiment(runtime_config: RuntimeConfig, enable_nni: bool = False,
return runtime_config


def print_config(config, dump_config=True, output_dir=None, expand_config=True):
def get_config_types(config):
type_dict = dict()
if isinstance(config, dict):
for key in config:
if isinstance(config[key], dict) or dataclasses.is_dataclass(config[key]):
value = get_config_types(config[key])
if len(value) != 0:
type_dict[key] = value
elif dataclasses.is_dataclass(config):
type_dict["_config_type_module"] = type(config).__module__
type_dict["_config_type_name"] = type(config).__name__
try:
_type = config.type()
type_dict["_type_module"] = _type.__module__
type_dict["_type_name"] = _type.__name__
except (TypeError, AttributeError):
pass
for f in dataclasses.fields(config):
value = get_config_types(getattr(config, f.name))
if len(value) != 0:
type_dict[f.name] = value
return type_dict


def print_config(config, dump_config=True, output_dir=None, expand_config=True, infer_types=False):
class Encoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):
Expand All @@ -103,6 +132,10 @@ def default(self, obj):
return obj.as_posix()
return super().default(obj)

if infer_types:
config_type = get_config_types(config)
else:
config_type = None
if isinstance(config, dict):
config_meta = None
else:
Expand All @@ -116,12 +149,17 @@ def default(self, obj):
print_log('Config (meta): ' + json.dumps(config_meta, cls=Encoder), __name__)
if expand_config:
print_log('Config (expanded):\n' + pprint.pformat(config), __name__)
if config_type is not None and len(config_type) != 0:
print_log('Config (types):\n' + pprint.pformat(config_type), __name__)
if dump_config:
with open(os.path.join(output_dir, 'config.json'), 'w') as fh:
json.dump(config, fh, cls=Encoder)
if config_meta is not None:
with open(os.path.join(output_dir, 'config_meta.json'), 'w') as fh:
json.dump(config_meta, fh, cls=Encoder)
if config_type is not None and len(config_type) != 0:
with open(os.path.join(output_dir, 'config_type.json'), 'w') as fh:
json.dump(config_type, fh, cls=Encoder)


def get_runtime_config():
Expand Down
2 changes: 1 addition & 1 deletion utilsd/fileio/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _format_dict(input_dict, outest_level=False):
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
text, _ = FormatCode(text, style_config=yapf_style)

return text

Expand Down