Skip to content

Commit 5e949ea

Browse files
authored
replace zero_only with simpler if statement (axolotl-ai-cloud#2592)
1 parent 89ca14d commit 5e949ea

File tree

6 files changed

+22
-19
lines changed

6 files changed

+22
-19
lines changed

.github/workflows/multi-gpu-e2e.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ on:
99
- 'pyproject.toml'
1010
- '.github/workflows/multi-gpu-e2e.yml'
1111
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
12+
- 'src/axolotl/utils/distributed.py'
1213
workflow_dispatch:
1314
schedule:
1415
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

.github/workflows/tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ concurrency:
2727
group: ${{ github.workflow }}-${{ github.ref }}
2828
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
2929

30+
env:
31+
TRANSFORMERS_IS_CI: "yes"
32+
3033
jobs:
3134
pre-commit:
3235
name: pre-commit

src/axolotl/integrations/cut_cross_entropy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from axolotl.integrations.base import BasePlugin
2727
from axolotl.utils import get_pytorch_version
28-
from axolotl.utils.distributed import zero_only
28+
from axolotl.utils.distributed import is_main_process
2929

3030
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
3131

@@ -76,7 +76,7 @@ def pre_model_load(self, cfg):
7676
cce_patch,
7777
)
7878

79-
with zero_only():
79+
if is_main_process(use_environ=True):
8080
LOG.info(
8181
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
8282
)

src/axolotl/integrations/liger/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import sys
2424

2525
from axolotl.integrations.base import BasePlugin
26+
from axolotl.utils.distributed import is_main_process
2627

27-
from ...utils.distributed import zero_only
2828
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
2929
from .utils import patch_with_compile_disable
3030

@@ -85,7 +85,7 @@ def pre_model_load(self, cfg):
8585
kwargs["geglu"] = cfg.liger_glu_activation
8686
elif "swiglu" in liger_fn_sig.parameters:
8787
kwargs["swiglu"] = cfg.liger_glu_activation
88-
with zero_only():
88+
if is_main_process(use_environ=True):
8989
LOG.info(
9090
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
9191
)

src/axolotl/utils/distributed.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,27 @@ def barrier():
6969
dist.barrier()
7070

7171

72-
def is_main_process():
72+
def is_main_process(use_environ=False):
7373
"""
7474
Check if the current process is the main process. If not in distributed mode,
7575
always return `True`.
76+
77+
Args:
78+
- use_environ (bool, optional): Use environment variable to determine main process.
79+
80+
Returns:
81+
- bool: `True` if the current process is the main process, `False` otherwise.
7682
"""
83+
if use_environ:
84+
return os.environ.get("LOCAL_RANK", "0") == "0"
7785
if not is_distributed():
7886
return True
7987
return dist.get_rank() == 0
8088

8189

82-
def is_local_main_process():
90+
def is_local_main_process(use_environ=False):
91+
if use_environ:
92+
return os.environ.get("LOCAL_RANK", "0") == "0"
8393
return PartialState().is_local_main_process
8494

8595

@@ -99,17 +109,6 @@ def cleanup_distributed():
99109
torch.distributed.destroy_process_group()
100110

101111

102-
@contextmanager
103-
def zero_only():
104-
"""
105-
Context manager that only runs the enclosed block on the main rank.
106-
"""
107-
if is_main_process():
108-
yield
109-
else:
110-
yield None
111-
112-
113112
@contextmanager
114113
def zero_first(is_main):
115114
"""

src/axolotl/utils/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
get_device_count,
6969
get_device_type,
7070
is_local_main_process,
71-
zero_only,
71+
is_main_process,
7272
)
7373
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
7474
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
@@ -437,7 +437,7 @@ def load_tokenizer(cfg):
437437
{"additional_special_tokens": additional_special_tokens}
438438
)
439439

440-
with zero_only():
440+
if is_main_process(use_environ=True):
441441
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
442442
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
443443
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")

0 commit comments

Comments
 (0)