Skip to content

enable internvl UTs on XPU #37779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 98 additions & 15 deletions tests/models/internvl/test_modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
is_vision_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
require_av,
require_bitsandbytes,
require_deterministic_for_xpu,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -177,7 +179,7 @@ def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_
model = InternVLForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
with torch.autocast(device_type=torch_device, dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -279,7 +281,7 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class InternVLQwen2IntegrationTest(unittest.TestCase):
def setUp(self):
self.small_model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
Expand Down Expand Up @@ -326,14 +328,22 @@ def test_qwen2_small_model_integration_forward(self):
output = model(**inputs)

actual_logits = output.logits[0, -1, :5].cpu()
expected_logits = torch.tensor([11.9375, 14.8750, 14.0625, 10.7500, 6.9062], dtype=torch.bfloat16)
expected_logits_all = Expectations(
{
("xpu", 3): torch.tensor([11.7500, 14.7500, 14.1250, 10.5625, 6.7812], dtype=torch.bfloat16),
("cuda", 7): torch.tensor([11.9375, 14.8750, 14.0625, 10.7500, 6.9062], dtype=torch.bfloat16),
}
)
expected_logits = expected_logits_all.get_expectation()

self.assertTrue(
torch.allclose(actual_logits, expected_logits, atol=0.1),
f"Actual logits: {actual_logits}"
f"\nExpected logits: {expected_logits}"
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
)

@require_deterministic_for_xpu
def test_qwen2_small_model_integration_generate_text_only(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
Expand All @@ -346,7 +356,15 @@ def test_qwen2_small_model_integration_generate_text_only(self):
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."

expected_outputs = Expectations(
{
("xpu", 3): "Whispers of dawn,\nSilent whispers of the night,\nNew day's light.",
("cuda", 7): "Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins.",
}
)
expected_output = expected_outputs.get_expectation()

self.assertEqual(decoded_output, expected_output)

def test_qwen2_small_model_integration_generate_chat_template(self):
Expand Down Expand Up @@ -375,6 +393,7 @@ def test_qwen2_small_model_integration_generate_chat_template(self):
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
self.assertEqual(decoded_output, expected_output)

@require_deterministic_for_xpu
def test_qwen2_small_model_integration_batched_generate(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
Expand Down Expand Up @@ -404,7 +423,15 @@ def test_qwen2_small_model_integration_batched_generate(self):
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of' # fmt: skip

expected_outputs = Expectations(
{
("xpu", 3): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate"',
("cuda", 7): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()

self.assertEqual(
decoded_output,
expected_output,
Expand Down Expand Up @@ -455,7 +482,14 @@ def test_qwen2_small_model_integration_batched_generate_multi_image(self):

# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Angle' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): 'user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Foreground',
("cuda", 7): 'user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Angle',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()

self.assertEqual(
decoded_output,
expected_output,
Expand Down Expand Up @@ -495,14 +529,21 @@ def test_qwen2_medium_model_integration_video(self):
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)

decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
expected_output = 'The man is performing a forehand shot.' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): "The man is performing a volley.",
("cuda", 7): "The man is performing a forehand shot.",
}
)
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)

@require_av
@require_deterministic_for_xpu
def test_qwen2_small_model_integration_interleaved_images_videos(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
Expand Down Expand Up @@ -564,15 +605,27 @@ def test_qwen2_small_model_integration_interleaved_images_videos(self):

decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
expected_output = 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image**: This shows the Statue of Liberty on Liberty Island, with the' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image:**\n - The Statue of Liberty is prominently featured on an',
("cuda", 7): 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image**: This shows the Statue of Liberty on Liberty Island, with the',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot.',
("cuda", 7): 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
Expand All @@ -590,7 +643,7 @@ def test_qwen2_small_model_integration_interleaved_images_videos(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class InternVLLlamaIntegrationTest(unittest.TestCase):
def setUp(self):
self.small_model_checkpoint = "OpenGVLab/InternVL2_5-2B-MPO-hf"
Expand Down Expand Up @@ -711,7 +764,13 @@ def test_llama_small_model_integration_batched_generate(self):

# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden path leads to calm lake,\nNature\'s peaceful grace.',
("cuda", 7): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
Expand Down Expand Up @@ -880,7 +939,19 @@ def test_llama_small_model_integration_interleaved_images_videos(self):

decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
expected_output = 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **' # fmt: skip
expected_outputs = Expectations(
{
(
"xpu",
3,
): "user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After re-examining the images, I can see that they are actually",
(
"cuda",
7,
): "user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **",
}
)
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
Expand All @@ -889,7 +960,13 @@ def test_llama_small_model_integration_interleaved_images_videos(self):

# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot. This is a common shot in tennis where the player swings the racket across their' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot. This is a common shot in tennis where the player swings the racket across their',
("cuda", 7): 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot. This is a common shot in tennis where the player swings the racket across their',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
Expand All @@ -898,7 +975,13 @@ def test_llama_small_model_integration_interleaved_images_videos(self):

# Check third output
decoded_output = processor.decode(output[2], skip_special_tokens=True)
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, untouched dreams.' # fmt: skip
expected_outputs = Expectations(
{
("xpu", 3): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.',
("cuda", 7): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, untouched dreams.',
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
self.assertEqual(
decoded_output,
expected_output,
Expand Down
8 changes: 3 additions & 5 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
require_bitsandbytes,
require_deepspeed,
require_flash_attn,
require_non_xpu,
require_safetensors,
require_torch,
require_torch_accelerator,
Expand Down Expand Up @@ -2604,7 +2603,7 @@ def test_inputs_embeds_matches_input_ids(self):
)[0]
torch.testing.assert_close(out_embeds, out_ids)

@require_non_xpu
@require_torch_gpu
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down Expand Up @@ -3874,7 +3873,6 @@ def test_sdpa_can_dispatch_on_flash(self):
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)

@require_non_xpu
@require_torch_sdpa
@require_torch_accelerator
@slow
Expand All @@ -3887,8 +3885,8 @@ def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
elif device_type not in ["cuda", "rocm", "xpu"]:
self.skipTest(reason="This test requires a Nvidia or AMD GPU or an Intel XPU")

torch.compiler.reset()

Expand Down