Skip to content

enable xpu in test_trainer #37774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2970,6 +2970,21 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
"cpu": lambda: 0,
"default": lambda: 1,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"cpu": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"cpu": 0,
"default": 0,
Comment on lines +2980 to +2981
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be a callable, from the definition of _device_agnostic_dispatch which has

    if fn is None:
        return None
    return fn(*args, **kwargs)

could you double check this part?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, changes are required in similar places in this PR

Copy link
Contributor Author

@yao-matrix yao-matrix Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh I tried to enhance _device_agnostic_dispatch like below, since I think it's usable for cases needed a default value rather than a None. Pls let me know if it makes sense to you.

-# Some device agnostic functions return values. Need to guard against `None`
-# instead at user level.
-if fn is None:
-     return None
+# Some device agnostic functions return values or None, will return then directly.
+if not callable(fn):
+   return fn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, that could simplify stuff 👍

}
BACKEND_MEMORY_ALLOCATED = {
"cuda": torch.cuda.memory_allocated,
"cpu": 0,
"default": 0,
}
else:
BACKEND_MANUAL_SEED = {"default": None}
BACKEND_EMPTY_CACHE = {"default": None}
Expand All @@ -2993,6 +3008,9 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated

if is_torch_xla_available():
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
Expand All @@ -3012,6 +3030,18 @@ def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)


def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)


def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)


def backend_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)


if is_torch_available():
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
# into device to function mappings.
Expand Down
56 changes: 29 additions & 27 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
TemporaryHubRepo,
TestCasePlus,
backend_device_count,
backend_empty_cache,
backend_max_memory_allocated,
backend_memory_allocated,
backend_reset_max_memory_allocated,
evaluate_side_effect_factory,
execute_subprocess_async,
get_gpu_count,
Expand All @@ -78,7 +82,6 @@
require_liger_kernel,
require_lomo,
require_non_hpu,
require_non_xpu,
require_optuna,
require_peft,
require_ray,
Expand Down Expand Up @@ -244,18 +247,18 @@ def bytes2megabytes(x):
class TorchTracemalloc:
def __enter__(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
if torch_device in ["cuda", "xpu"]:
backend_empty_cache(torch_device)
backend_reset_max_memory_allocated(torch_device) # reset the peak gauge to zero
self.begin = backend_memory_allocated(torch_device)
return self

def __exit__(self, *exc):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
if torch_device in ["cuda", "xpu"]:
backend_empty_cache(torch_device)
self.end = backend_memory_allocated(torch_device)
self.peak = backend_max_memory_allocated(torch_device)
self.used = bytes2megabytes(self.end - self.begin)
self.peaked = bytes2megabytes(self.peak - self.begin)

Expand Down Expand Up @@ -1243,7 +1246,6 @@ def test_mixed_bf16(self):

# will add more specific tests once there are some bugs to fix

@require_non_xpu
@require_torch_gpu
@require_torch_tf32
def test_tf32(self):
Expand Down Expand Up @@ -1835,7 +1837,7 @@ def test_use_liger_kernel_trainer(self):
_ = trainer.train()

@require_lomo
@require_torch_gpu
@require_torch_accelerator
def test_lomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -1858,7 +1860,7 @@ def test_lomo(self):
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))

@require_lomo
@require_torch_gpu
@require_torch_accelerator
def test_adalomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand Down Expand Up @@ -2024,7 +2026,7 @@ def test_galore_matched_modules(self):
self.assertFalse(is_regex)

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2045,7 +2047,7 @@ def test_galore(self):
_ = trainer.train()

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2067,7 +2069,7 @@ def test_galore_extra_args(self):
_ = trainer.train()

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2088,7 +2090,7 @@ def test_galore_layerwise(self):
_ = trainer.train()

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2110,7 +2112,7 @@ def test_galore_layerwise_with_scheduler(self):
_ = trainer.train()

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2131,7 +2133,7 @@ def test_galore_adamw_8bit(self):
_ = trainer.train()

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adafactor(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
Expand Down Expand Up @@ -2163,7 +2165,7 @@ def test_galore_adafactor(self):
self.assertTrue(lower_bound_pm < galore_peak_memory)

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adafactor_attention_only(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
Expand Down Expand Up @@ -2194,7 +2196,7 @@ def test_galore_adafactor_attention_only(self):
self.assertTrue(lower_bound_pm < galore_peak_memory)

@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adafactor_all_linear(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
Expand Down Expand Up @@ -2302,7 +2304,7 @@ def test_galore_lr_display_with_scheduler(self):
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))

@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2323,7 +2325,7 @@ def test_apollo(self):
_ = trainer.train()

@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2345,7 +2347,7 @@ def test_apollo_extra_args(self):
_ = trainer.train()

@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2366,7 +2368,7 @@ def test_apollo_layerwise(self):
_ = trainer.train()

@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2388,7 +2390,7 @@ def test_apollo_layerwise_with_scheduler(self):
_ = trainer.train()

@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand All @@ -2413,7 +2415,7 @@ def test_apollo_lr_display_without_scheduler(self):
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])

@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
Expand Down Expand Up @@ -3992,7 +3994,7 @@ def test_fp16_full_eval(self):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)

@require_non_xpu
@require_torch_gpu
@require_torch_non_multi_gpu
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):
Expand Down