Skip to content

Commit 3b01d20

Browse files
committed
Add PrepareRequestData method for the prefix cache plugin
1 parent cbbfd0e commit 3b01d20

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package plugins
2+
3+
import (
4+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
5+
)
6+
7+
const (
8+
PrefixCacheMatchPrecentKey = "PrefixCacheMatchPercentKey"
9+
)
10+
11+
type PrefixCacheMatchPercent struct {
12+
matchPercentage float64
13+
}
14+
15+
func NewPrefixCacheMatchPercent(matchPercentage float64) *PrefixCacheMatchPercent {
16+
return &PrefixCacheMatchPercent{
17+
matchPercentage: matchPercentage,
18+
}
19+
}
20+
21+
func (p *PrefixCacheMatchPercent) MatchPercentage() float64 {
22+
return p.matchPercentage
23+
}
24+
25+
func (p *PrefixCacheMatchPercent) Clone() datalayer.Cloneable {
26+
return &PrefixCacheMatchPercent{
27+
matchPercentage: p.matchPercentage,
28+
}
29+
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
k8stypes "k8s.io/apimachinery/pkg/types"
2929
"sigs.k8s.io/controller-runtime/pkg/log"
3030

31+
dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins"
3132
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3233
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3334
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
@@ -206,6 +207,32 @@ func (p *Plugin) WithName(name string) *Plugin {
206207
return p
207208
}
208209

210+
func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
211+
// pre score step, hashing prompt and find longest prefix match.
212+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
213+
state := &SchedulingContextState{
214+
PrefixHashes: hashes,
215+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
216+
}
217+
for server, matchLen := range state.PrefixCacheServers {
218+
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "server", server, "longest-prefix-match", matchLen)
219+
220+
}
221+
222+
total := len(state.PrefixHashes)
223+
podScoreFunc := func(pod types.Pod) float64 {
224+
if total == 0 {
225+
return 0
226+
}
227+
matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)]
228+
return float64(matchLen) / float64(total)
229+
}
230+
for _, pod := range pods {
231+
pod.Put(dplugins.PrefixCacheMatchPrecentKey, dplugins.NewPrefixCacheMatchPercent(podScoreFunc(pod)))
232+
}
233+
return nil
234+
}
235+
209236
// Score returns the scoring result for the given list of pods based on context.
210237
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
211238
// pre score step, hashing prompt and find longest prefix match.

0 commit comments

Comments
 (0)