Skip to content

Commit 96eb50e

Browse files
authored
Update for tests (#3)
* [WIP] add files * rename unet * rename scheduler * update * update * update
1 parent 3f71841 commit 96eb50e

11 files changed

+387
-25
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Created by https://www.toptal.com/developers/gitignore/api/python
22
# Edit at https://www.toptal.com/developers/gitignore?templates=python
33

4+
*.png
5+
*.json
6+
*.safetensors
7+
48
### Python ###
59
# Byte-compiled / optimized / DLL files
610
__pycache__/

README.md

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,104 @@
11
# 🤗 Noise Conditional Score Networks
22

3-
[![CI](https://github.com/py-img-gen/diffusers-ncsn/actions/workflows/ci.yaml/badge.svg)](https://github.com/py-img-gen/diffusers-ncsn/actions/workflows/ci.yaml) [![](https://img.shields.io/badge/Official_code-GitHub-green)](https://github.com/ermongroup/ncsn)
3+
[![CI](https://github.com/py-img-gen/diffusers-ncsn/actions/workflows/ci.yaml/badge.svg)](https://github.com/py-img-gen/diffusers-ncsn/actions/workflows/ci.yaml)
4+
[![](https://img.shields.io/badge/Official_code-GitHub-green)](https://github.com/ermongroup/ncsn)
5+
[![Model on HF](https://img.shields.io/badge/🤗%20Model%20on%20HF-py--img--gen/ncsn--mnist-D4AA00)](https://huggingface.co/py-img-gen/ncsn-mnist)
46

57
[`🤗 diffusers`](https://github.com/huggingface/diffusers) implementation of the paper ["Generative Modeling by Estimating Gradients of the Data Distribution" [Yang+ NeurIPS'19]](https://arxiv.org/abs/1907.05600).
68

7-
## Installation
9+
## How to use
10+
11+
### Use without installation
12+
13+
You can load the pretrained pipeline directly from the HF Hub as follows:
14+
15+
```python
16+
import torch
17+
from diffusers import DiffusionPipeline
18+
from diffusers.utils import make_image_grid
19+
20+
# Specify the device to use
21+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22+
23+
#
24+
# Load the pipeline from the Hugging Face Hub
25+
#
26+
pipe = DiffusionPipeline.from_pretrained(
27+
"py-img-gen/ncsn-mnist", trust_remote_code=True
28+
)
29+
pipe = pipe.to(device)
30+
31+
# Generate samples; here, we specify the seed and generate 16 images
32+
output = pipe(
33+
batch_size=16,
34+
generator=torch.manual_seed(42),
35+
)
36+
37+
# Create a grid image from the generated samples
38+
image = make_image_grid(images=output.images, rows=4, cols=4)
39+
image.save("output.png")
40+
```
41+
42+
### Use with installation
43+
44+
First, install the package from this repository:
845

946
```shell
1047
pip install git+https://github.com/py-img-gen/diffusers-ncsn
1148
```
1249

50+
Then, you can use the package as follows:
51+
52+
```python
53+
import torch
54+
55+
from ncsn.pipeline_ncsn import NCSNPipeline
56+
57+
# Specify the device to use
58+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59+
60+
#
61+
# Load the pipeline from the HF Hub through the NCSNPipeline of this library
62+
#
63+
pipe = NCSNPipeline.from_pretrained("py-img-gen/ncsn-mnist", trust_remote_code=True)
64+
pipe = pipe.to(device)
65+
66+
# Generate samples; here, we specify the seed and generate 16 images
67+
output = pipe(
68+
batch_size=16,
69+
generator=torch.manual_seed(42),
70+
)
71+
72+
# Create a grid image from the generated samples
73+
image = make_image_grid(images=output.images, rows=4, cols=4)
74+
image.save("output.png")
75+
```
76+
77+
## Pretrained models and pipeline
78+
79+
[![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm.svg)](https://huggingface.co/py-img-gen/ncsn-mnist)
80+
1381
## Showcase
1482

1583
### MNIST
1684

1785
Example of generating MNIST character images using the model trained with [`train_mnist.py`](https://github.com/py-img-gen/diffusers-ncsn/blob/main/train_mnist.py).
1886

19-
<div align="center">
87+
<p align="center">
2088
<img alt="mnist" src="https://github.com/user-attachments/assets/483b6637-2684-4844-8aa1-12b866d46226" width="50%" />
21-
</div>
89+
</p>
90+
91+
# Notes on uploading pipelines to the HF Hub with custom components
92+
93+
While referring to 📝 [Load community pipelines and components - huggingface diffusers](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview#community-components), pay attention to the following points.
94+
- Change [the `_class_name` attribute](https://huggingface.co/py-img-gen/ncsn-mnist/blob/main/model_index.json#L2) in [`model_index.json`](https://huggingface.co/py-img-gen/ncsn-mnist/blob/main/model_index.json) to `["pipeline_ncsn", "NCSNPipeline"]`.
95+
- Upload [`pipeline_ncsn.py`](https://github.com/py-img-gen/diffusers-ncsn/blob/main/src/ncsn/pipeline_ncsn.py) to [the root of the pipeline repository](https://huggingface.co/py-img-gen/ncsn-mnist/blob/main/pipeline_ncsn.py).
96+
- Upload custom components to each subfolder:
97+
- Upload [`scheduling_ncsn.py`](https://github.com/py-img-gen/diffusers-ncsn/blob/main/src/ncsn/scheduler/scheduling_ncsn.py) to [the `scheduler` subfolder](https://huggingface.co/py-img-gen/ncsn-mnist/tree/main/scheduler).
98+
- Upload [`unet_2d_ncsn.py`](https://github.com/py-img-gen/diffusers-ncsn/blob/main/src/ncsn/unet/unet_2d_ncsn.py) to [the `unet` subfolder](https://huggingface.co/py-img-gen/ncsn-mnist/tree/main/unet).
99+
- Ensure that the custom components are placed in each subfolder because they are referenced by relative paths from `pipeline_ncsn.py`.
100+
- Based on this, the code in this library is also placed in the same directory structure as the HF Hub.
101+
- For example, `pipeline_ncsn.py` imports `unet_2d_ncsn.py` as `from .unet.unet_2d_ncsn import UNet2DModelForNCSN` because it is placed in the `unet` subfolder.
22102

23103
## Acknowledgements
24104

src/ncsn/pipeline_ncsn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
66
from einops import rearrange
77

8-
from .scheduling_ncsn import (
9-
AnnealedLangevinDynamicOutput,
10-
AnnealedLangevinDynamicScheduler,
8+
from .scheduler.scheduling_ncsn import (
9+
AnnealedLangevinDynamicsOutput,
10+
AnnealedLangevinDynamicsScheduler,
1111
)
12-
from .unet_2d_ncsn import UNet2DModelForNCSN
12+
from .unet.unet_2d_ncsn import UNet2DModelForNCSN
1313

1414

1515
def normalize_images(image: torch.Tensor) -> torch.Tensor:
@@ -42,17 +42,17 @@ class NCSNPipeline(DiffusionPipeline):
4242
Parameters:
4343
unet ([`UNet2DModelForNCSN`]):
4444
A `UNet2DModelForNCSN` to estimate the score of the image.
45-
scheduler ([`AnnealedLangevinDynamicScheduler`]):
46-
A `AnnealedLangevinDynamicScheduler` to be used in combination with `unet` to estimate the score of the image.
45+
scheduler ([`AnnealedLangevinDynamicsScheduler`]):
46+
A `AnnealedLangevinDynamicsScheduler` to be used in combination with `unet` to estimate the score of the image.
4747
"""
4848

4949
unet: UNet2DModelForNCSN
50-
scheduler: AnnealedLangevinDynamicScheduler
50+
scheduler: AnnealedLangevinDynamicsScheduler
5151

5252
_callback_tensor_inputs: List[str] = ["samples"]
5353

5454
def __init__(
55-
self, unet: UNet2DModelForNCSN, scheduler: AnnealedLangevinDynamicScheduler
55+
self, unet: UNet2DModelForNCSN, scheduler: AnnealedLangevinDynamicsScheduler
5656
) -> None:
5757
super().__init__()
5858
self.register_modules(unet=unet, scheduler=scheduler)
@@ -151,7 +151,7 @@ def __call__(
151151
)
152152
samples = (
153153
output.prev_sample
154-
if isinstance(output, AnnealedLangevinDynamicOutput)
154+
if isinstance(output, AnnealedLangevinDynamicsOutput)
155155
else output[0]
156156
)
157157

src/ncsn/scheduler/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .scheduling_ncsn import (
2+
AnnealedLangevinDynamicsOutput,
3+
AnnealedLangevinDynamicsScheduler,
4+
)
5+
6+
__all__ = [
7+
"AnnealedLangevinDynamicsOutput",
8+
"AnnealedLangevinDynamicsScheduler",
9+
]

src/ncsn/scheduling_ncsn.py renamed to src/ncsn/scheduler/scheduling_ncsn.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616

1717
@dataclass
18-
class AnnealedLangevinDynamicOutput(SchedulerOutput):
19-
"""Annealed Langevin Dynamic output class."""
18+
class AnnealedLangevinDynamicsOutput(SchedulerOutput):
19+
"""Annealed Langevin Dynamics output class."""
2020

2121

22-
class AnnealedLangevinDynamicScheduler(SchedulerMixin, ConfigMixin): # type: ignore
22+
class AnnealedLangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): # type: ignore
23+
"""Annealed Langevin Dynamics scheduler for Noise Conditional Score Network (NCSN)."""
24+
2325
order = 1
2426

2527
@register_to_config
@@ -106,13 +108,13 @@ def step(
106108
samples: torch.Tensor,
107109
return_dict: bool = True,
108110
**kwargs,
109-
) -> Union[AnnealedLangevinDynamicOutput, Tuple]:
111+
) -> Union[AnnealedLangevinDynamicsOutput, Tuple]:
110112
z = torch.randn_like(samples)
111113
step_size = self.step_size[timestep]
112114
samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
113115

114116
if return_dict:
115-
return AnnealedLangevinDynamicOutput(prev_sample=samples)
117+
return AnnealedLangevinDynamicsOutput(prev_sample=samples)
116118
else:
117119
return (samples,)
118120

src/ncsn/unet/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .unet_2d_ncsn import UNet2DModelForNCSN
2+
3+
__all__ = [
4+
"UNet2DModelForNCSN",
5+
]
File renamed without changes.

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pathlib
2+
3+
import pytest
4+
import torch
5+
6+
7+
@pytest.fixture
8+
def device() -> torch.device:
9+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
11+
12+
@pytest.fixture
13+
def seed() -> int:
14+
return 19950815
15+
16+
17+
@pytest.fixture
18+
def root_dir() -> pathlib.Path:
19+
return pathlib.Path(__file__).parents[1]
20+
21+
22+
@pytest.fixture
23+
def project_dir(root_dir: pathlib.Path) -> pathlib.Path:
24+
dirpath = root_dir / "outputs"
25+
dirpath.mkdir(parents=True, exist_ok=True)
26+
return dirpath
27+
28+
29+
@pytest.fixture
30+
def lib_dir(root_dir: pathlib.Path) -> pathlib.Path:
31+
return root_dir / "src" / "ncsn"

0 commit comments

Comments
 (0)