diff --git a/pyproject.toml b/pyproject.toml index 40d8ba16c..e0ad7c043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,6 +197,11 @@ imaginAIry = "imaginAIry" filterwarnings = [ "ignore::UserWarning:segment_anything_hq.modeling.tiny_vit_sam.*", "ignore::DeprecationWarning:timm.models.layers.*", - "ignore::DeprecationWarning:timm.models.registry.*" + "ignore::DeprecationWarning:timm.models.registry.*", + "ignore::FutureWarning:timm.models.layers.*", + "ignore::FutureWarning:timm.models.registry.*", + "ignore:jsonschema.RefResolver is deprecated:DeprecationWarning", + # https://github.com/pytorch/pytorch/issues/136264 + "ignore:__array__ implementation doesn't accept a copy keyword:DeprecationWarning", ] addopts = "--import-mode=importlib" diff --git a/requirements.lock b/requirements.lock index c8cb5164d..84ad85cc0 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,42 +10,42 @@ # universal: false -e file:. -aiohappyeyeballs==2.4.2 +aiohappyeyeballs==2.4.4 # via aiohttp -aiohttp==3.10.6 +aiohttp==3.11.11 # via datasets # via fsspec -aiosignal==1.3.1 +aiosignal==1.3.2 # via aiohttp annotated-types==0.7.0 # via pydantic arrow==1.3.0 # via isoduration -async-timeout==4.0.3 +async-timeout==5.0.1 # via aiohttp -attrs==24.2.0 +attrs==24.3.0 # via aiohttp # via jsonschema # via referencing babel==2.16.0 # via mkdocs-material -bitsandbytes==0.44.0 +bitsandbytes==0.45.0 # via refiners -black==24.8.0 +black==24.10.0 # via refiners -boto3==1.35.28 +boto3==1.35.84 # via neptune -botocore==1.35.28 +botocore==1.35.84 # via boto3 # via s3transfer bravado==11.0.3 # via neptune bravado-core==6.1.1 # via bravado -certifi==2024.8.30 +certifi==2024.12.14 # via requests # via sentry-sdk -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests click==8.1.7 # via black @@ -56,9 +56,9 @@ click==8.1.7 colorama==0.4.6 # via griffe # via mkdocs-material -datasets==3.0.1 +datasets==3.2.0 # via refiners -diffusers==0.30.3 +diffusers==0.31.0 # via refiners dill==0.3.8 # via datasets @@ -76,10 +76,10 @@ filelock==3.16.1 # via triton fqdn==1.5.1 # via jsonschema -frozenlist==1.4.1 +frozenlist==1.5.0 # via aiohttp # via aiosignal -fsspec==2024.6.1 +fsspec==2024.9.0 # via datasets # via huggingface-hub # via torch @@ -93,9 +93,9 @@ gitpython==3.1.43 # via neptune # via refiners # via wandb -griffe==1.3.1 +griffe==1.5.1 # via mkdocstrings-python -huggingface-hub==0.25.1 +huggingface-hub==0.27.0 # via datasets # via diffusers # via refiners @@ -114,7 +114,7 @@ iniconfig==2.0.0 # via pytest isoduration==20.11.0 # via jsonschema -jaxtyping==0.2.34 +jaxtyping==0.2.36 # via refiners jinja2==3.1.4 # via mkdocs @@ -131,9 +131,9 @@ jsonref==1.1.0 jsonschema==4.23.0 # via bravado-core # via swagger-spec-validator -jsonschema-specifications==2023.12.1 +jsonschema-specifications==2024.10.1 # via jsonschema -loguru==0.7.2 +loguru==0.7.3 # via refiners markdown==3.7 # via mkdocs @@ -141,7 +141,7 @@ markdown==3.7 # via mkdocs-material # via mkdocstrings # via pymdown-extensions -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 # via mkdocs # via mkdocs-autorefs @@ -161,14 +161,14 @@ mkdocs-get-deps==0.2.0 # via mkdocs mkdocs-literate-nav==0.6.1 # via refiners -mkdocs-material==9.5.38 +mkdocs-material==9.5.49 # via refiners mkdocs-material-extensions==1.3.1 # via mkdocs-material -mkdocstrings==0.26.1 +mkdocstrings==0.27.0 # via mkdocstrings-python # via refiners -mkdocstrings-python==1.11.1 +mkdocstrings-python==1.12.2 # via mkdocstrings monotonic==1.6 # via bravado @@ -184,51 +184,51 @@ multiprocess==0.70.16 # via datasets mypy-extensions==1.0.0 # via black -neptune==1.11.1 +neptune==1.13.0 # via refiners -networkx==3.3 +networkx==3.4.2 # via torch -numpy==2.1.1 +numpy==2.2.0 # via bitsandbytes # via datasets # via diffusers # via pandas - # via pyarrow # via refiners # via torchvision # via transformers -nvidia-cublas-cu12==12.1.3.1 +nvidia-cublas-cu12==12.4.5.8 # via nvidia-cudnn-cu12 # via nvidia-cusolver-cu12 # via torch -nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-cupti-cu12==12.4.127 # via torch -nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.4.127 # via torch -nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.4.127 # via torch nvidia-cudnn-cu12==9.1.0.70 # via torch -nvidia-cufft-cu12==11.0.2.54 +nvidia-cufft-cu12==11.2.1.3 # via torch -nvidia-curand-cu12==10.3.2.106 +nvidia-curand-cu12==10.3.5.147 # via torch -nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusolver-cu12==11.6.1.9 # via torch -nvidia-cusparse-cu12==12.1.0.106 +nvidia-cusparse-cu12==12.3.1.170 # via nvidia-cusolver-cu12 # via torch -nvidia-nccl-cu12==2.20.5 +nvidia-nccl-cu12==2.21.5 # via torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusolver-cu12 # via nvidia-cusparse-cu12 -nvidia-nvtx-cu12==12.1.105 + # via torch +nvidia-nvtx-cu12==12.4.127 # via torch oauthlib==3.2.2 # via neptune # via requests-oauthlib -packaging==24.1 +packaging==24.2 # via black # via datasets # via huggingface-hub @@ -246,7 +246,7 @@ pandas==2.2.3 pathspec==0.12.1 # via black # via mkdocs -pillow==10.4.0 +pillow==11.0.0 # via diffusers # via neptune # via refiners @@ -260,29 +260,33 @@ platformdirs==4.3.6 # via wandb pluggy==1.5.0 # via pytest -prodigyopt==1.0 +prodigyopt==1.1.1 # via refiners -protobuf==5.28.2 +propcache==0.2.1 + # via aiohttp + # via yarl +protobuf==5.29.2 # via wandb -psutil==6.0.0 +psutil==6.1.0 # via neptune # via wandb -pyarrow==17.0.0 +pyarrow==18.1.0 # via datasets -pydantic==2.9.2 +pydantic==2.10.4 # via refiners -pydantic-core==2.23.4 + # via wandb +pydantic-core==2.27.2 # via pydantic pygments==2.18.0 # via mkdocs-material -pyjwt==2.9.0 +pyjwt==2.10.1 # via neptune -pymdown-extensions==10.10.2 +pymdown-extensions==10.12 # via mkdocs-material # via mkdocstrings -pytest==8.3.3 +pytest==8.3.4 # via pytest-rerunfailures -pytest-rerunfailures==14.0 +pytest-rerunfailures==15.0 # via refiners python-dateutil==2.9.0.post0 # via arrow @@ -312,7 +316,7 @@ pyyaml-env-tag==0.1 referencing==0.35.1 # via jsonschema # via jsonschema-specifications -regex==2024.9.11 +regex==2024.11.6 # via diffusers # via mkdocs-material # via transformers @@ -334,10 +338,10 @@ rfc3339-validator==0.1.4 # via jsonschema rfc3986-validator==0.1.1 # via jsonschema -rpds-py==0.20.0 +rpds-py==0.22.3 # via jsonschema # via referencing -s3transfer==0.10.2 +s3transfer==0.10.4 # via boto3 safetensors==0.4.5 # via diffusers @@ -350,16 +354,16 @@ segment-anything-py==1.0.1 # via refiners sentencepiece==0.2.0 # via refiners -sentry-sdk==2.14.0 +sentry-sdk==2.19.2 # via wandb -setproctitle==1.3.3 +setproctitle==1.3.4 # via wandb -setuptools==75.1.0 +setuptools==75.6.0 # via wandb simplejson==3.19.3 # via bravado # via bravado-core -six==1.16.0 +six==1.17.0 # via bravado # via bravado-core # via docker-pycreds @@ -371,43 +375,42 @@ smmap==5.0.1 swagger-spec-validator==3.0.4 # via bravado-core # via neptune -sympy==1.13.3 +sympy==1.13.1 # via torch -timm==1.0.9 +timm==1.0.12 # via refiners -tokenizers==0.20.0 +tokenizers==0.21.0 # via transformers -tomli==2.0.1 +tomli==2.2.1 # via black # via pytest # via refiners -torch==2.4.1 +torch==2.5.1 # via bitsandbytes # via refiners # via segment-anything-hq # via segment-anything-py # via timm # via torchvision -torchvision==0.19.1 +torchvision==0.20.1 # via piq # via refiners # via segment-anything-hq # via segment-anything-py # via timm -tqdm==4.66.5 +tqdm==4.67.1 # via datasets # via huggingface-hub # via refiners # via transformers -transformers==4.45.1 +transformers==4.47.1 # via refiners -triton==3.0.0 +triton==3.1.0 # via torch -typeguard==2.13.3 - # via jaxtyping -types-python-dateutil==2.9.0.20240906 +types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 + # via bitsandbytes # via black # via bravado # via huggingface-hub @@ -417,6 +420,7 @@ typing-extensions==4.12.2 # via pydantic-core # via swagger-spec-validator # via torch + # via wandb tzdata==2024.2 # via pandas uri-template==1.3.0 @@ -426,17 +430,17 @@ urllib3==2.2.3 # via neptune # via requests # via sentry-sdk -wandb==0.18.1 +wandb==0.19.1 # via refiners -watchdog==5.0.2 +watchdog==6.0.0 # via mkdocs -webcolors==24.8.0 +webcolors==24.11.1 # via jsonschema websocket-client==1.8.0 # via neptune xxhash==3.5.0 # via datasets -yarl==1.13.0 +yarl==1.18.3 # via aiohttp -zipp==3.20.2 +zipp==3.21.0 # via importlib-metadata diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 899702b47..5ee9f3170 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -248,27 +248,34 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str: f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", f"device={tensor.device}", ] + + numel = tensor.numel() + if numel == 0: + return "Tensor(" + ", ".join(info_list) + ")" + if tensor.is_complex(): tensor_f = tensor.real.float() else: - if tensor.numel() > 0: - info_list.extend( - [ - f"min={tensor.min():.2f}", # type: ignore - f"max={tensor.max():.2f}", # type: ignore - ] - ) + info_list.extend( + [ + f"min={tensor.min():.2f}", # type: ignore + f"max={tensor.max():.2f}", # type: ignore + ] + ) tensor_f = tensor.float() info_list.extend( [ f"mean={tensor_f.mean():.2f}", - f"std={tensor_f.std():.2f}", f"norm={norm(x=tensor_f):.2f}", - f"grad={tensor.requires_grad}", ] ) + if numel > 1: + info_list.append(f"std={tensor_f.std():.2f}") + + info_list.append(f"grad={tensor.requires_grad}") + return "Tensor(" + ", ".join(info_list) + ")" diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 740442f5a..56e374b30 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -22,12 +22,10 @@ def __init__( self, training_duration: TimeValue, gradient_accumulation: Step, - lr_scheduler_interval: TimeValue, verbose: bool = True, ) -> None: self.training_duration = training_duration self.gradient_accumulation = gradient_accumulation - self.lr_scheduler_interval = lr_scheduler_interval self.verbose = verbose self.start_time = None self.end_time = None diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py index c8a21f91d..f46ea6148 100644 --- a/src/refiners/training_utils/common.py +++ b/src/refiners/training_utils/common.py @@ -70,7 +70,8 @@ def __enter__(self) -> None: self.random_state = random.getstate() self.numpy_state = np.random.get_state() self.torch_state = torch.get_rng_state() - self.cuda_torch_state = cuda.get_rng_state() + if torch.cuda.is_available(): + self.cuda_torch_state = cuda.get_rng_state() seed_everything(seed) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @@ -78,7 +79,8 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: random.setstate(self.random_state) np.random.set_state(self.numpy_state) torch.set_rng_state(self.torch_state) - cuda.set_rng_state(self.cuda_torch_state) + if torch.cuda.is_available(): + cuda.set_rng_state(self.cuda_torch_state) @dataclass diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index adecfe3c4..09bf608d8 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -130,12 +130,15 @@ def __init__(self, config: ConfigType) -> None: self._load_models() self._call_callbacks(event_name="on_init_end") + # Ensure the lr_scheduler is initialized before calling `step` on the optimizer. + # See `patch_track_step_called` in LRScheduler constructor. + assert self.lr_scheduler + @register_callback() def clock(self, config: ClockConfig) -> TrainingClock: return TrainingClock( training_duration=self.config.training.duration, gradient_accumulation=self.config.training.gradient_accumulation, - lr_scheduler_interval=self.config.lr_scheduler.update_interval, verbose=config.verbose, ) @@ -299,10 +302,13 @@ def backward(self) -> None: self.optimizer.step() self.optimizer.zero_grad() self._call_callbacks(event_name="on_optimizer_step_end") - if self.clock.is_due(self.config.lr_scheduler.update_interval): - self._call_callbacks(event_name="on_lr_scheduler_step_begin") - self.lr_scheduler.step() - self._call_callbacks(event_name="on_lr_scheduler_step_end") + if self.clock.is_due(self.config.lr_scheduler.update_interval): + # TODO: if the update interval is in Epochs, this will be called + # at every optimizer step during targeted epochs. It should probably + # only be called once instead. + self._call_callbacks(event_name="on_lr_scheduler_step_begin") + self.lr_scheduler.step() + self._call_callbacks(event_name="on_lr_scheduler_step_end") def step(self, batch: Batch) -> None: """Perform a single training step.""" diff --git a/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png b/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png index c46dd89f1..fbc3dc0a4 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png and b/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png differ diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index e370a80c5..927b68570 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -205,7 +205,6 @@ def training_clock() -> TrainingClock: return TrainingClock( training_duration=Epoch(5), gradient_accumulation=Step(1), - lr_scheduler_interval=Epoch(1), ) @@ -280,7 +279,7 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None: warnings.filterwarnings( "ignore", category=UserWarning, - message=r"Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`", + message=r"Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`", ) for _ in range(102): warmup_scheduler.step()