@@ -139,7 +139,7 @@ def main(args):
139
139
cur_class_images = len (list (class_images_dir .iterdir ()))
140
140
141
141
if cur_class_images < args .num_class_images :
142
- torch_dtype = torch .float16 if accelerator .device .type == "cuda" else torch .float32
142
+ torch_dtype = torch .float16 if accelerator .device .type in [ "cuda" , "xpu" ] else torch .float32
143
143
if args .prior_generation_precision == "fp32" :
144
144
torch_dtype = torch .float32
145
145
elif args .prior_generation_precision == "fp16" :
@@ -176,6 +176,8 @@ def main(args):
176
176
del pipeline
177
177
if torch .cuda .is_available ():
178
178
torch .cuda .empty_cache ()
179
+ elif torch .xpu .is_available ():
180
+ torch .xpu .empty_cache ()
179
181
180
182
# Handle the repository creation
181
183
if accelerator .is_main_process :
@@ -263,7 +265,9 @@ def main(args):
263
265
text_encoder .to (accelerator .device , dtype = weight_dtype )
264
266
265
267
if args .enable_xformers_memory_efficient_attention :
266
- if is_xformers_available ():
268
+ if accelerator .device .type == "xpu" :
269
+ logger .warn ("XPU hasn't support xformers yet, ignore it." )
270
+ elif is_xformers_available ():
267
271
unet .enable_xformers_memory_efficient_attention ()
268
272
else :
269
273
raise ValueError ("xformers is not available. Make sure it is installed correctly" )
@@ -276,7 +280,7 @@ def main(args):
276
280
277
281
# Enable TF32 for faster training on Ampere GPUs,
278
282
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
279
- if args .allow_tf32 :
283
+ if args .allow_tf32 and torch . cuda . is_available () :
280
284
torch .backends .cuda .matmul .allow_tf32 = True
281
285
282
286
if args .scale_lr :
@@ -581,18 +585,27 @@ def main(args):
581
585
)
582
586
583
587
del pipeline
584
- torch .cuda .empty_cache ()
588
+ if torch .cuda .is_available ():
589
+ torch .cuda .empty_cache ()
590
+ elif torch .xpu .is_available ():
591
+ torch .xpu .empty_cache ()
585
592
586
593
if global_step >= args .max_train_steps :
587
594
break
588
595
589
- # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
596
+ # Printing the accelerator memory usage details such as allocated memory, peak memory, and total memory usage
590
597
if not args .no_tracemalloc :
591
- accelerator .print (f"GPU Memory before entering the train : { b2mb (tracemalloc .begin )} " )
592
- accelerator .print (f"GPU Memory consumed at the end of the train (end-begin): { tracemalloc .used } " )
593
- accelerator .print (f"GPU Peak Memory consumed during the train (max-begin): { tracemalloc .peaked } " )
594
598
accelerator .print (
595
- f"GPU Total Peak Memory consumed during the train (max): { tracemalloc .peaked + b2mb (tracemalloc .begin )} "
599
+ f"{ accelerator .device .type .upper ()} Memory before entering the train : { b2mb (tracemalloc .begin )} "
600
+ )
601
+ accelerator .print (
602
+ f"{ accelerator .device .type .upper ()} Memory consumed at the end of the train (end-begin): { tracemalloc .used } "
603
+ )
604
+ accelerator .print (
605
+ f"{ accelerator .device .type .upper ()} Peak Memory consumed during the train (max-begin): { tracemalloc .peaked } "
606
+ )
607
+ accelerator .print (
608
+ f"{ accelerator .device .type .upper ()} Total Peak Memory consumed during the train (max): { tracemalloc .peaked + b2mb (tracemalloc .begin )} "
596
609
)
597
610
598
611
accelerator .print (f"CPU Memory before entering the train : { b2mb (tracemalloc .cpu_begin )} " )
0 commit comments