-
Notifications
You must be signed in to change notification settings - Fork 459
add test_patch_distributed.py #2944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add test_patch_distributed.py #2944
Conversation
Signed-off-by: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds unit tests for distributed patches. The tests are a good addition, but I've found a couple of issues. One test case has a bug that will cause a NameError
and also suffers from test pollution by not cleaning up monkey-patched global objects. Another test is incomplete as it doesn't verify an important side effect of the function under test, which could hide a potential bug. I've provided suggestions to fix these issues to make the test suite more robust and reliable.
class TestCommunicationAdaptation(unittest.TestCase): | ||
def setUp(self): | ||
import torch.distributed as dist | ||
from vllm_ascend.patch.platform.patch_common.patch_distributed import communication_adaptation_310p | ||
|
||
self.original_broadcast = dist.broadcast | ||
self.original_all_reduce = dist.all_reduce | ||
|
||
|
||
def test_communication_adaptation_310p(self): | ||
import torch.distributed as dist | ||
from vllm_ascend.patch.platform.patch_common.patch_distributed import is_310p | ||
if is_310p(): | ||
communication_adaptation_310p() | ||
self.assertNotEqual(dist.broadcast, self.original_broadcast) | ||
self.assertNotEqual(dist.all_reduce, self.original_all_reduce) | ||
else: | ||
self.assertEqual(dist.broadcast, self.original_broadcast) | ||
self.assertEqual(dist.all_reduce, self.original_all_reduce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case has two issues:
-
Test Pollution: It modifies global state by patching
torch.distributed.broadcast
andtorch.distributed.all_reduce
(and theirdistributed_c10d
counterparts), but it does not clean up these changes after the test runs. This can lead to test pollution, where the outcome of subsequent tests is affected, making test failures difficult to debug. -
NameError Bug: The function
communication_adaptation_310p
is imported withinsetUp
but used intest_communication_adaptation_310p
, which will cause aNameError
because the import is not in the correct scope.
To ensure test isolation and fix the bug, you should:
- Move the import of
communication_adaptation_310p
intotest_communication_adaptation_310p
. - In
setUp
, also save the originaltorch.distributed.distributed_c10d.broadcast
andtorch.distributed.distributed_c10d.all_reduce
functions. - Implement a
tearDown
method to restore all patched functions to their original state after the test completes. - Update the test assertions to also verify the patching of the
distributed_c10d
functions for completeness.
class TestCommunicationAdaptation(unittest.TestCase):
def setUp(self):
import torch.distributed as dist
self.original_broadcast = dist.broadcast
self.original_all_reduce = dist.all_reduce
self.original_c10d_broadcast = dist.distributed_c10d.broadcast
self.original_c10d_all_reduce = dist.distributed_c10d.all_reduce
def tearDown(self):
import torch.distributed as dist
dist.broadcast = self.original_broadcast
dist.all_reduce = self.original_all_reduce
dist.distributed_c10d.broadcast = self.original_c10d_broadcast
dist.distributed_c10d.all_reduce = self.original_c10d_all_reduce
def test_communication_adaptation_310p(self):
import torch.distributed as dist
from vllm_ascend.patch.platform.patch_common.patch_distributed import (
communication_adaptation_310p, is_310p)
if is_310p():
communication_adaptation_310p()
self.assertNotEqual(dist.broadcast, self.original_broadcast)
self.assertNotEqual(dist.all_reduce, self.original_all_reduce)
self.assertNotEqual(dist.distributed_c10d.broadcast,
self.original_c10d_broadcast)
self.assertNotEqual(dist.distributed_c10d.all_reduce,
self.original_c10d_all_reduce)
else:
self.assertEqual(dist.broadcast, self.original_broadcast)
self.assertEqual(dist.all_reduce, self.original_all_reduce)
self.assertEqual(dist.distributed_c10d.broadcast,
self.original_c10d_broadcast)
self.assertEqual(dist.distributed_c10d.all_reduce,
self.original_c10d_all_reduce)
config = ParallelConfig() | ||
config.data_parallel_master_port = 29500 | ||
|
||
port = config.get_next_dp_init_port() | ||
self.assertEqual(port, 30000) | ||
|
||
port2 = config.get_next_dp_init_port() | ||
self.assertEqual(port2, 30000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test correctly checks that the port from the environment variable is used. However, the underlying implementation in parallel_config_get_dp_port
has an undesirable side effect: it increments config.data_parallel_master_port
even when its value is not used because the environment variable is present. This can lead to unexpected behavior, such as non-contiguous port allocation if the environment variable is not always set.
To make this behavior explicit and prevent future regressions, the test should also assert the state of config.data_parallel_master_port
. This will document the current behavior and highlight the side effect. Ideally, the implementation should be refactored to avoid modifying state unnecessarily.
config = ParallelConfig() | |
config.data_parallel_master_port = 29500 | |
port = config.get_next_dp_init_port() | |
self.assertEqual(port, 30000) | |
port2 = config.get_next_dp_init_port() | |
self.assertEqual(port2, 30000) | |
config = ParallelConfig() | |
config.data_parallel_master_port = 29500 | |
port = config.get_next_dp_init_port() | |
self.assertEqual(port, 30000) | |
# The current implementation has a side effect of incrementing the port | |
# even when the environment variable is used. This makes it explicit. | |
self.assertEqual(config.data_parallel_master_port, 29501) | |
port2 = config.get_next_dp_init_port() | |
self.assertEqual(port2, 30000) | |
self.assertEqual(config.data_parallel_master_port, 29502) |
Signed-off-by: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?