@@ -51,32 +51,74 @@ def __init__(
5151 if clear_buffers :
5252 self ._clear_buffer ()
5353
54- # -------------------------------------------------------------------------
55- # Helper utilities
56- # -------------------------------------------------------------------------
57- def _host_from_provider (self , provider : Dict ) -> Optional [str ]:
58- """Extract the host identifier from a provider configuration."""
59- # Most providers expose the host under ``api_host`` or ``host``.
60- return provider .get ("api_host" ) or provider .get ("host" )
54+ # -----------------------------------------------------------------
55+ # Overridden public API
56+ # -----------------------------------------------------------------
57+ def get_provider (
58+ self ,
59+ model_name : str ,
60+ providers : List [Dict ],
61+ options : Optional [Dict [str , Any ]] = None ,
62+ ) -> Dict | None :
63+ """
64+ Execute the optimisation steps before falling back to the base
65+ implementation.
66+ """
67+ if not providers :
68+ return None
6169
62- def _last_host_key (self , model_name : str ) -> str :
63- """Redis key that stores the last host used for a given model."""
64- return f"{ self ._get_redis_key (model_name )} :last_host"
70+ # Initialise Redis structures (same as parent).
71+ redis_key , _ = self .init_provider (
72+ model_name = model_name , providers = providers , options = options
73+ )
74+ if not redis_key :
75+ return None
6576
66- def _model_hosts_set_key (self , model_name : str ) -> str :
67- """Redis set key that holds all hosts where *model_name* is loaded."""
68- return f"{ self ._get_redis_key (model_name )} :hosts"
77+ # ---- Step 1 -------------------------------------------------
78+ provider = self ._step1_last_host (model_name , providers )
79+ # print("last_host_provider", provider)
80+ # print("last_host_provider", provider)
81+ # print("last_host_provider", provider)
82+ # print("last_host_provider", provider)
83+ if provider :
84+ self ._record_selection (model_name , provider )
85+ return provider
6986
70- def _host_occupancy_key (self , host_name : str ) -> str :
71- """
72- Redis hash key that stores the model currently occupying *host_name*.
73- The hash field used is ``model``.
74- """
75- return self ._host_key (host_name )
87+ # ---- Step 2 -------------------------------------------------
88+ provider = self ._step2_existing_hosts (model_name , providers )
89+ # print("existing_host_provider", provider)
90+ # print("existing_host_provider", provider)
91+ # print("existing_host_provider", provider)
92+ # print("existing_host_provider", provider)
93+ if provider :
94+ self ._record_selection (model_name , provider )
95+ return provider
7696
77- # -------------------------------------------------------------------------
97+ # ---- Step 3 -------------------------------------------------
98+ provider = self ._step3_unused_host (model_name , providers )
99+ # print("unused_host_provider", provider)
100+ # print("unused_host_provider", provider)
101+ # print("unused_host_provider", provider)
102+ # print("unused_host_provider", provider)
103+ if provider :
104+ self ._record_selection (model_name , provider )
105+ return provider
106+
107+ # ---- Step 4 – fallback ---------------------------------------
108+ provider = super ().get_provider (
109+ model_name = model_name , providers = providers , options = options
110+ )
111+ # print("first_available_host_provider", provider)
112+ # print("first_available_host_provider", provider)
113+ # print("first_available_host_provider", provider)
114+ # print("first_available_host_provider", provider)
115+ if provider :
116+ self ._record_selection (model_name , provider )
117+ return provider
118+
119+ # -----------------------------------------------------------------
78120 # Step 1 – reuse last host
79- # -------------------------------------------------------------------------
121+ # -----------------------------------------------------------------
80122
81123 def _step1_last_host (
82124 self , model_name : str , providers : List [Dict ]
@@ -92,9 +134,7 @@ def _step1_last_host(
92134 return None
93135
94136 # Verify host is not occupied by a different model.
95- occupancy_hash = self ._host_occupancy_key (last_host )
96- current_model = self .redis_client .hget (occupancy_hash , "model" )
97- if current_model and current_model != model_name :
137+ if not self ._is_host_free (last_host , model_name ):
98138 return None # host busy with another model
99139
100140 # Find a provider that belongs to this host.
@@ -117,9 +157,9 @@ def _step1_last_host(
117157 continue
118158 return None
119159
120- # -------------------------------------------------------------------------
160+ # -----------------------------------------------------------------
121161 # Step 2 – reuse any host that already runs this model
122- # -------------------------------------------------------------------------
162+ # -----------------------------------------------------------------
123163
124164 def _step2_existing_hosts (
125165 self , model_name : str , providers : List [Dict ]
@@ -139,9 +179,7 @@ def _step2_existing_hosts(
139179 if host not in known_hosts :
140180 continue
141181 # Ensure the host is not occupied by a different model.
142- occ_key = self ._host_occupancy_key (host )
143- cur = self .redis_client .hget (occ_key , "model" )
144- if cur and cur != model_name :
182+ if not self ._is_host_free (host , model_name ):
145183 continue
146184
147185 field = self ._provider_field (provider )
@@ -158,9 +196,9 @@ def _step2_existing_hosts(
158196 continue
159197 return None
160198
161- # ---------------------------------------------------------------------
199+ # -----------------------------------------------------------------
162200 # Step 3 – pick a host that does NOT already have this model loaded
163- # --------------------------------------------------------------------------
201+ # -----------------------------------------------------------------
164202
165203 def _step3_unused_host (
166204 self , model_name : str , providers : List [Dict ]
@@ -186,9 +224,7 @@ def _step3_unused_host(
186224 continue
187225
188226 # Ensure the host is not currently occupied by a different model.
189- occ_key = self ._host_occupancy_key (host )
190- current_model = self .redis_client .hget (occ_key , "model" )
191- if current_model and current_model != model_name :
227+ if not self ._is_host_free (host , model_name ):
192228 continue
193229
194230 field = self ._provider_field (provider )
@@ -205,9 +241,9 @@ def _step3_unused_host(
205241 continue
206242 return None
207243
208- # -------------------------------------------------------------------------
244+ # -----------------------------------------------------------------
209245 # Step 5 – bookkeeping after a successful acquisition
210- # -------------------------------------------------------------------------
246+ # -----------------------------------------------------------------
211247
212248 def _record_selection (self , model_name : str , provider : Dict ) -> None :
213249 """
@@ -228,9 +264,9 @@ def _record_selection(self, model_name: str, provider: Dict) -> None:
228264 occ_key = self ._host_occupancy_key (host )
229265 self .redis_client .hset (occ_key , "model" , model_name )
230266
231- # -------------------------------------------------------
267+ # -----------------------------------------------------------------
232268 # Ensure buffers are cleared on construction (delegated to parent)
233- # -------------------------------------------------------
269+ # -----------------------------------------------------------------
234270
235271 def _clear_buffer (self ) -> None :
236272 """
@@ -252,68 +288,35 @@ def _clear_buffer(self) -> None:
252288 self .logger .debug (f"Removing { self } => { key } from redis" )
253289 self .redis_client .delete (key )
254290
255- # -------------------------------------------------------------------------
256- # Overridden public API
257- # -------------------------------------------------------------------------
258-
259- def get_provider (
260- self ,
261- model_name : str ,
262- providers : List [Dict ],
263- options : Optional [Dict [str , Any ]] = None ,
264- ) -> Dict | None :
265- """
266- Execute the optimisation steps before falling back to the base
267- implementation.
268- """
269- if not providers :
270- return None
271-
272- # Initialise Redis structures (same as parent).
273- redis_key , _ = self .init_provider (
274- model_name = model_name , providers = providers , options = options
275- )
276- if not redis_key :
277- return None
291+ # -----------------------------------------------------------------
292+ # Helper utilities
293+ # -----------------------------------------------------------------
294+ def _host_from_provider (self , provider : Dict ) -> Optional [str ]:
295+ """Extract the host identifier from a provider configuration."""
296+ # Most providers expose the host under ``api_host`` or ``host``.
297+ return provider .get ("api_host" ) or provider .get ("host" )
278298
279- # ---- Step 1 ---------------------------------------------------------
280- provider = self ._step1_last_host (model_name , providers )
281- print ("last_host_provider" , provider )
282- print ("last_host_provider" , provider )
283- print ("last_host_provider" , provider )
284- print ("last_host_provider" , provider )
285- if provider :
286- self ._record_selection (model_name , provider )
287- return provider
299+ def _last_host_key (self , model_name : str ) -> str :
300+ """Redis key that stores the last host used for a given model."""
301+ return f"{ self ._get_redis_key (model_name )} :last_host"
288302
289- # ---- Step 2 ---------------------------------------------------------
290- provider = self ._step2_existing_hosts (model_name , providers )
291- print ("existing_host_provider" , provider )
292- print ("existing_host_provider" , provider )
293- print ("existing_host_provider" , provider )
294- print ("existing_host_provider" , provider )
295- if provider :
296- self ._record_selection (model_name , provider )
297- return provider
303+ def _model_hosts_set_key (self , model_name : str ) -> str :
304+ """Redis set key that holds all hosts where *model_name* is loaded."""
305+ return f"{ self ._get_redis_key (model_name )} :hosts"
298306
299- # ---- Step 3 ---------------------------------------------------------
300- provider = self ._step3_unused_host (model_name , providers )
301- print ("unused_host_provider" , provider )
302- print ("unused_host_provider" , provider )
303- print ("unused_host_provider" , provider )
304- print ("unused_host_provider" , provider )
305- if provider :
306- self ._record_selection (model_name , provider )
307- return provider
307+ def _host_occupancy_key (self , host_name : str ) -> str :
308+ """
309+ Redis hash key that stores the model currently occupying *host_name*.
310+ The hash field used is ``model``.
311+ """
312+ return self ._host_key (host_name )
308313
309- # ---- Step 4 – fallback ---------------------------------------------
310- provider = super ().get_provider (
311- model_name = model_name , providers = providers , options = options
312- )
313- print ("first_available_host_provider" , provider )
314- print ("first_available_host_provider" , provider )
315- print ("first_available_host_provider" , provider )
316- print ("first_available_host_provider" , provider )
317- if provider :
318- self ._record_selection (model_name , provider )
319- return provider
314+ # New helper to check if a host is free for the requested model.
315+ def _is_host_free (self , host : str , model_name : str ) -> bool :
316+ """
317+ Return ``True`` if *host* is either unoccupied or occupied by the same
318+ *model_name*. ``False`` means the host is occupied by a different model.
319+ """
320+ occ_key = self ._host_occupancy_key (host )
321+ current_model = self .redis_client .hget (occ_key , "model" )
322+ return not (current_model and current_model != model_name )
0 commit comments