77from .generate import stream_generate
88from .models .cache import make_prompt_cache
99from .sample_utils import make_sampler
10- from .utils import load
10+ from .utils import load , sharded_load
1111
1212DEFAULT_TEMP = 0.0
1313DEFAULT_TOP_P = 1.0
@@ -79,35 +79,54 @@ def setup_arg_parser():
7979 default = None ,
8080 help = "System prompt to be used for the chat template" ,
8181 )
82+ parser .add_argument (
83+ "--pipeline" ,
84+ action = "store_true" ,
85+ help = "Use pipelining instead of tensor parallelism" ,
86+ )
8287 return parser
8388
8489
8590def main ():
8691 parser = setup_arg_parser ()
8792 args = parser .parse_args ()
8893
94+ group = mx .distributed .init ()
95+ rank = group .rank ()
96+ pipeline_group = group if args .pipeline else None
97+ tensor_group = group if not args .pipeline else None
98+
99+ def rprint (* args , ** kwargs ):
100+ if rank == 0 :
101+ print (* args , ** kwargs )
102+
89103 if args .seed is not None :
90104 mx .random .seed (args .seed )
91105
92- model , tokenizer = load (
93- args .model ,
94- adapter_path = args .adapter_path ,
95- tokenizer_config = {
96- "trust_remote_code" : True if args .trust_remote_code else None
97- },
98- )
106+ if group .size () > 1 :
107+ if args .adapter_path :
108+ parser .error ("Adapters not supported in distributed mode" )
109+ model , tokenizer = sharded_load (args .model , pipeline_group , tensor_group )
110+ else :
111+ model , tokenizer = load (
112+ args .model ,
113+ adapter_path = args .adapter_path ,
114+ tokenizer_config = {
115+ "trust_remote_code" : True if args .trust_remote_code else None
116+ },
117+ )
99118
100119 def print_help ():
101- print ("The command list:" )
102- print ("- 'q' to exit" )
103- print ("- 'r' to reset the chat" )
104- print ("- 'h' to display these commands" )
120+ rprint ("The command list:" )
121+ rprint ("- 'q' to exit" )
122+ rprint ("- 'r' to reset the chat" )
123+ rprint ("- 'h' to display these commands" )
105124
106- print (f"[INFO] Starting chat session with { args .model } ." )
125+ rprint (f"[INFO] Starting chat session with { args .model } ." )
107126 print_help ()
108127 prompt_cache = make_prompt_cache (model , args .max_kv_size )
109128 while True :
110- query = input (">> " )
129+ query = input (">> " if rank == 0 else "" )
111130 if query == "q" :
112131 break
113132 if query == "r" :
@@ -120,7 +139,10 @@ def print_help():
120139 if args .system_prompt is not None :
121140 messages .append ({"role" : "system" , "content" : args .system_prompt })
122141 messages .append ({"role" : "user" , "content" : query })
123- prompt = tokenizer .apply_chat_template (messages , add_generation_prompt = True )
142+ prompt = tokenizer .apply_chat_template (
143+ messages ,
144+ add_generation_prompt = True ,
145+ )
124146 for response in stream_generate (
125147 model ,
126148 tokenizer ,
@@ -137,8 +159,8 @@ def print_help():
137159 ),
138160 prompt_cache = prompt_cache ,
139161 ):
140- print (response .text , flush = True , end = "" )
141- print ()
162+ rprint (response .text , flush = True , end = "" )
163+ rprint ()
142164
143165
144166if __name__ == "__main__" :
0 commit comments