File tree Expand file tree Collapse file tree 6 files changed +22
-19
lines changed Expand file tree Collapse file tree 6 files changed +22
-19
lines changed Original file line number Diff line number Diff line change 9
9
- ' pyproject.toml'
10
10
- ' .github/workflows/multi-gpu-e2e.yml'
11
11
- ' src/axolotl/core/trainers/mixins/sequence_parallel.py'
12
+ - ' src/axolotl/utils/distributed.py'
12
13
workflow_dispatch :
13
14
schedule :
14
15
- cron : ' 0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
Original file line number Diff line number Diff line change @@ -27,6 +27,9 @@ concurrency:
27
27
group : ${{ github.workflow }}-${{ github.ref }}
28
28
cancel-in-progress : ${{ github.ref != 'refs/heads/main' }}
29
29
30
+ env :
31
+ TRANSFORMERS_IS_CI : " yes"
32
+
30
33
jobs :
31
34
pre-commit :
32
35
name : pre-commit
Original file line number Diff line number Diff line change 25
25
26
26
from axolotl .integrations .base import BasePlugin
27
27
from axolotl .utils import get_pytorch_version
28
- from axolotl .utils .distributed import zero_only
28
+ from axolotl .utils .distributed import is_main_process
29
29
30
30
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
31
31
@@ -76,7 +76,7 @@ def pre_model_load(self, cfg):
76
76
cce_patch ,
77
77
)
78
78
79
- with zero_only ( ):
79
+ if is_main_process ( use_environ = True ):
80
80
LOG .info (
81
81
f"Applying Cut Cross Entropy to model type: { cfg .model_config_type } "
82
82
)
Original file line number Diff line number Diff line change 23
23
import sys
24
24
25
25
from axolotl .integrations .base import BasePlugin
26
+ from axolotl .utils .distributed import is_main_process
26
27
27
- from ...utils .distributed import zero_only
28
28
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
29
29
from .utils import patch_with_compile_disable
30
30
@@ -85,7 +85,7 @@ def pre_model_load(self, cfg):
85
85
kwargs ["geglu" ] = cfg .liger_glu_activation
86
86
elif "swiglu" in liger_fn_sig .parameters :
87
87
kwargs ["swiglu" ] = cfg .liger_glu_activation
88
- with zero_only ( ):
88
+ if is_main_process ( use_environ = True ):
89
89
LOG .info (
90
90
f"Applying LIGER to { cfg .model_config_type } with kwargs: { kwargs } "
91
91
)
Original file line number Diff line number Diff line change @@ -69,17 +69,27 @@ def barrier():
69
69
dist .barrier ()
70
70
71
71
72
- def is_main_process ():
72
+ def is_main_process (use_environ = False ):
73
73
"""
74
74
Check if the current process is the main process. If not in distributed mode,
75
75
always return `True`.
76
+
77
+ Args:
78
+ - use_environ (bool, optional): Use environment variable to determine main process.
79
+
80
+ Returns:
81
+ - bool: `True` if the current process is the main process, `False` otherwise.
76
82
"""
83
+ if use_environ :
84
+ return os .environ .get ("LOCAL_RANK" , "0" ) == "0"
77
85
if not is_distributed ():
78
86
return True
79
87
return dist .get_rank () == 0
80
88
81
89
82
- def is_local_main_process ():
90
+ def is_local_main_process (use_environ = False ):
91
+ if use_environ :
92
+ return os .environ .get ("LOCAL_RANK" , "0" ) == "0"
83
93
return PartialState ().is_local_main_process
84
94
85
95
@@ -99,17 +109,6 @@ def cleanup_distributed():
99
109
torch .distributed .destroy_process_group ()
100
110
101
111
102
- @contextmanager
103
- def zero_only ():
104
- """
105
- Context manager that only runs the enclosed block on the main rank.
106
- """
107
- if is_main_process ():
108
- yield
109
- else :
110
- yield None
111
-
112
-
113
112
@contextmanager
114
113
def zero_first (is_main ):
115
114
"""
Original file line number Diff line number Diff line change 68
68
get_device_count ,
69
69
get_device_type ,
70
70
is_local_main_process ,
71
- zero_only ,
71
+ is_main_process ,
72
72
)
73
73
from axolotl .utils .gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
74
74
from axolotl .utils .lora_embeddings import get_linear_embedding_layers
@@ -437,7 +437,7 @@ def load_tokenizer(cfg):
437
437
{"additional_special_tokens" : additional_special_tokens }
438
438
)
439
439
440
- with zero_only ( ):
440
+ if is_main_process ( use_environ = True ):
441
441
LOG .debug (f"EOS: { tokenizer .eos_token_id } / { tokenizer .eos_token } " )
442
442
LOG .debug (f"BOS: { tokenizer .bos_token_id } / { tokenizer .bos_token } " )
443
443
LOG .debug (f"PAD: { tokenizer .pad_token_id } / { tokenizer .pad_token } " )
You can’t perform that action at this time.
0 commit comments