Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 51a4e9e

Browse files
author
Laurent
committed
update gaussian_blur fluxion util, see pytorch/vision@45e053b
1 parent a51d695 commit 51a4e9e

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/refiners/fluxion/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,20 @@ def gaussian_blur(
7070
) -> Float[Tensor, "*batch channels height width"]:
7171
assert torch.is_floating_point(tensor)
7272

73-
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
73+
def get_gaussian_kernel1d(
74+
kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device
75+
) -> Float[Tensor, "kernel_size"]:
7476
ksize_half = (kernel_size - 1) * 0.5
75-
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
77+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, device=device, dtype=dtype)
7678
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
7779
kernel1d = pdf / pdf.sum()
7880
return kernel1d
7981

8082
def get_gaussian_kernel2d(
8183
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
8284
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
83-
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
84-
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
85+
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x, dtype, device)
86+
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y, dtype, device)
8587
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
8688
return kernel2d
8789

0 commit comments

Comments
 (0)