Skip to content

Commit 951e720

Browse files
authored
ENH XPU support for boft_dreambooth example (#2679)
--------- Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
1 parent 49b29c1 commit 951e720

File tree

5 files changed

+49
-29
lines changed

5 files changed

+49
-29
lines changed

examples/boft_dreambooth/boft_dreambooth.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cd peft/examples/boft_dreambooth
4040

4141
Set up your environment: install PEFT, and all the required libraries. At the time of writing this guide we recommend installing PEFT from source. The following environment setup should work on A100 and H100:
4242

43+
### CUDA
4344
```bash
4445
conda create --name peft python=3.10
4546
conda activate peft
@@ -48,6 +49,16 @@ conda install xformers -c xformers
4849
pip install -r requirements.txt
4950
pip install git+https://github.com/huggingface/peft
5051
```
52+
The follwing environment setuo is validated work on Intel XPU:
53+
54+
### Intel XPU
55+
```bash
56+
conda create --name peft python=3.10
57+
conda activate peft
58+
pip install pip install torch==2.8.0.dev20250615+xpu torchvision==0.23.0.dev20250615+xpu torchaudio==2.8.0.dev20250615+xpu --index-url https://download.pytorch.org/whl/nightly/xpu --no-cache-dir
59+
pip install -r requirements.txt
60+
pip install git+https://github.com/huggingface/peft
61+
```
5162

5263
## Download the data
5364

examples/boft_dreambooth/dreambooth_inference.ipynb

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444
"outputs": [],
4545
"source": [
4646
"def get_boft_sd_pipeline(\n",
47-
" ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"cuda\", adapter_name=\"default\"\n",
47+
" ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"auto\", adapter_name=\"default\"\n",
4848
"):\n",
49+
" if device == \"auto\":\n",
50+
" device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
4951
"\n",
5052
" if base_model_name_or_path is None:\n",
5153
" raise ValueError(\"Please specify the base model name or path\")\n",
@@ -152,14 +154,6 @@
152154
"image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
153155
"image"
154156
]
155-
},
156-
{
157-
"cell_type": "code",
158-
"execution_count": null,
159-
"id": "f534eca2-94a4-432b-b092-7149ac44b12f",
160-
"metadata": {},
161-
"outputs": [],
162-
"source": []
163157
}
164158
],
165159
"metadata": {
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
transformers=>4.48.0
2-
accelerate==0.25.0
1+
transformers==4.54.0
2+
accelerate==1.9.0
33
evaluate
44
tqdm
5-
datasets==2.16.1
6-
diffusers==0.17.1
5+
datasets==4.0.0
6+
diffusers==0.34.0
77
Pillow
88
huggingface_hub
99
safetensors
1010
nb_conda_kernels
1111
ipykernel
1212
ipywidgets
13-
wandb==0.16.1
13+
wandb==0.21.0

examples/boft_dreambooth/train_dreambooth.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main(args):
139139
cur_class_images = len(list(class_images_dir.iterdir()))
140140

141141
if cur_class_images < args.num_class_images:
142-
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
142+
torch_dtype = torch.float16 if accelerator.device.type in ["cuda", "xpu"] else torch.float32
143143
if args.prior_generation_precision == "fp32":
144144
torch_dtype = torch.float32
145145
elif args.prior_generation_precision == "fp16":
@@ -176,6 +176,8 @@ def main(args):
176176
del pipeline
177177
if torch.cuda.is_available():
178178
torch.cuda.empty_cache()
179+
elif torch.xpu.is_available():
180+
torch.xpu.empty_cache()
179181

180182
# Handle the repository creation
181183
if accelerator.is_main_process:
@@ -263,7 +265,9 @@ def main(args):
263265
text_encoder.to(accelerator.device, dtype=weight_dtype)
264266

265267
if args.enable_xformers_memory_efficient_attention:
266-
if is_xformers_available():
268+
if accelerator.device.type == "xpu":
269+
logger.warn("XPU hasn't support xformers yet, ignore it.")
270+
elif is_xformers_available():
267271
unet.enable_xformers_memory_efficient_attention()
268272
else:
269273
raise ValueError("xformers is not available. Make sure it is installed correctly")
@@ -276,7 +280,7 @@ def main(args):
276280

277281
# Enable TF32 for faster training on Ampere GPUs,
278282
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
279-
if args.allow_tf32:
283+
if args.allow_tf32 and torch.cuda.is_available():
280284
torch.backends.cuda.matmul.allow_tf32 = True
281285

282286
if args.scale_lr:
@@ -581,18 +585,27 @@ def main(args):
581585
)
582586

583587
del pipeline
584-
torch.cuda.empty_cache()
588+
if torch.cuda.is_available():
589+
torch.cuda.empty_cache()
590+
elif torch.xpu.is_available():
591+
torch.xpu.empty_cache()
585592

586593
if global_step >= args.max_train_steps:
587594
break
588595

589-
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
596+
# Printing the accelerator memory usage details such as allocated memory, peak memory, and total memory usage
590597
if not args.no_tracemalloc:
591-
accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}")
592-
accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
593-
accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
594598
accelerator.print(
595-
f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
599+
f"{accelerator.device.type.upper()} Memory before entering the train : {b2mb(tracemalloc.begin)}"
600+
)
601+
accelerator.print(
602+
f"{accelerator.device.type.upper()} Memory consumed at the end of the train (end-begin): {tracemalloc.used}"
603+
)
604+
accelerator.print(
605+
f"{accelerator.device.type.upper()} Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}"
606+
)
607+
accelerator.print(
608+
f"{accelerator.device.type.upper()} Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
596609
)
597610

598611
accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")

examples/boft_dreambooth/utils/tracemalloc.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ def b2mb(x):
1313
# This context manager is used to track the peak memory usage of the process
1414
class TorchTracemalloc:
1515
def __enter__(self):
16+
self.device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
17+
self.device_module = getattr(torch, self.device_type, torch.cuda)
1618
gc.collect()
17-
torch.cuda.empty_cache()
18-
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
19-
self.begin = torch.cuda.memory_allocated()
19+
self.device_module.empty_cache()
20+
self.device_module.reset_peak_memory_stats() # reset the peak gauge to zero
21+
self.begin = self.device_module.memory_allocated()
2022
self.process = psutil.Process()
2123

2224
self.cpu_begin = self.cpu_mem_used()
@@ -46,9 +48,9 @@ def __exit__(self, *exc):
4648
self.peak_monitoring = False
4749

4850
gc.collect()
49-
torch.cuda.empty_cache()
50-
self.end = torch.cuda.memory_allocated()
51-
self.peak = torch.cuda.max_memory_allocated()
51+
self.device_module.empty_cache()
52+
self.end = self.device_module.memory_allocated()
53+
self.peak = self.device_module.max_memory_allocated()
5254
self.used = b2mb(self.end - self.begin)
5355
self.peaked = b2mb(self.peak - self.begin)
5456

0 commit comments

Comments
 (0)