Skip to content

DeepSpeed ZeRO Stage 2/3 incompatibility with no_sync context manager #3481

Open
@alisafaya

Description

@alisafaya

When using DeepSpeed with ZeRO stages 2 or 3, the no_sync() context manager in the Accelerator class causes an assertion error. This happens because DeepSpeed's ZeRO gradient partitioning requires gradient reduction, which conflicts with the purpose of no_sync() that disables gradient synchronization.

See this comment and issue for more details.

Current behavior

@contextmanager
def no_sync(self, model):
    context = contextlib.nullcontext
    if self.use_distributed:
        context = getattr(model, "no_sync", context)

    with context():
        yield

This implementation doesn't account for DeepSpeed ZeRO stages 2/3, leading to the following error:

AssertionError: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 2

Similar issues are:

Proposed solution

Add a condition to check for DeepSpeed with ZeRO stages 2 or 3:

@contextmanager
def no_sync(self, model):
    context = contextlib.nullcontext
    if self.use_distributed:
        if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
            context = getattr(model, "no_sync", context)

    with context():
        yield

This ensures we only use model.no_sync() in compatible scenarios and fall back to nullcontext() when using DeepSpeed ZeRO stage 2 or 3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions