|
3 | 3 |
|
4 | 4 | This script is used to launch the Parallax server. |
5 | 5 | It will start the following services: |
6 | | - 1.Executor with tp_rank=0 in the main process. |
7 | | - 2.Executor with tp_rank>0, each tp_rank as a subprocess. |
8 | | - 3.HTTP server as a subprocess. |
9 | | - 4.P2P server as a thread in the main process. |
| 6 | + 1.Executor each tp_rank as a subprocess. |
| 7 | + 2.HTTP server as a subprocess. |
| 8 | + 3.P2P server as a subprocess. |
10 | 9 |
|
11 | 10 | Example command: |
12 | 11 | python src/parallax/launch.py \ |
|
21 | 20 | import multiprocessing |
22 | 21 | import os |
23 | 22 | import tempfile |
24 | | -import threading |
25 | | - |
26 | | -from parallax.p2p.server import ServerState, launch_p2p_server |
27 | | -from parallax.server.executor import ( |
28 | | - Executor, |
29 | | - run_executor_process, |
30 | | - stop_executor_process, |
31 | | -) |
| 23 | +import time |
| 24 | + |
| 25 | +from parallax.p2p.server import ServerState, launch_p2p_server_process, stop_p2p_server |
| 26 | +from parallax.server.executor import run_executor_process, stop_executor_process |
32 | 27 | from parallax.server.http_server import launch_http_server, stop_http_server |
33 | 28 | from parallax.server.server_args import parse_args |
| 29 | +from parallax.utils.shared_state import SharedState |
34 | 30 | from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port |
35 | 31 | from parallax_utils.ascii_anime import display_parallax_join |
36 | 32 | from parallax_utils.logging_config import get_logger, set_log_level |
|
39 | 35 | logger = get_logger("parallax.launch") |
40 | 36 |
|
41 | 37 |
|
| 38 | +def _update_args_from_shared_state(args, shared_state: SharedState): |
| 39 | + """Update args with layer allocation from shared state""" |
| 40 | + model_info = shared_state.get_model_info() |
| 41 | + args.start_layer = model_info["block_start_index"] |
| 42 | + args.end_layer = model_info["block_end_index"] |
| 43 | + # Update model_path if provided and not already set |
| 44 | + if model_info["model_name"] and args.model_path is None: |
| 45 | + args.model_path = model_info["model_name"] |
| 46 | + # Update tp_size if provided, otherwise keep current value |
| 47 | + args.tp_size = model_info["tp_size"] or args.tp_size |
| 48 | + |
| 49 | + |
| 50 | +def _stop_executor_processes(executor_subprocs): |
| 51 | + """Stop all executor processes""" |
| 52 | + for executor_process in executor_subprocs: |
| 53 | + if executor_process.is_alive(): |
| 54 | + logger.debug(f"Terminating executor process {executor_process.pid}") |
| 55 | + stop_executor_process(executor_process) |
| 56 | + |
| 57 | + |
| 58 | +def _wait_executors_check_layer_change(shared_state: SharedState, executor_subprocs): |
| 59 | + """Wait for executor processes and check if layer allocation changed. |
| 60 | +
|
| 61 | + Returns: |
| 62 | + True if layer allocation changed (need to reload executors), |
| 63 | + False if all executors exited normally. |
| 64 | + """ |
| 65 | + while any(proc.is_alive() for proc in executor_subprocs): |
| 66 | + for proc in executor_subprocs: |
| 67 | + if proc.is_alive(): |
| 68 | + proc.join(timeout=1.0) # Check every second |
| 69 | + |
| 70 | + if shared_state.get_layer_allocation_changed(): |
| 71 | + return True |
| 72 | + |
| 73 | + # Check race condition: layer allocation changed after all processes exited |
| 74 | + return shared_state.get_layer_allocation_changed() |
| 75 | + |
| 76 | + |
42 | 77 | if __name__ == "__main__": |
43 | 78 | multiprocessing.set_start_method("spawn", force=True) |
44 | 79 |
|
45 | | - gradient_server = None |
| 80 | + p2p_server_process = None |
46 | 81 | http_server_process = None |
47 | | - executor = None |
48 | 82 | executor_subprocs = [] |
| 83 | + # Shared state for layer allocation info (used when P2P server is in subprocess) |
| 84 | + shared_state = SharedState.create() |
| 85 | + shared_state.set_status(ServerState.JOINING.value) |
| 86 | + |
49 | 87 | try: |
50 | 88 | args = parse_args() |
51 | 89 | set_log_level(args.log_level) |
|
72 | 110 | # only launch http server on head node |
73 | 111 | if args.start_layer == 0: |
74 | 112 | http_server_process = launch_http_server(args) |
75 | | - launch_p2p_server( |
| 113 | + # Launch P2P server as subprocess |
| 114 | + p2p_server_process = launch_p2p_server_process( |
76 | 115 | initial_peers=args.initial_peers, |
77 | 116 | scheduler_addr=args.scheduler_addr, |
78 | 117 | relay_servers=args.relay_servers, |
|
93 | 132 | max_sequence_length=args.max_sequence_length, |
94 | 133 | param_mem_ratio=args.param_mem_ratio, |
95 | 134 | kvcache_mem_ratio=args.kvcache_mem_ratio, |
| 135 | + shared_state=shared_state.dict, # Pass dict to subprocess |
| 136 | + log_level=args.log_level, |
96 | 137 | ) |
97 | | - if gradient_server is not None: |
98 | | - gradient_server.status = ServerState.READY |
99 | 138 |
|
100 | | - # For each tp_rank > 0, create a subprocess and run executor |
101 | | - for tp_rank in range(1, args.tp_size): |
| 139 | + # Launch all executor processes (including tp_rank=0) |
| 140 | + for tp_rank in range(args.tp_size): |
102 | 141 | args_copy = argparse.Namespace(**vars(args)) |
103 | 142 | args_copy.tp_rank = tp_rank |
104 | 143 | proc = multiprocessing.Process( |
105 | 144 | target=run_executor_process, |
106 | | - args=(args_copy,), |
| 145 | + args=( |
| 146 | + args_copy, |
| 147 | + shared_state.dict, # Pass dict to subprocess |
| 148 | + ), |
107 | 149 | ) |
108 | 150 | proc.start() |
109 | 151 | executor_subprocs.append(proc) |
110 | | - # Launch executor with tp_rank=0 in the main process |
111 | | - args.tp_rank = 0 |
112 | | - executor = Executor.create_from_args(args) |
113 | | - executor.run_loop() |
| 152 | + |
| 153 | + time.sleep(2) # Give executors time to start |
| 154 | + shared_state.set_status(ServerState.READY.value) |
| 155 | + |
| 156 | + # Wait for all executor processes |
| 157 | + for proc in executor_subprocs: |
| 158 | + proc.join() |
114 | 159 | else: |
115 | | - gradient_server = launch_p2p_server( |
| 160 | + # Launch P2P server as subprocess (with scheduler) |
| 161 | + # Pass dict to subprocess (multiprocessing requires serializable objects) |
| 162 | + p2p_server_process = launch_p2p_server_process( |
116 | 163 | initial_peers=args.initial_peers, |
117 | 164 | scheduler_addr=args.scheduler_addr, |
118 | 165 | relay_servers=args.relay_servers, |
|
133 | 180 | max_sequence_length=args.max_sequence_length, |
134 | 181 | param_mem_ratio=args.param_mem_ratio, |
135 | 182 | kvcache_mem_ratio=args.kvcache_mem_ratio, |
| 183 | + shared_state=shared_state.dict, # Pass dict to subprocess |
| 184 | + log_level=args.log_level, |
136 | 185 | ) |
137 | | - args.start_layer = gradient_server.block_start_index |
138 | | - args.end_layer = gradient_server.block_end_index |
139 | | - # Only read model_name from scheduler if model_path is not set, so we can use local path as model_path |
140 | | - if args.model_path is None: |
141 | | - args.model_path = gradient_server.model_name |
142 | | - args.tp_size = gradient_server.tp_size |
| 186 | + |
| 187 | + # Wait for layer allocation from scheduler (via shared state) |
| 188 | + logger.debug("Waiting for layer allocation from scheduler...") |
| 189 | + max_wait_time = 300 # 5 minutes |
| 190 | + wait_start = time.time() |
| 191 | + while True: |
| 192 | + model_info = shared_state.get_model_info() |
| 193 | + if ( |
| 194 | + model_info["block_start_index"] is not None |
| 195 | + and model_info["block_end_index"] is not None |
| 196 | + and model_info["model_name"] is not None |
| 197 | + ): |
| 198 | + break |
| 199 | + if time.time() - wait_start > max_wait_time: |
| 200 | + logger.error("Timeout waiting for layer allocation from scheduler") |
| 201 | + raise RuntimeError("Failed to get layer allocation from scheduler") |
| 202 | + time.sleep(1) |
| 203 | + |
| 204 | + # Get layer allocation from shared state |
| 205 | + _update_args_from_shared_state(args, shared_state) |
143 | 206 |
|
144 | 207 | logger.debug( |
145 | | - f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}" |
| 208 | + f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}, " |
| 209 | + f"model: {args.model_path}" |
146 | 210 | ) |
147 | | - gradient_server.status = ServerState.INITIALIZING |
148 | 211 |
|
149 | 212 | if args.log_level != "DEBUG": |
150 | 213 | display_parallax_join(args.model_path) |
|
157 | 220 | # Main execution loop with layer reallocation support |
158 | 221 | while True: |
159 | 222 | try: |
160 | | - # For each tp_rank > 0, create a subprocess and run executor |
161 | | - for tp_rank in range(1, args.tp_size): |
| 223 | + # Launch all executor processes (including tp_rank=0) |
| 224 | + executor_subprocs = [] |
| 225 | + for tp_rank in range(args.tp_size): |
162 | 226 | args_copy = argparse.Namespace(**vars(args)) |
163 | 227 | args_copy.tp_rank = tp_rank |
164 | 228 | proc = multiprocessing.Process( |
165 | 229 | target=run_executor_process, |
166 | | - args=(args_copy,), |
| 230 | + args=( |
| 231 | + args_copy, |
| 232 | + shared_state.dict, # Pass dict to subprocess |
| 233 | + ), |
167 | 234 | ) |
168 | 235 | proc.start() |
169 | 236 | executor_subprocs.append(proc) |
170 | | - # Launch executor with tp_rank=0 in the main process |
171 | | - args.tp_rank = 0 |
172 | | - executor = Executor.create_from_args(args, gradient_server=gradient_server) |
173 | | - if gradient_server is not None: |
174 | | - gradient_server.status = ServerState.READY |
175 | | - |
176 | | - executor.run_loop() |
177 | | - |
178 | | - # Check if layer allocation changed (executor exited due to reallocation) |
179 | | - if gradient_server is not None and gradient_server._layer_allocation_changed: |
180 | | - logger.warning( |
181 | | - "Layer allocation changed! Reloading executor with new layers..." |
182 | | - ) |
183 | | - |
184 | | - # shutdown all executor processes |
185 | | - thread_pool = [] |
186 | | - for executor_process in executor_subprocs: |
187 | | - t = threading.Thread( |
188 | | - target=stop_executor_process, args=(executor_process,) |
189 | | - ) |
190 | | - t.start() |
191 | | - thread_pool.append(t) |
192 | | - executor.shutdown() |
193 | | - for t in thread_pool: |
194 | | - t.join() |
195 | | - |
196 | | - if args.start_layer == 0: |
197 | | - http_server_process = stop_http_server(http_server_process) |
198 | | - if gradient_server.block_start_index == 0: |
199 | | - http_server_process = launch_http_server(args) |
200 | | - |
201 | | - # Update args with new layer allocation |
202 | | - args.start_layer = gradient_server.block_start_index |
203 | | - args.end_layer = gradient_server.block_end_index |
204 | | - if gradient_server.model_name: |
205 | | - args.model_path = gradient_server.model_name |
206 | 237 |
|
| 238 | + # Wait for executors and restart if layer allocation changes |
| 239 | + if _wait_executors_check_layer_change(shared_state, executor_subprocs): |
| 240 | + logger.warning("Layer allocation changed! Stopping executors to reload...") |
| 241 | + # Reset flag and set status to INITIALIZING |
| 242 | + shared_state.update( |
| 243 | + _layer_allocation_changed=False, |
| 244 | + status=ServerState.INITIALIZING.value, |
| 245 | + ) |
| 246 | + _stop_executor_processes(executor_subprocs) |
| 247 | + _update_args_from_shared_state(args, shared_state) |
207 | 248 | logger.info( |
208 | | - f"Creating new executor with layers [{args.start_layer}, {args.end_layer})" |
| 249 | + f"Reloading executor with layers [{args.start_layer}, {args.end_layer})" |
209 | 250 | ) |
| 251 | + continue |
210 | 252 |
|
211 | | - gradient_server._layer_allocation_changed = False |
212 | | - continue # Create new executor in next iteration |
213 | | - else: |
214 | | - break # Normal exit |
| 253 | + # All processes exited normally |
| 254 | + break |
215 | 255 | except KeyboardInterrupt: |
216 | 256 | logger.debug("Received interrupt signal, shutting down...") |
217 | 257 | break |
218 | 258 | except Exception as e: |
219 | 259 | logger.exception(f"Executor error: {e}") |
220 | | - # If layer allocation changed, try to reload |
221 | | - if gradient_server is not None and gradient_server._layer_allocation_changed: |
222 | | - logger.info("Attempting to reload executor after error...") |
223 | | - if executor is not None: |
224 | | - executor.shutdown() |
225 | | - continue |
226 | | - else: |
227 | | - raise |
| 260 | + # Shutdown all executor processes on error |
| 261 | + for proc in executor_subprocs: |
| 262 | + if proc.is_alive(): |
| 263 | + stop_executor_process(proc) |
| 264 | + raise |
228 | 265 | except KeyboardInterrupt: |
229 | 266 | logger.debug("Received interrupt signal, shutting down...") |
230 | 267 | except Exception as e: |
231 | 268 | logger.exception(e) |
232 | 269 | finally: |
233 | | - thread_pool = [] |
234 | | - |
235 | | - # Shutdown http server |
236 | | - if http_server_process is not None: |
237 | | - t = threading.Thread(target=stop_http_server, args=(http_server_process,)) |
238 | | - t.start() |
239 | | - thread_pool.append(t) |
240 | | - |
241 | | - # Shutdown gradient server |
242 | | - if gradient_server is not None: |
243 | | - gradient_server.shutdown() |
| 270 | + # Shutdown all processes |
| 271 | + logger.debug("Shutting down all processes...") |
244 | 272 |
|
245 | 273 | # Shutdown executor subprocesses |
246 | 274 | for executor_process in executor_subprocs: |
247 | | - t = threading.Thread(target=stop_executor_process, args=(executor_process,)) |
248 | | - t.start() |
249 | | - thread_pool.append(t) |
| 275 | + if executor_process.is_alive(): |
| 276 | + stop_executor_process(executor_process) |
250 | 277 |
|
251 | | - # Shutdown executor main process |
252 | | - if executor is not None: |
253 | | - executor.shutdown() |
| 278 | + # Shutdown P2P server subprocess |
| 279 | + if p2p_server_process is not None: |
| 280 | + stop_p2p_server(p2p_server_process) |
| 281 | + |
| 282 | + # Shutdown http server |
| 283 | + if http_server_process is not None: |
| 284 | + stop_http_server(http_server_process) |
254 | 285 |
|
255 | | - for t in thread_pool: |
256 | | - t.join() |
| 286 | + logger.debug("All processes shut down.") |
0 commit comments