Skip to content

Commit a4b41e7

Browse files
authored
ENH Support XPU in train_memory.py script (#2729)
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
1 parent e98a59e commit a4b41e7

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

scripts/train_memory.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,18 @@
7070
# suppress all warnings
7171
warnings.filterwarnings("ignore")
7272

73-
device = "cuda" if torch.cuda.is_available() else "cpu"
73+
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
7474
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
7575

7676

77-
def init_cuda():
77+
def init_accelerator():
7878
torch.manual_seed(0)
7979
if device == "cpu":
8080
return
8181

82-
torch.cuda.reset_peak_memory_stats()
83-
torch.cuda.manual_seed_all(0)
82+
device_module = getattr(torch, device, torch.cuda)
83+
device_module.reset_peak_memory_stats()
84+
device_module.manual_seed_all(0)
8485
# might not be necessary, but just to be sure
8586
nn.Linear(1, 1).to(device)
8687

@@ -106,9 +107,10 @@ def tokenize(samples):
106107

107108

108109
def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config):
109-
init_cuda()
110-
cuda_memory_init = torch.cuda.max_memory_allocated()
111-
cuda_memory_log = []
110+
init_accelerator()
111+
device_module = getattr(torch, device, torch.cuda)
112+
accelerator_memory_init = device_module.max_memory_allocated()
113+
accelerator_memory_log = []
112114

113115
tokenizer = AutoTokenizer.from_pretrained(model_id)
114116
tokenizer.model_max_length = max_seq_length
@@ -177,8 +179,8 @@ def unpack(x):
177179
loss.backward()
178180
optimizer.step()
179181
losses.append(loss.item())
180-
cuda_memory_log.append(torch.cuda.memory_allocated() - cuda_memory_init)
181-
torch.cuda.empty_cache()
182+
accelerator_memory_log.append(device_module.memory_allocated() - accelerator_memory_init)
183+
device_module.empty_cache()
182184
gc.collect()
183185
toc = time.perf_counter()
184186
print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr)
@@ -191,10 +193,10 @@ def unpack(x):
191193

192194
toc_total = time.perf_counter()
193195

194-
cuda_memory_final = torch.cuda.max_memory_allocated()
195-
cuda_memory_avg = int(sum(cuda_memory_log) / len(cuda_memory_log))
196-
print(f"cuda memory avg: {cuda_memory_avg // 2**20}MB")
197-
print(f"cuda memory max: {(cuda_memory_final - cuda_memory_init) // 2**20}MB")
196+
accelerator_memory_final = device_module.max_memory_allocated()
197+
accelerator_memory_avg = int(sum(accelerator_memory_log) / len(accelerator_memory_log))
198+
print(f"{model.device.type} memory avg: {accelerator_memory_avg // 2**20}MB")
199+
print(f"{model.device.type} memory max: {(accelerator_memory_final - accelerator_memory_init) // 2**20}MB")
198200
print(f"total time: {toc_total - tic_total:.2f}s")
199201

200202
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)