70
70
# suppress all warnings
71
71
warnings .filterwarnings ("ignore" )
72
72
73
- device = "cuda" if torch .cuda . is_available () else "cpu "
73
+ device = torch .accelerator . current_accelerator (). type if hasattr ( torch , "accelerator" ) else "cuda "
74
74
dtype_to_bytes_linear = {"float32" : 4 , "float16" : 2 , "bfloat16" : 2 , "int8" : 1 , "int4" : 0.5 }
75
75
76
76
77
- def init_cuda ():
77
+ def init_accelerator ():
78
78
torch .manual_seed (0 )
79
79
if device == "cpu" :
80
80
return
81
81
82
- torch .cuda .reset_peak_memory_stats ()
83
- torch .cuda .manual_seed_all (0 )
82
+ device_module = getattr (torch , device , torch .cuda )
83
+ device_module .reset_peak_memory_stats ()
84
+ device_module .manual_seed_all (0 )
84
85
# might not be necessary, but just to be sure
85
86
nn .Linear (1 , 1 ).to (device )
86
87
@@ -106,9 +107,10 @@ def tokenize(samples):
106
107
107
108
108
109
def train (model_id , rank , dtype , monitor_tensors , max_seq_length , batch_size , max_steps , path_config ):
109
- init_cuda ()
110
- cuda_memory_init = torch .cuda .max_memory_allocated ()
111
- cuda_memory_log = []
110
+ init_accelerator ()
111
+ device_module = getattr (torch , device , torch .cuda )
112
+ accelerator_memory_init = device_module .max_memory_allocated ()
113
+ accelerator_memory_log = []
112
114
113
115
tokenizer = AutoTokenizer .from_pretrained (model_id )
114
116
tokenizer .model_max_length = max_seq_length
@@ -177,8 +179,8 @@ def unpack(x):
177
179
loss .backward ()
178
180
optimizer .step ()
179
181
losses .append (loss .item ())
180
- cuda_memory_log .append (torch . cuda . memory_allocated () - cuda_memory_init )
181
- torch . cuda .empty_cache ()
182
+ accelerator_memory_log .append (device_module . memory_allocated () - accelerator_memory_init )
183
+ device_module .empty_cache ()
182
184
gc .collect ()
183
185
toc = time .perf_counter ()
184
186
print (f"step { i :3d} loss { loss .item ():.6f} time { toc - tic :.2f} s" , file = sys .stderr )
@@ -191,10 +193,10 @@ def unpack(x):
191
193
192
194
toc_total = time .perf_counter ()
193
195
194
- cuda_memory_final = torch . cuda .max_memory_allocated ()
195
- cuda_memory_avg = int (sum (cuda_memory_log ) / len (cuda_memory_log ))
196
- print (f"cuda memory avg: { cuda_memory_avg // 2 ** 20 } MB" )
197
- print (f"cuda memory max: { (cuda_memory_final - cuda_memory_init ) // 2 ** 20 } MB" )
196
+ accelerator_memory_final = device_module .max_memory_allocated ()
197
+ accelerator_memory_avg = int (sum (accelerator_memory_log ) / len (accelerator_memory_log ))
198
+ print (f"{ model . device . type } memory avg: { accelerator_memory_avg // 2 ** 20 } MB" )
199
+ print (f"{ model . device . type } memory max: { (accelerator_memory_final - accelerator_memory_init ) // 2 ** 20 } MB" )
198
200
print (f"total time: { toc_total - tic_total :.2f} s" )
199
201
200
202
with tempfile .TemporaryDirectory () as tmp_dir :
0 commit comments