Skip to content

Commit 4a2f67a

Browse files
JasonOEgufengc
andauthored
feat(backend): subprocess for executor and gradient server (#254)
Co-authored-by: jason <jl@gradient.network> Co-authored-by: gufengc <gufeng@gradient.network>
1 parent 3747595 commit 4a2f67a

File tree

6 files changed

+510
-246
lines changed

6 files changed

+510
-246
lines changed

src/parallax/launch.py

Lines changed: 134 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
44
This script is used to launch the Parallax server.
55
It will start the following services:
6-
1.Executor with tp_rank=0 in the main process.
7-
2.Executor with tp_rank>0, each tp_rank as a subprocess.
8-
3.HTTP server as a subprocess.
9-
4.P2P server as a thread in the main process.
6+
1.Executor each tp_rank as a subprocess.
7+
2.HTTP server as a subprocess.
8+
3.P2P server as a subprocess.
109
1110
Example command:
1211
python src/parallax/launch.py \
@@ -21,16 +20,13 @@
2120
import multiprocessing
2221
import os
2322
import tempfile
24-
import threading
25-
26-
from parallax.p2p.server import ServerState, launch_p2p_server
27-
from parallax.server.executor import (
28-
Executor,
29-
run_executor_process,
30-
stop_executor_process,
31-
)
23+
import time
24+
25+
from parallax.p2p.server import ServerState, launch_p2p_server_process, stop_p2p_server
26+
from parallax.server.executor import run_executor_process, stop_executor_process
3227
from parallax.server.http_server import launch_http_server, stop_http_server
3328
from parallax.server.server_args import parse_args
29+
from parallax.utils.shared_state import SharedState
3430
from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port
3531
from parallax_utils.ascii_anime import display_parallax_join
3632
from parallax_utils.logging_config import get_logger, set_log_level
@@ -39,13 +35,55 @@
3935
logger = get_logger("parallax.launch")
4036

4137

38+
def _update_args_from_shared_state(args, shared_state: SharedState):
39+
"""Update args with layer allocation from shared state"""
40+
model_info = shared_state.get_model_info()
41+
args.start_layer = model_info["block_start_index"]
42+
args.end_layer = model_info["block_end_index"]
43+
# Update model_path if provided and not already set
44+
if model_info["model_name"] and args.model_path is None:
45+
args.model_path = model_info["model_name"]
46+
# Update tp_size if provided, otherwise keep current value
47+
args.tp_size = model_info["tp_size"] or args.tp_size
48+
49+
50+
def _stop_executor_processes(executor_subprocs):
51+
"""Stop all executor processes"""
52+
for executor_process in executor_subprocs:
53+
if executor_process.is_alive():
54+
logger.debug(f"Terminating executor process {executor_process.pid}")
55+
stop_executor_process(executor_process)
56+
57+
58+
def _wait_executors_check_layer_change(shared_state: SharedState, executor_subprocs):
59+
"""Wait for executor processes and check if layer allocation changed.
60+
61+
Returns:
62+
True if layer allocation changed (need to reload executors),
63+
False if all executors exited normally.
64+
"""
65+
while any(proc.is_alive() for proc in executor_subprocs):
66+
for proc in executor_subprocs:
67+
if proc.is_alive():
68+
proc.join(timeout=1.0) # Check every second
69+
70+
if shared_state.get_layer_allocation_changed():
71+
return True
72+
73+
# Check race condition: layer allocation changed after all processes exited
74+
return shared_state.get_layer_allocation_changed()
75+
76+
4277
if __name__ == "__main__":
4378
multiprocessing.set_start_method("spawn", force=True)
4479

45-
gradient_server = None
80+
p2p_server_process = None
4681
http_server_process = None
47-
executor = None
4882
executor_subprocs = []
83+
# Shared state for layer allocation info (used when P2P server is in subprocess)
84+
shared_state = SharedState.create()
85+
shared_state.set_status(ServerState.JOINING.value)
86+
4987
try:
5088
args = parse_args()
5189
set_log_level(args.log_level)
@@ -72,7 +110,8 @@
72110
# only launch http server on head node
73111
if args.start_layer == 0:
74112
http_server_process = launch_http_server(args)
75-
launch_p2p_server(
113+
# Launch P2P server as subprocess
114+
p2p_server_process = launch_p2p_server_process(
76115
initial_peers=args.initial_peers,
77116
scheduler_addr=args.scheduler_addr,
78117
relay_servers=args.relay_servers,
@@ -93,26 +132,34 @@
93132
max_sequence_length=args.max_sequence_length,
94133
param_mem_ratio=args.param_mem_ratio,
95134
kvcache_mem_ratio=args.kvcache_mem_ratio,
135+
shared_state=shared_state.dict, # Pass dict to subprocess
136+
log_level=args.log_level,
96137
)
97-
if gradient_server is not None:
98-
gradient_server.status = ServerState.READY
99138

100-
# For each tp_rank > 0, create a subprocess and run executor
101-
for tp_rank in range(1, args.tp_size):
139+
# Launch all executor processes (including tp_rank=0)
140+
for tp_rank in range(args.tp_size):
102141
args_copy = argparse.Namespace(**vars(args))
103142
args_copy.tp_rank = tp_rank
104143
proc = multiprocessing.Process(
105144
target=run_executor_process,
106-
args=(args_copy,),
145+
args=(
146+
args_copy,
147+
shared_state.dict, # Pass dict to subprocess
148+
),
107149
)
108150
proc.start()
109151
executor_subprocs.append(proc)
110-
# Launch executor with tp_rank=0 in the main process
111-
args.tp_rank = 0
112-
executor = Executor.create_from_args(args)
113-
executor.run_loop()
152+
153+
time.sleep(2) # Give executors time to start
154+
shared_state.set_status(ServerState.READY.value)
155+
156+
# Wait for all executor processes
157+
for proc in executor_subprocs:
158+
proc.join()
114159
else:
115-
gradient_server = launch_p2p_server(
160+
# Launch P2P server as subprocess (with scheduler)
161+
# Pass dict to subprocess (multiprocessing requires serializable objects)
162+
p2p_server_process = launch_p2p_server_process(
116163
initial_peers=args.initial_peers,
117164
scheduler_addr=args.scheduler_addr,
118165
relay_servers=args.relay_servers,
@@ -133,18 +180,34 @@
133180
max_sequence_length=args.max_sequence_length,
134181
param_mem_ratio=args.param_mem_ratio,
135182
kvcache_mem_ratio=args.kvcache_mem_ratio,
183+
shared_state=shared_state.dict, # Pass dict to subprocess
184+
log_level=args.log_level,
136185
)
137-
args.start_layer = gradient_server.block_start_index
138-
args.end_layer = gradient_server.block_end_index
139-
# Only read model_name from scheduler if model_path is not set, so we can use local path as model_path
140-
if args.model_path is None:
141-
args.model_path = gradient_server.model_name
142-
args.tp_size = gradient_server.tp_size
186+
187+
# Wait for layer allocation from scheduler (via shared state)
188+
logger.debug("Waiting for layer allocation from scheduler...")
189+
max_wait_time = 300 # 5 minutes
190+
wait_start = time.time()
191+
while True:
192+
model_info = shared_state.get_model_info()
193+
if (
194+
model_info["block_start_index"] is not None
195+
and model_info["block_end_index"] is not None
196+
and model_info["model_name"] is not None
197+
):
198+
break
199+
if time.time() - wait_start > max_wait_time:
200+
logger.error("Timeout waiting for layer allocation from scheduler")
201+
raise RuntimeError("Failed to get layer allocation from scheduler")
202+
time.sleep(1)
203+
204+
# Get layer allocation from shared state
205+
_update_args_from_shared_state(args, shared_state)
143206

144207
logger.debug(
145-
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
208+
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}, "
209+
f"model: {args.model_path}"
146210
)
147-
gradient_server.status = ServerState.INITIALIZING
148211

149212
if args.log_level != "DEBUG":
150213
display_parallax_join(args.model_path)
@@ -157,100 +220,67 @@
157220
# Main execution loop with layer reallocation support
158221
while True:
159222
try:
160-
# For each tp_rank > 0, create a subprocess and run executor
161-
for tp_rank in range(1, args.tp_size):
223+
# Launch all executor processes (including tp_rank=0)
224+
executor_subprocs = []
225+
for tp_rank in range(args.tp_size):
162226
args_copy = argparse.Namespace(**vars(args))
163227
args_copy.tp_rank = tp_rank
164228
proc = multiprocessing.Process(
165229
target=run_executor_process,
166-
args=(args_copy,),
230+
args=(
231+
args_copy,
232+
shared_state.dict, # Pass dict to subprocess
233+
),
167234
)
168235
proc.start()
169236
executor_subprocs.append(proc)
170-
# Launch executor with tp_rank=0 in the main process
171-
args.tp_rank = 0
172-
executor = Executor.create_from_args(args, gradient_server=gradient_server)
173-
if gradient_server is not None:
174-
gradient_server.status = ServerState.READY
175-
176-
executor.run_loop()
177-
178-
# Check if layer allocation changed (executor exited due to reallocation)
179-
if gradient_server is not None and gradient_server._layer_allocation_changed:
180-
logger.warning(
181-
"Layer allocation changed! Reloading executor with new layers..."
182-
)
183-
184-
# shutdown all executor processes
185-
thread_pool = []
186-
for executor_process in executor_subprocs:
187-
t = threading.Thread(
188-
target=stop_executor_process, args=(executor_process,)
189-
)
190-
t.start()
191-
thread_pool.append(t)
192-
executor.shutdown()
193-
for t in thread_pool:
194-
t.join()
195-
196-
if args.start_layer == 0:
197-
http_server_process = stop_http_server(http_server_process)
198-
if gradient_server.block_start_index == 0:
199-
http_server_process = launch_http_server(args)
200-
201-
# Update args with new layer allocation
202-
args.start_layer = gradient_server.block_start_index
203-
args.end_layer = gradient_server.block_end_index
204-
if gradient_server.model_name:
205-
args.model_path = gradient_server.model_name
206237

238+
# Wait for executors and restart if layer allocation changes
239+
if _wait_executors_check_layer_change(shared_state, executor_subprocs):
240+
logger.warning("Layer allocation changed! Stopping executors to reload...")
241+
# Reset flag and set status to INITIALIZING
242+
shared_state.update(
243+
_layer_allocation_changed=False,
244+
status=ServerState.INITIALIZING.value,
245+
)
246+
_stop_executor_processes(executor_subprocs)
247+
_update_args_from_shared_state(args, shared_state)
207248
logger.info(
208-
f"Creating new executor with layers [{args.start_layer}, {args.end_layer})"
249+
f"Reloading executor with layers [{args.start_layer}, {args.end_layer})"
209250
)
251+
continue
210252

211-
gradient_server._layer_allocation_changed = False
212-
continue # Create new executor in next iteration
213-
else:
214-
break # Normal exit
253+
# All processes exited normally
254+
break
215255
except KeyboardInterrupt:
216256
logger.debug("Received interrupt signal, shutting down...")
217257
break
218258
except Exception as e:
219259
logger.exception(f"Executor error: {e}")
220-
# If layer allocation changed, try to reload
221-
if gradient_server is not None and gradient_server._layer_allocation_changed:
222-
logger.info("Attempting to reload executor after error...")
223-
if executor is not None:
224-
executor.shutdown()
225-
continue
226-
else:
227-
raise
260+
# Shutdown all executor processes on error
261+
for proc in executor_subprocs:
262+
if proc.is_alive():
263+
stop_executor_process(proc)
264+
raise
228265
except KeyboardInterrupt:
229266
logger.debug("Received interrupt signal, shutting down...")
230267
except Exception as e:
231268
logger.exception(e)
232269
finally:
233-
thread_pool = []
234-
235-
# Shutdown http server
236-
if http_server_process is not None:
237-
t = threading.Thread(target=stop_http_server, args=(http_server_process,))
238-
t.start()
239-
thread_pool.append(t)
240-
241-
# Shutdown gradient server
242-
if gradient_server is not None:
243-
gradient_server.shutdown()
270+
# Shutdown all processes
271+
logger.debug("Shutting down all processes...")
244272

245273
# Shutdown executor subprocesses
246274
for executor_process in executor_subprocs:
247-
t = threading.Thread(target=stop_executor_process, args=(executor_process,))
248-
t.start()
249-
thread_pool.append(t)
275+
if executor_process.is_alive():
276+
stop_executor_process(executor_process)
250277

251-
# Shutdown executor main process
252-
if executor is not None:
253-
executor.shutdown()
278+
# Shutdown P2P server subprocess
279+
if p2p_server_process is not None:
280+
stop_p2p_server(p2p_server_process)
281+
282+
# Shutdown http server
283+
if http_server_process is not None:
284+
stop_http_server(http_server_process)
254285

255-
for t in thread_pool:
256-
t.join()
286+
logger.debug("All processes shut down.")

0 commit comments

Comments
 (0)