Skip to content

Commit 1e8c2d2

Browse files
Nupur KumariNupur Kumari
authored andcommitted
update
1 parent 10b6350 commit 1e8c2d2

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/diffusers_model_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def __init__(
414414
scheduler,
415415
safety_checker,
416416
feature_extractor,
417-
requires_safety_checker)
417+
requires_safety_checker=requires_safety_checker)
418418

419419
# change attn class
420420
self.modifier_token = modifier_token

src/diffusers_training.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,13 @@
231231
import diffusers
232232
from accelerate.logging import get_logger
233233
from accelerate import Accelerator
234-
from accelerate.utils import set_seed
234+
from accelerate.utils import set_seed, ProjectConfiguration
235235
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
236236
from diffusers.optimization import get_scheduler
237237
from huggingface_hub import HfFolder, Repository, create_repo, whoami
238238
from tqdm.auto import tqdm
239239
from transformers import AutoTokenizer, PretrainedConfig
240-
from diffusers.models.cross_attention import CrossAttention
240+
from diffusers.models.attention import Attention
241241
from diffusers.utils.import_utils import is_xformers_available
242242
from diffusers.utils import check_min_version, is_wandb_available
243243

@@ -273,7 +273,7 @@ def create_custom_diffusion(unet, freeze_model):
273273
# change attn class
274274
def change_attn(unet):
275275
for layer in unet.children():
276-
if type(layer) == CrossAttention:
276+
if type(layer) == Attention:
277277
bound_method = set_use_memory_efficient_attention_xformers.__get__(layer, layer.__class__)
278278
setattr(layer, 'set_use_memory_efficient_attention_xformers', bound_method)
279279
else:
@@ -593,6 +593,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
593593

594594
def main(args):
595595
logging_dir = Path(args.output_dir, args.logging_dir)
596+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
596597

597598
accelerator = Accelerator(
598599
gradient_accumulation_steps=args.gradient_accumulation_steps,

0 commit comments

Comments
 (0)