Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ setup:

.PHONY: install
install:
uv sync
uv sync --extra training

#
# linter/formatter/typecheck
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,23 @@
[![CI](https://github.com/py-img-gen/diffusers-ncsn/actions/workflows/ci.yaml/badge.svg)](https://github.com/py-img-gen/diffusers-ncsn/actions/workflows/ci.yaml) [![](https://img.shields.io/badge/Official_code-GitHub-green)](https://github.com/ermongroup/ncsn)

[`🤗 diffusers`](https://github.com/huggingface/diffusers) implementation of the paper ["Generative Modeling by Estimating Gradients of the Data Distribution" [Yang+ NeurIPS'19]](https://arxiv.org/abs/1907.05600).

## Installation

```shell
pip install git+https://github.com/py-img-gen/diffusers-ncsn
```

## Showcase

### MNIST

Example of generating MNIST character images using the model trained with [`train_mnist.py`](https://github.com/py-img-gen/diffusers-ncsn/blob/main/train_mnist.py).

<div align="center">
<img alt="mnist" src="https://github.com/user-attachments/assets/483b6637-2684-4844-8aa1-12b866d46226" width="50%" />
</div>

## Acknowledgements

- JeongJiHeon/ScoreDiffusionModel: The Pytorch Tutorial of Score-based and Diffusion Model https://github.com/JeongJiHeon/ScoreDiffusionModel/tree/main
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ description = "Diffusers implementation of Noise Conditional Score Networks (NCS
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"diffusers[torch]>=0.31.0",
"diffusers[torch]>=0.31.0,!=0.32.1",
"einops>=0.7.0",
"torch>=1.0.0",
"torchvision>=0.2.1",
"transformers>=4.30.0",
]

[project.optional-dependencies]
training = [
"matplotlib>=3.10.0",
]

[dependency-groups]
dev = ["mypy>=1.0.0", "pytest>=6.0.0", "ruff>=0.1.5"]

Expand Down
102 changes: 72 additions & 30 deletions src/ncsn/pipeline_ncsn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Self, Tuple, Union

import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from einops import rearrange

Expand Down Expand Up @@ -48,12 +49,21 @@ class NCSNPipeline(DiffusionPipeline):
unet: UNet2DModelForNCSN
scheduler: AnnealedLangevinDynamicScheduler

_callback_tensor_inputs: List[str] = ["samples"]

def __init__(
self, unet: UNet2DModelForNCSN, scheduler: AnnealedLangevinDynamicScheduler
) -> None:
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

def decode_samples(self, samples: torch.Tensor) -> torch.Tensor:
# Normalize the generated image
samples = normalize_images(samples)
# Rearrange the generated image to the correct format
samples = rearrange(samples, "b c w h -> b w h c")
return samples

@torch.no_grad()
def __call__(
self,
Expand All @@ -62,6 +72,14 @@ def __call__(
generator: Optional[torch.Generator] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[
Callable[[Self, int, int, Dict], Dict],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Expand All @@ -79,55 +97,79 @@ def __call__(
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.

Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
sample_shape = (
callback_on_step_end_tensor_inputs = (
callback_on_step_end_tensor_inputs or self._callback_tensor_inputs
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

samples_shape = (
batch_size,
self.unet.config.in_channels, # type: ignore
self.unet.config.sample_size, # type: ignore
self.unet.config.sample_size, # type: ignore
)

# Generate a random sample
sample = torch.rand(sample_shape, generator=generator)
sample = sample.to(self.device)
# NOTE: The behavior of random number generation is different between CPU and GPU,
# so first generate random numbers on CPU and then move them to GPU (if available).
samples = torch.rand(samples_shape, generator=generator)
samples = samples.to(self.device)

# Set the number of inference steps for the scheduler
self.scheduler.set_timesteps(num_inference_steps)

# Perform the reverse diffusion process
for t in self.progress_bar(self.scheduler.timesteps):
# Predict the score using the model
model_output = self.unet(sample, t).sample # type: ignore

# Perform the annealed langevin dynamics
output = self.scheduler.step(
model_output=model_output,
model=self.unet,
timestep=t,
sample=sample,
generator=generator,
return_dict=return_dict,
)
sample = (
output.prev_sample
if isinstance(output, AnnealedLangevinDynamicOutput)
else output[0]
)

# Normalize the generated image
sample = normalize_images(sample)

# Rearrange the generated image to the correct format
sample = rearrange(sample, "b c w h -> b w h c")
# Perform `num_annnealed_steps` annealing steps
for i in range(self.scheduler.num_annealed_steps):
# Predict the score using the model
model_output = self.unet(samples, t).sample # type: ignore

# Perform the annealed langevin dynamics
output = self.scheduler.step(
model_output=model_output,
timestep=t,
samples=samples,
generator=generator,
return_dict=return_dict,
)
samples = (
output.prev_sample
if isinstance(output, AnnealedLangevinDynamicOutput)
else output[0]
)

# Perform the callback on step end if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]

callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
samples = callback_outputs.pop("samples", samples)

samples = self.decode_samples(samples)

if output_type == "pil":
sample = self.numpy_to_pil(sample.cpu().numpy())
samples = self.numpy_to_pil(samples.cpu().numpy())

if return_dict:
return ImagePipelineOutput(images=sample) # type: ignore
return ImagePipelineOutput(images=samples) # type: ignore
else:
return (sample,)
return (samples,)
32 changes: 8 additions & 24 deletions src/ncsn/scheduling_ncsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
sigma_max: float,
sampling_eps: float,
) -> None:
self._num_train_timesteps = num_train_timesteps
self._num_annealed_steps = num_annealed_steps
self.num_train_timesteps = num_train_timesteps
self.num_annealed_steps = num_annealed_steps

self._sigma_min = sigma_min
self._sigma_max = sigma_max
Expand Down Expand Up @@ -99,38 +99,22 @@ def set_sigmas(
sampling_eps = sampling_eps or self._sampling_eps
self._step_size = sampling_eps * (self.sigmas / self.sigmas[-1]) ** 2

def _step_annealing(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
) -> torch.Tensor:
z = torch.randn_like(sample)
step_size = self.step_size[timestep]
sample = sample + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
return sample

def step(
self,
model_output: torch.Tensor,
model,
timestep: int,
sample: torch.Tensor,
samples: torch.Tensor,
return_dict: bool = True,
**kwargs,
) -> Union[AnnealedLangevinDynamicOutput, Tuple]:
for _ in range(self._num_annealed_steps):
sample = self._step_annealing(
model_output=model_output,
timestep=timestep,
sample=sample,
)
model_output = model(sample, timestep).sample
z = torch.randn_like(samples)
step_size = self.step_size[timestep]
samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z

if return_dict:
return AnnealedLangevinDynamicOutput(prev_sample=sample)
return AnnealedLangevinDynamicOutput(prev_sample=samples)
else:
return (sample,)
return (samples,)

def add_noise(
self,
Expand Down
5 changes: 3 additions & 2 deletions src/ncsn/unet_2d_ncsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import torch
from diffusers import UNet2DModel
from diffusers.configuration_utils import register_to_config
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin


class UNet2DModelForNCSN(UNet2DModel):
class UNet2DModelForNCSN(UNet2DModel, ModelMixin, ConfigMixin): # type: ignore[misc]
@register_to_config
def __init__(
self,
Expand Down
Loading
Loading