Skip to content

Commit f9c7c3b

Browse files
authored
don't use is_main_process during config validation (axolotl-ai-cloud#2569)
1 parent caf5cb6 commit f9c7c3b

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

.github/workflows/multi-gpu-e2e.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on:
88
- 'setup.py'
99
- 'pyproject.toml'
1010
- '.github/workflows/multi-gpu-e2e.yml'
11+
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
1112
workflow_dispatch:
1213
schedule:
1314
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

src/axolotl/utils/schemas/config.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from transformers.utils.import_utils import is_torch_npu_available
2020

21-
from axolotl.utils.distributed import is_main_process
2221
from axolotl.utils.schemas.datasets import (
2322
DatasetConfig,
2423
DPODataset,
@@ -719,10 +718,9 @@ def check_eval_packing(cls, data):
719718
and data.get("eval_sample_packing") is None
720719
and not data.get("eval_table_size")
721720
):
722-
if is_main_process():
723-
LOG.info(
724-
"explicitly setting `eval_sample_packing` to match `sample_packing`"
725-
)
721+
LOG.info(
722+
"explicitly setting `eval_sample_packing` to match `sample_packing`"
723+
)
726724
data["eval_sample_packing"] = True
727725

728726
if (
@@ -1179,15 +1177,14 @@ def check_sequence_parallel_degree(self):
11791177
# TODO: monkeypatch / callback to average losses correctly across SP ranks
11801178
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
11811179
# according to the proportion of non-padding tokens per rank.
1182-
if is_main_process():
1183-
LOG.warning(
1184-
"Sequence parallelism (SP) is enabled with "
1185-
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
1186-
"Please note that logged losses may differ slightly to the non-SP "
1187-
"losses due to transformers Trainer implementation details. "
1188-
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
1189-
"for more details."
1190-
)
1180+
LOG.warning(
1181+
"Sequence parallelism (SP) is enabled with "
1182+
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
1183+
"Please note that logged losses may differ slightly to the non-SP "
1184+
"losses due to transformers Trainer implementation details. "
1185+
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
1186+
"for more details."
1187+
)
11911188

11921189
return self
11931190

src/axolotl/utils/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
528528
def setup_deepspeed_env(cfg, stage=None):
529529
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
530530

531+
from axolotl.utils.distributed import distributed_state
532+
533+
if distributed_state and distributed_state.initialized:
534+
raise RuntimeError(
535+
"Distributed State already initialized before Deepspeed setup"
536+
)
537+
531538
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
532539
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
533540
if stage:

tests/e2e/patched/test_sp.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,6 @@ def setup_mocks(self, monkeypatch):
131131
# Mock the ring_flash_attn module
132132
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
133133

134-
# Mock the is_main_process function to return True
135-
monkeypatch.setattr(
136-
"axolotl.utils.schemas.config.is_main_process", lambda: True
137-
)
138-
139134
@pytest.fixture
140135
def base_cfg(self):
141136
"""Create a base configuration for testing."""

0 commit comments

Comments
 (0)