Skip to content

Commit e423ba4

Browse files
author
Laurent
committed
remove a couple from torch import ... from the code
1 parent 45143e2 commit e423ba4

File tree

13 files changed

+76
-80
lines changed

13 files changed

+76
-80
lines changed

src/refiners/fluxion/layers/chain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
77

88
import torch
9-
from torch import Tensor, cat, device as Device, dtype as DType
9+
from torch import Tensor, device as Device, dtype as DType
1010

1111
from refiners.fluxion.context import ContextProvider, Contexts
1212
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
@@ -950,7 +950,7 @@ def __init__(self, *modules: Module, dim: int = 0) -> None:
950950

951951
def forward(self, *args: Any) -> Tensor:
952952
outputs = [module(*args) for module in self]
953-
return cat(
953+
return torch.cat(
954954
[output for output in outputs if output is not None],
955955
dim=self.dim,
956956
)

src/refiners/fluxion/layers/norm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import torch
12
from jaxtyping import Float
2-
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
3+
from torch import Tensor, device as Device, dtype as DType
34
from torch.nn import (
45
GroupNorm as _GroupNorm,
56
InstanceNorm2d as _InstanceNorm2d,
@@ -111,8 +112,8 @@ def __init__(
111112
dtype: DType | None = None,
112113
) -> None:
113114
super().__init__()
114-
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
115-
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
115+
self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
116+
self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
116117
self.eps = eps
117118

118119
def forward(
@@ -121,7 +122,7 @@ def forward(
121122
) -> Float[Tensor, "batch channels height width"]:
122123
x_mean = x.mean(1, keepdim=True)
123124
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
124-
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
125+
x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps)
125126
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
126127
return x_out
127128

src/refiners/fluxion/utils.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,22 @@
88
from PIL import Image
99
from safetensors import safe_open as _safe_open # type: ignore
1010
from safetensors.torch import save_file as _save_file # type: ignore
11-
from torch import (
12-
Tensor,
13-
cat,
14-
device as Device,
15-
dtype as DType,
16-
manual_seed as _manual_seed, # type: ignore
17-
no_grad as _no_grad, # type: ignore
18-
norm as _norm, # type: ignore
19-
)
11+
from torch import Tensor, device as Device, dtype as DType
2012
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
2113

2214
T = TypeVar("T")
2315
E = TypeVar("E")
2416

2517

2618
def norm(x: Tensor) -> Tensor:
27-
return _norm(x) # type: ignore
19+
return torch.norm(x) # type: ignore
2820

2921

3022
def manual_seed(seed: int) -> None:
31-
_manual_seed(seed)
23+
torch.manual_seed(seed) # type: ignore
3224

3325

34-
class no_grad(_no_grad):
26+
class no_grad(torch.no_grad):
3527
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
3628
return object.__new__(cls)
3729

@@ -123,7 +115,7 @@ def default_sigma(kernel_size: int) -> float:
123115
def images_to_tensor(
124116
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
125117
) -> Tensor:
126-
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
118+
return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
127119

128120

129121
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:

src/refiners/foundationals/clip/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from torch import Tensor, arange, device as Device, dtype as DType
1+
import torch
2+
from torch import Tensor, device as Device, dtype as DType
23

34
import refiners.fluxion.layers as fl
45

@@ -25,7 +26,7 @@ def __init__(
2526

2627
@property
2728
def position_ids(self) -> Tensor:
28-
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
29+
return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
2930

3031
def get_position_ids(self, x: Tensor) -> Tensor:
3132
return self.position_ids[:, : x.shape[1]]

src/refiners/foundationals/clip/concepts.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import re
22
from typing import cast
33

4+
import torch
45
import torch.nn.functional as F
5-
from torch import Tensor, cat, zeros
6+
from torch import Tensor
67
from torch.nn import Parameter
78

89
import refiners.fluxion.layers as fl
@@ -22,19 +23,26 @@ def __init__(
2223
with self.setup_adapter(target):
2324
super().__init__(fl.Lambda(func=self.lookup))
2425
p = Parameter(
25-
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
26+
torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
2627
) # requires_grad=True by default
2728
self.old_weight = cast(Parameter, target.weight)
2829
self.new_weight = p
2930

3031
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
3132
def lookup(self, x: Tensor) -> Tensor:
3233
# Concatenate old and new weights for dynamic embedding updates during training
33-
return F.embedding(x, cat([self.old_weight, self.new_weight]))
34+
return F.embedding(x, torch.cat([self.old_weight, self.new_weight]))
3435

3536
def add_embedding(self, embedding: Tensor) -> None:
3637
assert embedding.shape == (self.old_weight.shape[1],)
37-
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
38+
p = Parameter(
39+
torch.cat(
40+
[
41+
self.new_weight,
42+
embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype),
43+
]
44+
)
45+
)
3846
self.new_weight = p
3947

4048
@property

src/refiners/foundationals/latent_diffusion/image_prompt.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import math
22
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
33

4+
import torch
45
from jaxtyping import Float
56
from PIL import Image
6-
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
7+
from torch import Tensor, device as Device, dtype as DType, nn
78

89
import refiners.fluxion.layers as fl
910
from refiners.fluxion.adapters.adapter import Adapter
@@ -98,7 +99,7 @@ def forward(
9899
v = self.reshape_tensor(value)
99100

100101
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
101-
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
102+
attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype)
102103
attention = attention @ v
103104

104105
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
@@ -159,7 +160,7 @@ def __init__(
159160
)
160161

161162
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
162-
return cat((x, latents), dim=-2)
163+
return torch.cat((x, latents), dim=-2)
163164

164165

165166
class LatentsToken(fl.Chain):
@@ -484,7 +485,7 @@ def compute_clip_image_embedding(
484485
image_prompt = self.preprocess_image(image_prompt)
485486
elif isinstance(image_prompt, list):
486487
assert all(isinstance(image, Image.Image) for image in image_prompt)
487-
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
488+
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])
488489

489490
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
490491

@@ -493,28 +494,28 @@ def compute_clip_image_embedding(
493494
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
494495
if any(weight != 1.0 for weight in weights):
495496
conditional_embedding *= (
496-
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
497+
torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
497498
.unsqueeze(-1)
498499
.unsqueeze(-1)
499500
)
500501

501502
if batch_size > 1 and concat_batches:
502503
# Create a longer image tokens sequence when a batch of images is given
503504
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
504-
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
505-
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
505+
negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
506+
conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)
506507

507-
return cat((negative_embedding, conditional_embedding))
508+
return torch.cat((negative_embedding, conditional_embedding))
508509

509510
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
510511
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
511512
clip_embedding = image_encoder(image_prompt)
512513
conditional_embedding = self.image_proj(clip_embedding)
513514
if not self.fine_grained:
514-
negative_embedding = self.image_proj(zeros_like(clip_embedding))
515+
negative_embedding = self.image_proj(torch.zeros_like(clip_embedding))
515516
else:
516517
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
517-
clip_embedding = image_encoder(zeros_like(image_prompt))
518+
clip_embedding = image_encoder(torch.zeros_like(image_prompt))
518519
negative_embedding = self.image_proj(clip_embedding)
519520
return negative_embedding, conditional_embedding
520521

src/refiners/foundationals/latent_diffusion/range_adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
22

3+
import torch
34
from jaxtyping import Float, Int
4-
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
5+
from torch import Tensor, device as Device, dtype as DType
56

67
import refiners.fluxion.layers as fl
78
from refiners.fluxion.adapters.adapter import Adapter
@@ -14,10 +15,10 @@ def compute_sinusoidal_embedding(
1415
half_dim = embedding_dim // 2
1516
# Note: it is important that this computation is done in float32.
1617
# The result can be cast to lower precision later if necessary.
17-
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
18+
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
1819
exponent /= half_dim
19-
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
20-
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
20+
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
21+
embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1)
2122
return embedding
2223

2324

src/refiners/foundationals/latent_diffusion/solvers/ddim.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22

3-
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
3+
import torch
4+
from torch import Generator, Tensor, device as Device, dtype as Dtype
45

56
from refiners.foundationals.latent_diffusion.solvers.solver import (
67
BaseSolverParams,
@@ -28,7 +29,7 @@ def __init__(
2829
first_inference_step: int = 0,
2930
params: BaseSolverParams | None = None,
3031
device: Device | str = "cpu",
31-
dtype: Dtype = float32,
32+
dtype: Dtype = torch.float32,
3233
) -> None:
3334
"""Initializes a new DDIM solver.
3435
@@ -71,7 +72,7 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
7172
(
7273
self.timesteps[step + 1]
7374
if step < self.num_inference_steps - 1
74-
else tensor(data=[0], device=self.device, dtype=self.dtype)
75+
else torch.tensor(data=[0], device=self.device, dtype=self.dtype)
7576
),
7677
)
7778
current_scale_factor, previous_scale_factor = (
@@ -82,8 +83,8 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
8283
else self.cumulative_scale_factors[0]
8384
),
8485
)
85-
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
86-
noise_factor = sqrt(1 - previous_scale_factor**2)
86+
predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
87+
noise_factor = torch.sqrt(1 - previous_scale_factor**2)
8788

8889
# Do not add noise at the last step to avoid visual artifacts.
8990
if step == self.num_inference_steps - 1:

src/refiners/foundationals/latent_diffusion/solvers/dpm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55
import torch
6-
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
6+
from torch import Generator, Tensor, device as Device, dtype as Dtype
77

88
from refiners.foundationals.latent_diffusion.solvers.solver import (
99
BaseSolverParams,
@@ -38,7 +38,7 @@ def __init__(
3838
params: BaseSolverParams | None = None,
3939
last_step_first_order: bool = False,
4040
device: Device | str = "cpu",
41-
dtype: Dtype = float32,
41+
dtype: Dtype = torch.float32,
4242
):
4343
"""Initializes a new DPM solver.
4444
@@ -62,7 +62,7 @@ def __init__(
6262
device=device,
6363
dtype=dtype,
6464
)
65-
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
65+
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
6666
self.last_step_first_order = last_step_first_order
6767

6868
def rebuild(
@@ -94,7 +94,7 @@ def _generate_timesteps(self) -> Tensor:
9494
offset = self.params.timesteps_offset
9595
max_timestep = self.params.num_train_timesteps - 1 + offset
9696
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
97-
return tensor(np_space).flip(0)
97+
return torch.tensor(np_space).flip(0)
9898

9999
def dpm_solver_first_order_update(
100100
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
@@ -110,7 +110,7 @@ def dpm_solver_first_order_update(
110110
The denoised version of the input data `x`.
111111
"""
112112
current_timestep = self.timesteps[step]
113-
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
113+
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
114114

115115
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
116116
current_ratio = self.signal_to_noise_ratios[current_timestep]
@@ -144,7 +144,7 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noi
144144
Returns:
145145
The denoised version of the input data `x`.
146146
"""
147-
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
147+
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
148148
current_timestep = self.timesteps[step]
149149
next_timestep = self.timesteps[step - 1]
150150

src/refiners/foundationals/latent_diffusion/solvers/euler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import torch
3-
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
3+
from torch import Generator, Tensor, device as Device, dtype as Dtype
44

55
from refiners.foundationals.latent_diffusion.solvers.solver import (
66
BaseSolverParams,
@@ -23,7 +23,7 @@ def __init__(
2323
first_inference_step: int = 0,
2424
params: BaseSolverParams | None = None,
2525
device: Device | str = "cpu",
26-
dtype: Dtype = float32,
26+
dtype: Dtype = torch.float32,
2727
):
2828
"""Initializes a new Euler solver.
2929
@@ -57,7 +57,7 @@ def _generate_sigmas(self) -> Tensor:
5757
"""Generate the sigmas used by the solver."""
5858
sigmas = self.noise_std / self.cumulative_scale_factors
5959
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
60-
sigmas = torch.cat([sigmas, tensor([0.0])])
60+
sigmas = torch.cat([sigmas, torch.tensor([0.0])])
6161
return sigmas.to(device=self.device, dtype=self.dtype)
6262

6363
def scale_model_input(self, x: Tensor, step: int) -> Tensor:

0 commit comments

Comments
 (0)