Skip to content

Commit 78012cf

Browse files
committed
Add PrepareRequestData method for the prefix cache plugin
1 parent 5debae8 commit 78012cf

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package plugins
2+
3+
import (
4+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
5+
)
6+
7+
const (
8+
PrefixCacheMatchPrecentKey = "PrefixCacheMatchPercentKey"
9+
)
10+
11+
type PrefixCacheMatchPercent struct {
12+
matchPercentage float64
13+
}
14+
15+
func NewPrefixCacheMatchPercent(matchPercentage float64) *PrefixCacheMatchPercent {
16+
return &PrefixCacheMatchPercent{
17+
matchPercentage: matchPercentage,
18+
}
19+
}
20+
21+
func (p *PrefixCacheMatchPercent) MatchPercentage() float64 {
22+
return p.matchPercentage
23+
}
24+
25+
func (p *PrefixCacheMatchPercent) Clone() datalayer.Cloneable {
26+
return &PrefixCacheMatchPercent{
27+
matchPercentage: p.matchPercentage,
28+
}
29+
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"sigs.k8s.io/controller-runtime/pkg/log"
3030

3131
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
32+
dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins"
3233
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3334
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3435
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
@@ -207,6 +208,32 @@ func (p *Plugin) WithName(name string) *Plugin {
207208
return p
208209
}
209210

211+
func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
212+
// pre score step, hashing prompt and find longest prefix match.
213+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
214+
state := &SchedulingContextState{
215+
PrefixHashes: hashes,
216+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
217+
}
218+
for server, matchLen := range state.PrefixCacheServers {
219+
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "server", server, "longest-prefix-match", matchLen)
220+
221+
}
222+
223+
total := len(state.PrefixHashes)
224+
podScoreFunc := func(pod types.Pod) float64 {
225+
if total == 0 {
226+
return 0
227+
}
228+
matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)]
229+
return float64(matchLen) / float64(total)
230+
}
231+
for _, pod := range pods {
232+
pod.Put(dplugins.PrefixCacheMatchPrecentKey, dplugins.NewPrefixCacheMatchPercent(podScoreFunc(pod)))
233+
}
234+
return nil
235+
}
236+
210237
// Score returns the scoring result for the given list of pods based on context.
211238
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
212239
// pre score step, hashing prompt and find longest prefix match.

0 commit comments

Comments
 (0)